《机器学习实战》学习笔记之k-近邻算法3

2.3 手写识别系统

从os模块中导入listdir函数,用来读取给定目录中的文件名

from os import listdir
关于zeros函数的使用,


代码及注释

#image convert to vector
def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr=fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('digits/trainingDigits')
    m = len(trainingFileList)#训练样本的个数
    trainingMat = zeros((m,1024))#创建训练矩阵,每行有1024个元素,表示一个训练样本
    for i in range(m):
        fileNameStr = trainingFileList[i]#第i个训练样本
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])#样本命名的第一个数字表示实际的分类
        hwLabels.append(classNumStr)#得到训练集的所有分类
        trainingMat[i,:] = img2vector('digits/trainingDigits/%s'%fileNameStr)#将所有样本转换成矩阵,得到训练样本集
    testFileList = listdir('digits/testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr =fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])#同样方法得到一个测试样本的分类
        vectorUnderTest = img2vector('digits/testDigits/%s'%fileNameStr)#将一个测试样本转成矩阵
        classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)#执行分类
        print "the classifier came back with:%d, the real answer is:%d"%(classifierResult,classNumStr)
        if(classifierResult != classNumStr):errorCount+=1.0
    print "\nthe total number of errors is:%d"%errorCount
    print "\nthe total error rate is:%f"%(errorCount/float(mTest))

终端结果

k-近邻算法总结:摘自《机器学习实战》

简单的说,该算法采用测量不同特征值之间的距离方法进行分类,缺陷:

1. 必须保存全部数据集,如果训练数据集很大,必须使用大量的存储空间

2.必须对每个数据计算距离值,耗时大

3.无法给出任何数据的基础结构信息,无法知晓平均实例样本和典型实例样本具有什么特征