机器学习笔记--鸢尾花分类(二)

狂风中的少年 提交于 2019-12-05 07:14:47

· 训练和测试数据

要验证模型是否成功,通常会把收集好带标签的数据分成两部分,一部分用来构建机器学习模型,叫做训练数据(training data),其余的用来测试,叫做测试数据(test data)。scikit-learn 中的 train_test_split 函数一般会把75%的数据作为训练集,25%的数据作为测试集。 根据train_test_split对数据分类:

from sklearn.model_selection import train_test_split 
X_train, X_test, y_train, y_test = train_test_split(     
    iris_dataset['data'], iris_dataset['target'], random_state=0)
#random_state = 0 是他的随机种子
print("X_train shape: {}".format(X_train.shape)) 
print("y_train shape: {}".format(y_train.shape))
print("X_test shape: {}".format(X_test.shape)) 
print("y_test shape: {}".format(y_test.shape))

得到结果:

X_train shape: (112, 4) 
y_train shape: (112,)
X_test shape: (38, 4) 
y_test shape: (38,)

说明训练集输入的是一个112*4的二维数组,得到的是一个长度112的一维数组;测试集输入的是一个38*4的二维数组,得到的是一个长度38的一维数组。

· 观察数据

沿用上面的代码,我们用pandas里一个绘制散点图矩阵的函数,叫作scatter_matrix绘制一下散点图:

import mglearn  
import pandas as pd 
from sklearn.model_selection import train_test_split 
from sklearn.datasets import load_iris 
iris_dataset = load_iris()
X_train, X_test, y_train, y_test = train_test_split(     
    iris_dataset['data'], iris_dataset['target'], random_state=0)

iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)

grr = pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o',                        
                        hist_kwds={'bins': 20}, s=60, alpha=.8, cmap=mglearn.cm3)

可以得到下图:

可以看出根据任意两两特征基本都可以把这三个类别区分开来,说明机器学习模型很可能是可以被学会的。

· KNN算法

KNN算法总结起来就是保存训练集,然后有一个新点加入时寻找与他最近的k个点,然后根据这些邻居中数量最多的类别进行判断,这里我们设k为1。

import numpy as np
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=1)  #根据最近的一个点判断
knn.fit(X_train, y_train)  #用训练组建模
y_pred = knn.predict(X_test)  #用测试组得到预测的数据
print("Test set predictions:\n {}".format(y_pred))  #打印预测的数据
print("Test set score: {:.2f}".format(np.mean(y_pred == y_test)))  #和我们的测试集的结果比较

得到结果:

Test set predictions:
 [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
 2]
Test set score: 0.97

对于这个模型来说,测试集的精度约为 0.97,比较能够接受了。

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!