本文共 5378 字,大约阅读时间需要 17 分钟。
KNN算法是懒惰的学习算法,没有明显的训练过程,预测时只需要使用已经有标注(分类学习)的训练数据即可
适用于多分类的学习任务
from numpy import*import operatorimport pdbimport matplotlibimport matplotlib.pyplot as pltfrom os import listdir#测试数据def createDataSet(): group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) labels = ['A','A','B','B'] return group,labelsdef classify0(inX, dataSet, labels, k): ''' K-邻近算法: inX:目标点 dataSet:数据集 labels:数据的标签,label的列数和dataSet一样 K:选取的K个邻近值 返回inx的属性 ''' dataSetSize = dataSet.shape[0] # 数据集大小 # 目标点到k邻近点的距离 diffMat = tile(inX, (dataSetSize, 1)) - dataSet sqDiffMat = diffMat ** 2 # 平方 sqDistances = sqDiffMat.sum(axis=1) # 求和 distances = sqDistances ** 0.5 # 开方 sortedDistIndicies = distances.argsort() # 距离排序,返回的是distance排序后的索引 # 统计K个点所属类别 classCount = {} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 返回频率最高的标签类别 return sortedClassCount[0][0]#将读取文件,将文件转化为numpy数据def file2matrix(filename): fr = open(filename) arrayOLines = fr.readlines() numberOfLines = len(arrayOLines) #文件的行数 returnMat = zeros((numberOfLines,3)) #初始化矩阵# pdb.set_trace() classLabelVector = [] index = 0 for line in arrayOLines: line = line.strip() listFromLine = line.split('\t') returnMat[index,:] = listFromLine[0:3] classLabelVector.append(int(listFromLine[-1])) index += 1 return returnMat,classLabelVector#归一化数据def autoNorm(dataSet):# pdb.set_trace() minVals = dataSet.min(0) #行最大,若参数为1 则是列最大 maxVals = dataSet.max(0) ranges = maxVals - minVals normDataSet = zeros(shape(dataSet)) m = dataSet.shape[0] #dataSet的行数,若参数为1则为列数 normDataSet = dataSet - tile(minVals,(m,1)) #计算每个数据与最小数据之间的差值 normDataSet = normDataSet/tile(ranges,(m,1))#归一化 return normDataSet,ranges,minVals#绘图def data_plt(datingDataMat,datingLabels): fig = plt.figure() ax = fig.add_subplot(111) ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2], 15.0 * array(datingLabels), 15.0 * array(datingLabels)) plt.show()#分类器针对约会网站的测试代码def datingClassText(datingDataMat,datingLabels): hoRation = 0.10 #10%测试集 m = datingDataMat.shape[0] numTestVecs = int(m*hoRation) #测试集总量 errorCount = 0.0 for i in range(numTestVecs): ''' classify0参数 (预测点,数据集,标注,K的取值) ''' classifierResult = classify0(datingDataMat[i,:],datingDataMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) print(classifierResult,datingLabels[i]) if(classifierResult != datingLabels[i]): errorCount += 1.0 #错误分类计数 print(errorCount/float(numTestVecs)) #错误分类概率#约会网站预测函数def classifyPerson(datingDataMat,datingLabels,ranges,minVals): resultList = ['not at all','in small doses','in large doses'] #输入待测试数据,并将测试数据转换成numpy数据 percentTats = float(input("每年玩游戏的时间")) ffMiles = float(input("每年飞行公里数")) iceCream = float(input("每年冰激凌消耗数")) inArr = array([ffMiles,percentTats,iceCream]) #使用KNN算法 classifierResult = classify0((inArr-minVals)/ranges,datingDataMat,datingLabels,3) #输出预测结果 print('the result is :',resultList[classifierResult-1])'''手写识别系统数据解释:trainingDigits是训练数据,testDigits是测试数据'''#每次只读一张图片(一个txt文件),将每张图片合成的数据,改成1*1024的numpy矩阵def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0,32*i+j] = int(lineStr[j]) return returnVect#识别程序def handwritingClassTest(traindir,testdir): ''' 读取训练数据,生成训练集 ''' hwLabels = [] #用于存储每张图片所表示的数字,即数据属性 trainingFileList = listdir(traindir) #读取指定文件下下所有文件 m = len(trainingFileList)#计算总文件数目 trainingMat = zeros((m,1024))#存储所有图片数字矩阵,m*1024, for i in range(m):#"0_7.txt" 文件名样式 fileNameStr = trainingFileList[i] #第i个文件的文件名 fileStr = fileNameStr.split('.')[0] #取文件名的前面,即去掉txt classNumStr = int(fileStr.split('_')[0]) #从文件名获取当前文件所保存的图像表示的数字 hwLabels.append(classNumStr) path = '{}/{}'.format(traindir, fileNameStr) trainingMat[i:] = img2vector(path) #循环读入每个文件 ''' 读取测试数据,对每个数据进行测试 ''' testFileList = listdir(testdir) errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) vectorUnderTest = img2vector('{}/{}'.format(testdir,fileNameStr)) classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3) #进行预测 if classifierResult != classNumStr: #错误预测结果计数 errorCount += 1 print(" {}, {}".format(classifierResult,classNumStr)) print('the error rate is:',errorCount/float(mTest)) #错误预测率if __name__ == '__main__': # group,labels = createDataSet() # datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') # datingDataMat,ranges,minVals = autoNorm(datingDataMat) # # data_plt(datingDataMat,datingLabels) # datingClassText(datingDataMat,datingLabels) # classifyPerson(datingDataMat,datingLabels,ranges,minVals) # img2vector('digits/trainingDigits/0_1.txt') path_train = 'digits/trainingDigits' path_test = 'digits/testDigits' handwritingClassTest(path_train,path_test) # pdb.set_trace()
转载地址:http://aenvb.baihongyu.com/