手写数字识别 KNN

℡╲_俬逩灬. 提交于 2020-02-04 02:54:54
from numpy import*
import csv
import operator
from sklearn.neighbors import KNeighborsClassifier

def toInt(array):
    array = mat(array)
    m, n = shape(array)
    Array=zeros((m,n))
    for i in range(m):
        for j in range(n):
            try:
                Array[i, j]=int(array[i, j])
            except ValueError:
                continue
    return Array

def nomalizing(array):
    m,n=shape(array)
    for i in range(m):
        for j in range(n):
            if array[i, j] != 0:
                array[i, j] = 1
    return array

def loadTrainData():
    l = []
    with open("train.csv") as file:
        lines = csv.reader(file)
        for line in lines:
            l.append(line)
    file.close()
    l.remove(l[0])
    l = array(l)
    label = l[:, 0]
    data = l[:, 1:]
    return nomalizing(toInt(data)),toInt(label)

def loadTestData():
    l = []
    with open("test.csv") as file:
        lines = csv.reader(file)
        for line in lines:
            l.append(line)
    file.close()
    l.remove(l[0])
    data = array(l)
    return nomalizing(toInt(data))

def loadTest_result():
    l = []
    with open("test_result.csv") as file:
        lines = csv.reader(file)
        for line in lines:
            l.append(line)
    file.close()
    l.remove(l[0])
    label = array(l)
    return toInt(label[:, 1])

def saveResult(result):
    l = []
    with open("my_result.csv","w")as myFile:
        myWriter = csv.writer(myFile)
        for i in result:
            l.append(i)
        myWriter.writerow(l)
    myFile.close()
    return;

def knnClassify(x_train, y_train, x_test):
    estimator = KNeighborsClassifier()
    estimator.fit(x_train, ravel(y_train))
    y_test = estimator.predict(x_test)
    saveResult(y_test)
    return y_test

def digitRecognition():
    x_train, y_train = loadTrainData()
    x_test = loadTestData()
    predict = knnClassify(x_train, y_train, x_test)
    y_test = loadTest_result()
    m, n = shape(x_test)
    wrong = 0
    for i in range(m):
        # print("predict: %d, answer: %d" %(predict[i], y_test[0, i]))
        if predict[i] != y_test[0, i]:
            wrong += 1
    print("wrong = %d" % wrong)                                       #819
    print("right rate = %f%%" % (100.0 * (m - wrong) / float(m)))     #97.075%

if __name__ == "__main__":
    digitRecognition()
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!