kNN算法python实现和简单数字识别的方法_python

本文实例讲述了kNN算法python实现和简单数字识别的方法。分享给大家供大家参考。具体如下:

kNN算法算法优缺点:

优点:精度高、对异常值不敏感、无输入数据假定
缺点:时间复杂度和空间复杂度都很高
适用数据范围:数值型和标称型

算法的思路:

KNN算法(全称K最近邻算法),算法的思想很简单,简单的说就是物以类聚,也就是说我们从一堆已知的训练集中找出k个与目标最靠近的,然后看他们中最多的分类是哪个,就以这个为依据分类。

函数解析:

库函数:

tile()
如tile(A,n)就是将A重复n次

复制代码 代码如下:

a = np.array([0, 1, 2])
np.tile(a, 2)
array([0, 1, 2, 0, 1, 2])
np.tile(a, (2, 2))
array([[0, 1, 2, 0, 1, 2],[0, 1, 2, 0, 1, 2]])
np.tile(a, (2, 1, 2))
array([[[0, 1, 2, 0, 1, 2]],[[0, 1, 2, 0, 1, 2]]])
b = np.array([[1, 2], [3, 4]])
np.tile(b, 2)
array([[1, 2, 1, 2],[3, 4, 3, 4]])
np.tile(b, (2, 1))
array([[1, 2],[3, 4],[1, 2],[3, 4]])`

自己实现的函数

createDataSet()生成测试数组
kNNclassify(inputX, dataSet, labels, k)分类函数

inputX 输入的参数
dataSet 训练集
labels 训练集的标号
k 最近邻的数目

复制代码 代码如下:

#coding=utf-8
from numpy import *
import operator

def createDataSet():
    group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
    labels = ['A','A','B','B']
    return group,labels
#inputX表示输入向量(也就是我们要判断它属于哪一类的)
#dataSet表示训练样本
#label表示训练样本的标签
#k是最近邻的参数,选最近k个
def kNNclassify(inputX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]#计算有几个训练数据
    #开始计算欧几里得距离
    diffMat = tile(inputX, (dataSetSize,1)) - dataSet
   
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
    distances = sqDistances ** 0.5
    #欧几里得距离计算完毕
    sortedDistance = distances.argsort()
    classCount = {}
    for i in xrange(k):
        voteLabel = labels[sortedDistance[i]]
        classCount[voteLabel] = classCount.get(voteLabel,0) + 1
    res = max(classCount)
    return res

def main():
    group,labels = createDataSet()
    t = kNNclassify([0,0],group,labels,3)
    print t
   
if __name__=='__main__':
    main()

kNN应用实例

手写识别系统的实现

数据集:

两个数据集:training和test。分类的标号在文件名中。像素32*32的。数据大概这个样子:

方法:

kNN的使用,不过这个距离算起来比较复杂(1024个特征),主要是要处理如何读取数据这个问题的,比较方面直接调用就可以了。

速度:

速度还是比较慢的,这里数据集是:training 2000+,test 900+(i5的CPU)

k=3的时候要32s+

复制代码 代码如下:

#coding=utf-8
from numpy import *
import operator
import os
import time

def createDataSet():
    group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
    labels = ['A','A','B','B']
    return group,labels
#inputX表示输入向量(也就是我们要判断它属于哪一类的)
#dataSet表示训练样本
#label表示训练样本的标签
#k是最近邻的参数,选最近k个
def kNNclassify(inputX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]#计算有几个训练数据
    #开始计算欧几里得距离
    diffMat = tile(inputX, (dataSetSize,1)) - dataSet
    #diffMat = inputX.repeat(dataSetSize, aixs=1) - dataSet
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
    distances = sqDistances ** 0.5
    #欧几里得距离计算完毕
    sortedDistance = distances.argsort()
    classCount = {}
    for i in xrange(k):
        voteLabel = labels[sortedDistance[i]]
        classCount[voteLabel] = classCount.get(voteLabel,0) + 1
    res = max(classCount)
    return res

def img2vec(filename):
    returnVec = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVec[0,32*i+j] = int(lineStr[j])
    return returnVec
   
def handwritingClassTest(trainingFloder,testFloder,K):
    hwLabels = []
    trainingFileList = os.listdir(trainingFloder)
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileName = trainingFileList[i]
        fileStr = fileName.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vec(trainingFloder+'/'+fileName)
    testFileList = os.listdir(testFloder)
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileName = testFileList[i]
        fileStr = fileName.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vec(testFloder+'/'+fileName)
        classifierResult = kNNclassify(vectorUnderTest, trainingMat, hwLabels, K)
        #print classifierResult,' ',classNumStr
        if classifierResult != classNumStr:
            errorCount +=1
    print 'tatal error ',errorCount
    print 'error rate',errorCount/mTest
       
def main():
    t1 = time.clock()
    handwritingClassTest('trainingDigits','testDigits',3)
    t2 = time.clock()
    print 'execute ',t2-t1
if __name__=='__main__':
    main()

希望本文所述对大家的Python程序设计有所帮助。

时间: 2024-11-01 10:57:54

kNN算法python实现和简单数字识别的方法_python的相关文章

python实现超简单端口转发的方法_python

本文实例讲述了python实现超简单端口转发的方法.分享给大家供大家参考.具体如下: 代码非常简单,实现了简单的端口数据转发功能,用于真实环境还需要再修改一下. 复制代码 代码如下: #tcp server import socket host = '127.0.0.1'          #Local Server IP host2 = '127.0.0.1'   #Real Server IP port = 6001 #Local Server Port port2 = 7001 #Real

python实现简单温度转换的方法_python

本文实例讲述了python实现简单温度转换的方法.分享给大家供大家参考.具体分析如下: 这是一段简单的python代码,用户转换不同单位的温度,适合初学者参考 复制代码 代码如下: def c2f(t):     return (t*9/5.0)+32 def c2k(t):     return t+273.15 def f2c(t):     return (t-32)*5.0/9 def f2k(t):     return (t+459.67)*5.0/9 def k2c(t):    

python变量不能以数字打头详解_python

在编写python函数时,无意中发现一个问题:python中的变量不能以数字打头,以下函数中定义了一个变量3_num_varchar,执行时报错. 函数如下: def database_feild_varchar_trans(in_feild): ''' transfer the feild if varchar then 3times lang else no transfer ''' feild_split = in_feild.split(' ') is_varchar = feild_s

python根据出生年份简单计算生肖的方法_python

本文实例讲述了python根据出生年份简单计算生肖的方法.分享给大家供大家参考.具体分析如下: 这里使用python根据出生年份计算生肖,看了代码会发现原来这么简单 #计算生肖 def ChineseZodiac(year): return u'猴鸡狗猪鼠牛虎兔龙蛇马羊'[year%12] ChineseZodiac(1990) 希望本文所述对大家的Python程序设计有所帮助. 以上是小编为您精心准备的的内容,在的博客.问答.公众号.人物.课程等栏目也有的相关内容,欢迎继续使用右上角搜索按钮进

python通过线程实现定时器timer的方法_python

本文实例讲述了python通过线程实现定时器timer的方法.分享给大家供大家参考.具体分析如下: 这个python类实现了一个定时器效果,调用非常简单,可以让系统定时执行指定的函数 下面介绍以threading模块来实现定时器的方法. 使用前先做一个简单试验: import threading def sayhello(): print "hello world" global t #Notice: use global variable! t = threading.Timer(5

python使用chardet判断字符串编码的方法_python

本文实例讲述了python使用chardet判断字符串编码的方法.分享给大家供大家参考.具体分析如下: 最近利用python抓取一些网上的数据,遇到了编码的问题.非常头痛,总结一下用到的解决方案. linux中vim下查看文件编码的命令 set fileencoding python中一个强力的编码检测包 chardet ,使用方法非常简单.linux下利用pip install chardet实现简单安装 import chardet f = open('file','r') fencodin

python批量生成本地ip地址的方法_python

本文实例讲述了python批量生成本地ip地址的方法.分享给大家供大家参考.具体分析如下: 这段代码用于在本地计算机上生成本地ip地址绑定到网卡,生成的是一个bat的批处理文件,运行此批处理文件,可以通过ipconfig查看 #!/usr/bin/python2.7 # -*- coding: utf-8 -*- # Filename: AddIPAliases.py import re,sys,socket,struct # 1. 判断IP地址是否合法: 2. 判断用户输入的IP是否在Clas

Python读取图片属性信息的实现方法_python

本文是利用Python脚本读取图片信息,有几个说明如下:      1.没有实现错误处理      2.没有读取所有信息,大概只有 GPS 信息.图片分辨率.图片像素.设备商.拍摄设备等      3.简单修改后应该能实现暴力修改图片的 GPS 信息      4.但对于本身没有 GPS 信息的图片,实现则非常复杂,需要仔细计算每个描述符的偏移量 脚本运行后,读取结果如下 脚本读取的信息 这里和 Windows 属性查看器读到的内容完全一致 图片信息1 图片信息2 源码如下 # -*- codi

Python读取mp3中ID3信息的方法_python

本文实例讲述了Python读取mp3中ID3信息的方法.分享给大家供大家参考.具体分析如下: pyid3不好用,常常有不认识的. mutagen不错,不过默认带的easyid3不会读取注释,需要手工hack一下 Python代码如下: from mutagen.mp3 import MP3 import mutagen.id3 from mutagen.easyid3 import EasyID3 EasyID3.valid_keys["comment"]="COMM::'X