Python機器學習筆記 K-近鄰演算法
一,本文概述
眾所周知,電影可以按照題材分類,然而題材本身是如何定義的?由誰來判斷某部電影屬於哪個題材?也就是說同一題材的電影具有哪些公共特徵?這些都是在進行電影分類時必須要考慮的問題。沒有哪個電影人會說自己製作的電影和以前的某部電影類似,但是我們確實知道每部電影在風格上的確有可能會和同題材的電影相近。那麼動作片具有哪些共有特徵,使得動作片之間非常類似,而與愛情片存在著明顯的差距呢?動作片中也會存在接吻鏡頭,愛情片中也會存在打鬥場景,我們不能單純依靠是否存在打鬥或者親吻來判斷影片的型別,但是愛情片中的接吻鏡頭更多,動作片中的打鬥場景也更為頻繁,基於此類場景在某部電影中出現的次數可以用來進行電影分類。本文基於電影中出現的親吻,打鬥出現的次數,使用 K-近鄰演算法構造程式,自動劃分電影的題材型別。我們首先使用電影分類講解 K-近鄰演算法的基礎概念,然後學習如何在其他系統上使用 K-近鄰演算法。
那麼本次首先探討 K-近鄰演算法的基礎理論,以及如何使用距離測量的方法分類物品;其次我們使用Python從文字檔案中匯入並解析資料,再次,本文學習了當存在許多資料來源時,如何避免計算距離時可能會碰到的一些錯誤,最後利用實際的例子講解如何使用 K-近鄰演算法改進約會網站和手寫數字識別系統,而且實戰中將使用自己編寫程式碼和使用sklearn兩種方式編寫。
二,K-近鄰演算法概述
簡單來說 K-近鄰演算法採用測量不同特徵值之間的距離方法進行分類。
K最近鄰(k-Nearest Neighbor,KNN),是一種常用於分類的演算法,是有成熟理論支撐的、較為簡單的經典機器學習演算法之一。該方法的基本思路是:如果一個待分類樣本在特徵空間中的k個最相似(即特徵空間中K近鄰)的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別,即近朱者赤,近墨者黑。顯然,對當前待分類樣本的分類,需要大量已知分類的樣本的支援,其中k通常是不大於20的整數。KNN演算法中,所選擇的鄰居都是已經正確分類的物件。該方法在定類決策上只依據最近鄰的一個或者幾個樣本的類別來決定待分樣本所屬的類別。因此KNN是一種有監督學習演算法。
最簡單最初級的分類器時將全部的訓練資料所對應的類別都記錄下來,當測試物件的屬性和某個訓練物件的屬性完全匹配時,便可以對其進行分類,但是怎麼可能所有測試物件都會找到與之完全匹配的訓練物件呢,其次就是存在一個測試物件同時與多個訓練物件匹配,導致一個訓練物件被分到了多個類的問題,基於這些問題呢,就產生了KNN。
下面通過一個簡單的例子說明一下:如下圖,綠色圓要被決定賦予哪個類,是紅色三角形?還是藍色四方形?如果K=3,由於紅色三角形所佔比例為2/3,綠色圓將被賦予紅色三角形所屬的類,如果K = 5 ,由於藍色四邊形比例為3/5,因此綠色圓被賦予藍色四邊形類。
由此也說明了KNN演算法的結果很大程式取決於K的選擇。
在KNN中,通過計算物件間距離來作為各個物件之間的非相似性指標,避免了物件之間的匹配問題,在這裡距離一般使用歐式距離或者曼哈頓距離:
同時,KNN通過依據k個物件中佔優的類別進行決策,而不是單一的物件類別決策。這兩點算是KNN演算法的優勢。
接下來對KNN演算法的思想總結一下:就是在訓練集中資料和標籤已知的情況下,輸入測試資料,將測試資料的特徵與訓練集中對應的特徵進行相互比較,找到訓練集中與之最為相似的前K個數據,則該冊數資料對應的類別就是K個數據中出現次數最多的那個分類,其演算法的描述如下:
(1):計算測試資料與各個訓練資料之間的距離
(2):按照距離的遞增關係進行排序
(3):選取距離最小的K個點
(4):確定前K個點所在類別的出現頻率
(5):返回前K個點中出現頻率最高的類別作為測試資料的預測分類
三,k-近鄰演算法的優缺點
優點
- 簡單,易於理解,易於實現,無需引數估計,無需訓練
- 對異常值不敏感(個別噪音資料對結果的影響不是很大)
- 適合對稀有事件進行分類
- 適合於多分類問題(multi-modal,物件具有多個類別標籤),KNN要比SVM表現好
缺點
- 對測試樣本分類時的計算量大,記憶體開銷大,因為對每個待分類的文字都要計算它到全體已知樣本的距離,才能求得它的K個最近鄰點,目前常用的解決方法是對已知的樣本點進行剪輯,事先要去除對分類作用不大的樣本
- 可解析性差,無法告訴你哪個變數更重要,無法給出決策樹那樣的規則
- k值的選擇:最大的缺點是當樣本不平衡時,如一個類的樣本容量很大,而其他樣本容量很小時候,有可能導致當輸入一個新樣本時,該樣本的K個鄰居中大容量類的樣本佔多數。該演算法只計算“最近的”鄰居樣本,某一類的樣本數量很大的時候,那麼或者這類樣本並不接近目標樣本,或者這類樣本很靠近目標樣本。無論如何,數量並不影響執行結果,可以採用權值的方法(和該樣本距離小的鄰居權重大)來改進
- KNN是一種消極學習方法,懶惰演算法
四,舉例說明k-近鄰演算法的過程
存在一個樣本資料集合,也稱作訓練樣本集,並且樣本集中每個資料都存在標籤,即我們知道樣本集中每一資料與所屬分類的對應關係。輸入每一標籤的新資料後,將新資料的每個特徵與樣本集中資料對應的特徵進行比較,然後演算法提取樣本集中特徵最相似資料(最近鄰)的分類標籤。一般來說,我們只選擇樣本資料集中前k個最相似的資料,這就是 K-近鄰演算法中k的出處,通常k是不大於20的整數。最後,選擇k個最相似資料中出現次數最多的分類,作為新資料的分類。
現在我們回到前面電影分類的例子,使用 K-近鄰演算法分類愛情片和動作片。有人曾經統計過很多電影的打鬥鏡頭和接吻鏡頭。假設有一部未看過的電影,如何確定它是愛情片還是動作片呢?我們可以使用KNN來解決這個問題。
首先我們需要知道這個未知電影存在多少個打鬥鏡頭和接吻鏡頭,具體數字參見表2-1。
表1.1 就是我們已有的資料集合,也就是訓練樣本集。這個資料集有兩個特徵,即打鬥鏡頭數和接吻鏡頭數。除此之外,我們也知道每個電影的所屬型別,即分類標籤。用肉眼粗略地觀察,接吻鏡頭多的,是愛情片。打鬥鏡頭多的,是動作片。以我們多年的看片經驗,這個分類還算合理。如果現在給我一部電影,你告訴我這個電影打鬥鏡頭數和接吻鏡頭數。不告訴我這個電影型別,我可以根據你給我的資訊進行判斷,這個電影是屬於愛情片還是動作片。而k-近鄰演算法也可以像我們人一樣做到這一點,不同的地方在於,我們的經驗更"牛逼",而k-近鄰演算法是靠已有的資料。比如,你告訴我這個電影打鬥鏡頭數為2,接吻鏡頭數為102,我的經驗會告訴你這個是愛情片,k-近鄰演算法也會告訴你這個是愛情片。你又告訴我另一個電影打鬥鏡頭數為49,接吻鏡頭數為51,我"邪惡"的經驗可能會告訴你,這有可能是個"愛情動作片",但是k-近鄰演算法不會告訴你這些,因為在它的眼裡,電影型別只有愛情片和動作片,它會提取樣本集中特徵最相似資料(最鄰近)的分類標籤,得到的結果可能是愛情片,也可能是動作片,但絕不會是"愛情動作片"。當然,這些取決於資料集的大小以及最近鄰的判斷標準等因素。
4.1 距離度量
我們已經知道k-近鄰演算法根據特徵比較,然後提取樣本集中特徵最相似資料(最鄰近)的分類標籤。那麼如何進行比較呢?比如我們以下表為例,怎麼判斷紅色圓點標記的電影所屬類別呢?
我們從散點圖大致推斷,這個紅色圓點標記的電影可能屬於動作片,因為距離已知的那兩個動作片的圓點更近,那麼k-近鄰演算法用什麼方法進行判斷呢?沒錯,就是距離度量,這個電影分類的例子有兩個特徵,也就是二維實數向量空間,可以使用我們學過的兩點距離公式計算距離(歐式距離),如下圖:
通過計算,我們可以得到如下結果:
- (101,20)->動作片(108,5)的距離約為16.55
- (101,20)->動作片(115,8)的距離約為18.44
- (101,20)->愛情片(5,89)的距離約為118.22
- (101,20)->愛情片(1,101)的距離約為128.69
通過計算可知,紅色圓點標記的電影到動作片 (108,5)的距離最近,為16.55。如果演算法直接根據這個結果,判斷該紅色圓點標記的電影為動作片,這個演算法就是最近鄰演算法,而非k-近鄰演算法。那麼k-近鄰演算法是什麼呢?下面我們學習一下k-近鄰演算法步驟
五,k-近鄰演算法的流程步驟
(1)收集資料:可以使用任何方法。包括爬蟲,或者第三方提供的免費或收費資料
(2)準備資料:距離計算所需要的數值,最好是結構化的資料格式(計算測試資料與各個訓練資料之間的距離)
(3)分析資料:可以使用任何方法。此處使用Python解析,預處理資料
(4)訓練演算法:此步驟不適用於k-近鄰演算法
(5)測試演算法:計算錯誤率
(6)使用演算法:首先需要輸入樣本資料和結構化的輸出結果,然後執行k-近鄰演算法判斷輸入資料分別屬於哪個類別,最後應用對計算出的分類執行後續的處理
比如我這裡取k值為3,那麼在電影例子中,按照距離依次排序的三個點分別是動作片(108,5)、動作片(115,8)、愛情片(5,89)。在這三個點中,動作片出現的頻率為三分之二,愛情片出現的頻率為三分之一,所以該紅色圓點標記的電影為動作片。這個判別過程就是k-近鄰演算法
六,Python程式碼實現k-近鄰演算法
我們已經知道了k-近鄰演算法的原理,下面使用Python實現演算法,前面學習了Python匯入資料,後面使用約會網站配對效果判斷的例子進行學習。
6.1 準備:使用Python匯入資料
首先,匯入資料
#_*_coding:utf-8_*_ from numpy import * # 獲取資料 import operator def creataDataSet(): ''' labels中A代表愛情片B代表動作片 :return: ''' group = array([[1,101],[5,89],[108,5],[115,8]]) labels = ['A','A','B','B'] return group,labels if __name__ == '__main__': group,labels = creataDataSet() print(group) print(labels)
測試輸出結果如下:
[[1 101] [589] [1085] [1158]] ['A', 'A', 'B', 'B']
6.2 使用k-近鄰演算法解析資料
首先我們給出了k-近鄰演算法的虛擬碼和實際的Python程式碼,然後詳細的解釋其含義,該函式的功能是使用k-近鄰演算法將每組資料劃分到某個類中,其虛擬碼如下:
對未知類別屬性的資料集中的每個點依次執行以下操作: (1) 計算已知類別資料集中的點與當前點之間的距離 (2) 按照距離遞增次序排序 (3) 選取與當前點距離最小的k個點 (4) 確定前k個點所在類別的出現頻率 (5) 返回前k個點出現頻率最高的類別作為當前點的預測分類
其Python函式classify0()的程式程式碼如下:
# k-近鄰演算法 def classify0(inX,dataSet,labels,k): # shape讀取資料矩陣第一維度的長度 dataSetSize = dataSet.shape[0] # tile重複陣列inX,有dataSet行 1個dataSet列,減法計算差值 diffMat = tile(inX,(dataSetSize,1)) - dataSet # **是冪運算的意思,這裡用的歐式距離 sqDiffMat = diffMat ** 2 # 普通sum預設引數為axis=0為普通相加,axis=1為一行的行向量相加 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances ** 0.5 # argsort返回數值從小到大的索引值(陣列索引0,1,2,3) sortedDistIndicies = distances.argsort() # 選擇距離最小的k個點 classCount = {} for i in range(k): # 根據排序結果的索引值返回靠近的前k個標籤 voteLabel = labels[sortedDistIndicies[i]] # 各個標籤出現頻率 classCount[voteLabel] = classCount.get(voteLabel,0) +1 ##!!!!!classCount.iteritems()修改為classCount.items() #sorted(iterable, cmp=None, key=None, reverse=False) --> new sorted list。 # reverse預設升序 key關鍵字排序itemgetter(1)按照第一維度排序(0,1,2,3) sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0]
classify0()函式有4個輸入引數:用於分類的輸入向量是inX,輸入的訓練樣本集為dataSet,標籤向量為labels,最後的引數k表示用於選擇最近鄰居的數目,其中標籤向量的元素數目和矩陣dataSet的行數相同,上面程式使用歐式距離公式。
計算完所有點之間的距離後,可以對資料按照從小到大的次序排序。然後,確定前k個距離最小元素所在的主要分類,輸入k總是正整數;最後,將classCount字典分解為元組列表,然後使用程式第二行匯入運算子模組的itemgetter方法,按照第二個元素的次序對元組進行排序。此處的排序是逆序,即按照從最大到最小次序排序,最後返回發生頻率最高的元素標籤。
為了預測資料所在分類,我們輸入如下程式碼:
if __name__ == '__main__': group,labels = creataDataSet() knn = classify0([0,0],group,labels,3) print(knn) ''' B '''
輸出結果應該是B,也就是動作片。
到目前為止,我們已經構造了第一個分類器,使用這個分類器可以完成很多分類任務。從這個例項出發,構造使用分類演算法將會更加容易。
6.3 如何測試分類器
上面我們已經使用k-近鄰演算法構造了第一個分類器,也可以檢驗分類器給出的答案是否符合我們的預期,大家可能會問:“分類器何種情況下會出錯?”或者“答案是否總是正確的?” 答案是否定的,分類器並不會得到百分百正確的結果,我們可以使用多種方法檢測分類器的正確率。此外分類器的效能也受到多種因素的影響,如分類器設定和資料集等。不同的演算法在不同資料集上的表現可能完全不同。
為了測試分類器的效果,我們可以使用已知答案的資料,當然答案不能告訴分類器,檢驗分類器給出的結果是否符合預期結果。通過大量的測試資料,我們可以得到分類器的錯誤率——分類器給出的錯誤結果的次數除了測試執行的總數。錯誤率是常用的評估方法,主要用於評估分類器在某個資料集上的執行效果。完美分類器的錯誤率為0,最差分類器的錯誤率是1.0,在這種情況下,分類器根本就無法找到一個正確答案。所以我們不難發現,k-近鄰演算法沒有進行資料的訓練,直接使用未知的資料與已知的資料進行比較,得到結果。因此,可以說k-近鄰演算法不具有顯式公式。
上面介紹的例子可以正確運轉,但是並沒有太大的實際用處,下面我們將在現實世界中使用k-近鄰演算法。首先,我們將使用k-近鄰演算法改進約會網站的效果,然後使用k-近鄰演算法改進手寫識別系統。
6.4 常見問題
1,K值設定為多大?
K太小,分類結果易受噪聲點影響;K太大,近鄰中又可能包含太多的其他類別的點。(對距離加權,可以降低K值設定的影響)
K值通常是採用交叉檢驗來確定
經驗規則:K一般低於訓練樣本數的平方根
2,類別如何判斷最合適?
投票法沒有考慮近鄰的距離遠近,距離更近的近鄰也許更應該決定最終的分類,所以加權投票法更加恰當一些。
3,如何選擇合適的距離衡量?
高維度對距離衡量的影響:眾所周知當變數數越多,歐氏距離的區分能力就越差。
變數值域對距離的影響:值域越大的變數常常會在距離計算中佔據主導作用,因此應當對變數進行標準化。
4,訓練樣本是否要一視同仁?
在訓練集中 ,有些樣本可能是更值得依賴的
可以給不同的樣本施加不同的權重,加強依賴樣本的權重,降低不可信賴樣本的影響。
5,效能問題?
KNN是一種懶惰演算法,平時不好好學習,考試(對測試樣本分類)時才臨陣磨槍(臨時找k個近鄰)。
懶惰的後果:構造模型很簡單,但是對測試樣本分類的系統開銷很大,因為要掃描全部訓練樣本並計算距離。
已經有一些方法提高計算的效率,例如壓縮訓練樣本量等。
6,能否大幅度減少訓練樣本量,同時又保持分類精度?
濃縮技術(condensing)
編輯技術(editing)
七: 使用k-近鄰演算法改進約會網站的配對效果
James一直使用線上約會網站尋找適合自己的約會物件。儘管約會網站會推薦不同的人選,但是他從來沒有選中喜歡的人,經過一番總結,他發現曾交往過三種類型的人:
- 不喜歡的人
- 魅力一般的人
- 極具魅力的人
儘管發現了上述歸類,但是他依然無法將約會網站推薦的匹配物件歸入恰當的分類。他覺得可以在週一到週五約會哪些魅力一般的人,而週末則更喜歡與那些極具魅力的人為伴。James希望分類軟體可以更好地幫助他將匹配物件劃分到確切的分類中,此外他還收集了一些約會網站未曾記錄的資料資訊,他認為這些資料更有助於匹配物件的歸類。
7.1,在約會網站上使用k-近鄰演算法流程
(1)收集資料:提供文字檔案
(2)準備資料:使用Python解析文字檔案
(3)分析資料:使用Matplotlib畫二維擴散圖
(4)訓練演算法:此步驟不適用於k-近鄰演算法
(5)測試演算法:使用James提供的部分資料作為測試樣本
測試樣本和非測試樣本的區別在於:測試樣本是已經完成分類的資料,如果預測分類與實際類別不同,則標記為一個錯誤。
(6)使用演算法:產生簡單的命令程式,然後James可以輸入一些特徵資料以判斷對方是否為自己喜歡的型別。
7.2,準備資料:從文字檔案中解析資料
James收集的資料有一段時間了,她將這些資料放在文字檔案datingTestSet.txt中,每個樣本資料佔據一行,總共有1000行,James的樣本主要包括以下三種特徵:
- 每年獲得的飛行常客里程數
- 玩視訊遊戲所消耗時間百分比
- 每週消費的冰淇淋公升數
在將資料特徵資料輸入到分類器之前,必須將待處理資料的格式改變為分類器可以接受的格式,我們首先處理輸入格式問題,該函式的輸入問檔名字串,輸出為訓練樣本矩陣和類標籤向量。
將文字記錄到轉換Numpy的解析程式程式碼如下:
# 將文字記錄到轉換Numpy的解析程式 def file2Matrix(filename): fr = open(filename) arrayOLines = fr.readlines() # 得到檔案行數 numberOfLines = len(arrayOLines) # 建立返回的Numpy矩陣 returnMat = zeros((numberOfLines,3)) classLabelVector = [] index = 0 # 解析檔案資料到列表 for line in arrayOLines: # 刪除空白行 line = line.strip() listFromLine = line.split('\t') # 選取前3個元素(特徵)儲存在返回矩陣中 returnMat[index,:] = listFromLine[0:3] # -1索引表示最後一列元素,位label資訊儲存在classLabelVector classLabelVector.append(int(listFromLine[-1])) index += 1 return returnMat,classLabelVector
從上面程式碼可以看到,處理文字檔案非常容易,下面我們測試一下,Python輸出的結果。
if __name__ == '__main__': filename = 'datingTestSet2.txt' datingDataMat,datingLabels = file2Matrix(filename) print(datingDataMat) print(datingLabels) ''' [[4.0920000e+04 8.3269760e+00 9.5395200e-01] [1.4488000e+04 7.1534690e+00 1.6739040e+00] [2.6052000e+04 1.4418710e+00 8.0512400e-01] ... [2.6575000e+04 1.0650102e+01 8.6662700e-01] [4.8111000e+04 9.1345280e+00 7.2804500e-01] [4.3757000e+04 7.8826010e+00 1.3324460e+00]] [3, 2, 1, 1, 1, 1, 3, 3, 1, 3, 1, 1, 2, 1, 1, 1, 1, 1, 2, 3, 2, 1, 2, 3, 2, 3, 2, 3, 2, 1, 3, 1, 3, 1, 2, 1, 1, 2, 3, 3, 1, 2, 3, 3, 3, 1, 1, 1, 1, 2, 2, 1, 3, 2, 2, 2, 2, 3, 1, 2, 1, 2, 2, 2, 2, 2, 3, 2, 3, 1, 2, 3, 2, 2, 1, 3, 1, 1, 3, 3, 1, 2, 3, 1, 3, 1, 2, 2, 1, 1, 3, 3, 1, 2, 1, 3, 3, 2, 1, 1, 3, 1, 2, 3, 3, 2, 3, 3, 1, 2, 3, 2, 1, 3, 1, 2, 1, 1, 2, 3, 2, 3, 2, 3, 2, 1, 3, 3, 3, 1, 3, 2, 2, 3, 1, 3, 3, 3, 1, 3, 1, 1, 3, 3, 2, 3, 3, 1, 2, 3, 2, 2, 3, 3, 3, 1, 2, 2, 1, 1, 3, 2, 3, 3, 1, 2, 1, 3, 1, 2, 3, 2, 3, 1, 1, 1, 3, 2, 3, 1, 3, 2, 1, 3, 2, 2, 3, 2, 3, 2, 1, 1, 3, 1, 3, 2, 2, 2, 3, 2, 2, 1, 2, 2, 3, 1, 3, 3, 2, 1, 1, 1, 2, 1, 3, 3, 3, 3, 2, 1, 1, 1, 2, 3, 2, 1, 3, 1, 3, 2, 2, 3, 1, 3, 1, 1, 2, 1, 2, 2, 1, 3, 1, 3, 2, 3, 1, 2, 3, 1, 1, 1, 1, 2, 3, 2, 2, 3, 1, 2, 1, 1, 1, 3, 3, 2, 1, 1, 1, 2, 2, 3, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 2, 3, 2, 3, 3, 3, 3, 1, 2, 3, 1, 1, 1, 3, 1, 3, 2, 2, 1, 3, 1, 3, 2, 2, 1, 2, 2, 3, 1, 3, 2, 1, 1, 3, 3, 2, 3, 3, 2, 3, 1, 3, 1, 3, 3, 1, 3, 2, 1, 3, 1, 3, 2, 1, 2, 2, 1, 3, 1, 1, 3, 3, 2, 2, 3, 1, 2, 3, 3, 2, 2, 1, 1, 1, 1, 3, 2, 1, 1, 3, 2, 1, 1, 3, 3, 3, 2, 3, 2, 1, 1, 1, 1, 1, 3, 2, 2, 1, 2, 1, 3, 2, 1, 3, 2, 1, 3, 1, 1, 3, 3, 3, 3, 2, 1, 1, 2, 1, 3, 3, 2, 1, 2, 3, 2, 1, 2, 2, 2, 1, 1, 3, 1, 1, 2, 3, 1, 1, 2, 3, 1, 3, 1, 1, 2, 2, 1, 2, 2, 2, 3, 1, 1, 1, 3, 1, 3, 1, 3, 3, 1, 1, 1, 3, 2, 3, 3, 2, 2, 1, 1, 1, 2, 1, 2, 2, 3, 3, 3, 1, 1, 3, 3, 2, 3, 3, 2, 3, 3, 3, 2, 3, 3, 1, 2, 3, 2, 1, 1, 1, 1, 3, 3, 3, 3, 2, 1, 1, 1, 1, 3, 1, 1, 2, 1, 1, 2, 3, 2, 1, 2, 2, 2, 3, 2, 1, 3, 2, 3, 2, 3, 2, 1, 1, 2, 3, 1, 3, 3, 3, 1, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 3, 2, 1, 3, 3, 2, 2, 2, 3, 1, 2, 1, 1, 3, 2, 3, 2, 3, 2, 3, 3, 2, 2, 1, 3, 1, 2, 1, 3, 1, 1, 1, 3, 1, 1, 3, 3, 2, 2, 1, 3, 1, 1, 3, 2, 3, 1, 1, 3, 1, 3, 3, 1, 2, 3, 1, 3, 1, 1, 2, 1, 3, 1, 1, 1, 1, 2, 1, 3, 1, 2, 1, 3, 1, 3, 1, 1, 2, 2, 2, 3, 2, 2, 1, 2, 3, 3, 2, 3, 3, 3, 2, 3, 3, 1, 3, 2, 3, 2, 1, 2, 1, 1, 1, 2, 3, 2, 2, 1, 2, 2, 1, 3, 1, 3, 3, 3, 2, 2, 3, 3, 1, 2, 2, 2, 3, 1, 2, 1, 3, 1, 2, 3, 1, 1, 1, 2, 2, 3, 1, 3, 1, 1, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 2, 2, 3, 1, 3, 1, 2, 3, 2, 2, 3, 1, 2, 3, 2, 3, 1, 2, 2, 3, 1, 1, 1, 2, 2, 1, 1, 2, 1, 2, 1, 2, 3, 2, 1, 3, 3, 3, 1, 1, 3, 1, 2, 3, 3, 2, 2, 2, 1, 2, 3, 2, 2, 3, 2, 2, 2, 3, 3, 2, 1, 3, 2, 1, 3, 3, 1, 2, 3, 2, 1, 3, 3, 3, 1, 2, 2, 2, 3, 2, 3, 3, 1, 2, 1, 1, 2, 1, 3, 1, 2, 2, 1, 3, 2, 1, 3, 3, 2, 2, 2, 1, 2, 2, 1, 3, 1, 3, 1, 3, 3, 1, 1, 2, 3, 2, 2, 3, 1, 1, 1, 1, 3, 2, 2, 1, 3, 1, 2, 3, 1, 3, 1, 3, 1, 1, 3, 2, 3, 1, 1, 3, 3, 3, 3, 1, 3, 2, 2, 1, 1, 3, 3, 2, 2, 2, 1, 2, 1, 2, 1, 3, 2, 1, 2, 2, 3, 1, 2, 2, 2, 3, 2, 1, 2, 1, 2, 3, 3, 2, 3, 1, 1, 3, 3, 1, 2, 2, 2, 2, 2, 2, 1, 3, 3, 3, 3, 3, 1, 1, 3, 2, 1, 2, 1, 2, 2, 3, 2, 2, 2, 3, 1, 2, 1, 2, 2, 1, 1, 2, 3, 3, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 3, 2, 3, 2, 3, 3, 2, 2, 1, 1, 1, 3, 3, 1, 1, 1, 3, 3, 2, 1, 2, 1, 1, 2, 2, 1, 1, 1, 3, 1, 1, 2, 3, 2, 2, 1, 3, 1, 2, 3, 1, 2, 2, 2, 2, 3, 2, 3, 3, 1, 2, 1, 2, 3, 1, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 1, 3, 3, 3] '''
現在我們已經從文字檔案中匯入了資料,並將其格式化為想要的格式,接著我們需要了解資料的真實含義,當然我們可以直接瀏覽文字檔案,但是這種方式並不友好,一般來說,我們會採用圖形化的方式直觀的展示資料,下面使用Python工具來圖形化展示資料內容,以便辨識出一些資料模式。
7.3 分析資料:使用Matplotlib建立散點圖
首先我們使用Matplotlib製作原始資料的散點圖,程式碼如下:
由於沒有使用樣本分類的特徵值,我們很難從上面看到任何有用的資料模式資訊。一般來說,我們會採用色彩或其他的記號來標記不同樣本分類,以便更好地理解資料資訊,Matplotlib庫提供的scatter函式支援個性化標記散點圖上的點,重新輸入上面程式碼,呼叫scatter函式時使用下列引數。
if __name__ == '__main__': filename = 'datingTestSet2.txt' datingDataMat,datingLabels = file2Matrix(filename) print(datingDataMat) print(datingLabels) # 使用Matplotlib建立散點圖 import matplotlib.pyplot as plt from pylab import mpl # 指定預設字型 mpl.rcParams['font.sans-serif'] = ['FangSong'] # 解決儲存影象是負號- 顯示為方塊的問題 mpl.rcParams['axes.unicode_minus'] = False fig = plt.figure() ax = fig.add_subplot(111) ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2], 15.0*array(datingLabels),15.0*array(datingLabels)) plt.xlabel("玩視訊遊戲所耗時間百分比") plt.ylabel("每週消費的冰淇淋公升數") plt.show()
上面程式碼利用了變數datingLabels儲存的類標籤屬性,在散點圖繪製了色彩不等,尺寸不等的點,你可以看到與上面類似的散點圖,從上圖中我們很難看到有用的資訊,然而由下圖的顏色及尺寸標識了資料點的屬性類別,因此我們可以看到所屬三個樣本分類的區域輪廓。
上圖可以區別,但是下圖採用不同的屬性值可以得到更好的結果,圖中清晰的標識了三個不同的樣本分類區域,具有不同愛好的人其類別區域也不同。
7.4 準備資料:歸一化數值
表給出了提取的四組資料,如果想要計算樣本3和樣本4之間的距離,可以使用下面的方法:
計算的公式如下:
我們很容易發現,上面方程中數字差值最大的屬性對計算結果的影響最大,也就是說,每年獲取的飛行常客里程數對於計算結果的影響遠遠大於下標中其他兩個特徵——玩視訊遊戲的和每週消費冰淇淋公升數的影響。而產生這種現象的唯一原因僅僅是飛行常客里程數遠大於其他特徵數值。但是James認為這三種特徵是同等重要的,因此作為三個等權重的特徵之一,飛行常客里程數並不應該如此嚴重的影響到計算結果。
在處理這種不同取值範圍的特徵值時,我們通常採用的方法是將數值歸一化,如將取值範圍處理為0
到1或者-1到1之間。下面的公式可以將任意取值範圍的特徵值轉化到0到1區間內的值:
newValue = (oldValue - min)/(max-min)
其中min和max分別是資料集中的最小特徵值和最大特徵值。雖然改變數值取值範圍增加了分類器的複雜度,但是為了得到準確結果,我們必須這樣做,所以我們新增一個函式autoNorm() ,該函式可以自動將數字特徵轉化為0到1的區間。
歸一化特徵值程式碼如下:
# 歸一化特徵值 # 歸一化公式 : (當前值 - 最小值) / range def autoNorm(dataSet): # 存放每列最小值,引數0使得可以從列中選取最小值,而不是當前行 minVals = dataSet.min(0) # 存放每列最大值 maxVals = dataSet.max(0) ranges = maxVals - minVals # 初始化歸一化矩陣為讀取的dataSet normDataSet = zeros(shape(dataSet)) # 保留第一行 m = dataSet.shape[0] # 特徵值相除,特徵矩陣是3*1000 minmax range是1*3 # 因此採用tile將變數內容複製成輸入矩陣同大小 normDataSet = dataSet - tile(minVals , (m,1)) normDataSet = normDataSet/tile(ranges,(m,1)) return normDataSet,ranges,minVals
那麼執行命令列,得到結果如下:
if __name__ == '__main__': filename = 'datingTestSet2.txt' datingDataMat,datingLabels = file2Matrix(filename) print(datingDataMat) print(datingLabels) normMat,ranges,minVals = autoNorm(datingDataMat) print(normMat) print(ranges) print(minVals) ''' [[0.44832535 0.39805139 0.56233353] [0.15873259 0.34195467 0.98724416] [0.28542943 0.06892523 0.47449629] ... [0.29115949 0.50910294 0.51079493] [0.52711097 0.43665451 0.4290048 ] [0.47940793 0.37680910.78571804]] [9.1273000e+04 2.0919349e+01 1.6943610e+00] [0.0.0.001156] '''
7.5 測試演算法:作為完整程式驗證分類器
上面我們已經將資料按照需求做了處理,本節我們將測試分類器的效果,如果分類器的正確率滿足要求,James就可以使用這個軟體來處理約會網站提供的約會名單了。機器學習演算法一個很重要的工作就是評估演算法的正確率,通常我們只提供已有資料的90%作為訓練樣本來訓練分類器,而使用其餘的10%資料來測試分類器,檢測分類器的正確率。值得注意的是,10%的測試資料應該是隨機選擇的,由於James提供的資料並沒有按照特定目的來排序,所以我們可以隨意選擇10%資料而不影響其隨機性。
前面我們已經提到可以使用錯誤率來檢測分類器的效能。對於分類器來說,錯誤率就是分類器給出的錯誤結果的次數除以測試資料的總數。程式碼裡我們定義一個計算器變數,每次分類器錯誤地分類資料,計數器就加1,程式執行完成之後計數器的結果除以資料點總數即是錯誤率。
為了測試分類器效果,我們建立了datingClassTest,該函式時自包含的,程式碼如下:
分類器針對約會網站的測試程式碼:
# 分類器針對約會網站的測試程式碼 def datingClassTest(): hoRatio = 0.10 datingDataMat,datingLabels = file2Matrix('datingTestSet2.txt') normMat,ranges,minVals = autoNorm(datingDataMat) m = normMat.shape[0] numTestVecs = int(m * hoRatio) errorCount = 0.0 for i in range(numTestVecs): classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:], datingLabels[numTestVecs:m],3) print('the classifier came back with:%d,the read answer is :%d ' %(classifierResult,datingLabels[i])) if (classifierResult != datingLabels[i]): errorCount +=1.0 print("the total error rate is :%f"%(errorCount/float(numTestVecs)))
測試結果如下:
if __name__ == '__main__': datingClassTest() ''' the classifier came back with:3,the read answer is :3 the classifier came back with:2,the read answer is :2 the classifier came back with:1,the read answer is :1 ... the classifier came back with:1,the read answer is :1 the classifier came back with:3,the read answer is :1 the total error rate is :0.050000 '''
分類器處理越活資料集的錯誤率是5%,這是一個相當不錯的結果,我們可以改變函式datingClassTest內變數hoRatio和變數k的值,檢測錯誤率是否隨著變數值的變化而增加,依賴於分類演算法,資料集和程式設定,分類器的輸出結果可能有很大的不同。
這個例子表明我們可以正確的預測分類,錯誤率僅僅是5%。James完全可以輸入未知物件的屬性資訊,由分類軟體來幫助他判定某一物件的可交往程度:討厭,一般喜歡,非常喜歡。
7.6 使用演算法:構建完整可用系統
上面我們已經在資料上對分類器進行了測試,現在我們可以使用這個分類器為James來對人們分類,我們會給James給一小段程式,通過該程式會在約會網站上找到某個人並輸入他的資訊,程式會給出他對對方喜歡程式的預測值。
約會網站預測函式程式碼:
# 約會網站預測函式 def classifyPerson(): resultList = ['not at all','in small','in large doses'] percentTats = float(input("percentage of time spent playing video games ?")) ffMiles = float(input("frequent flier miles earned per year?")) iceCream = float(input("liters od ice cream consumed per year?")) dataingDataMat,datingLabels = file2Matrix('datingTestSet2.txt') normMat,ranges,minVals = autoNorm(datingDataMat) inArr = array([ffMiles,percentTats,iceCream]) classifierResult = classify0((inArr-minVals)/ranges,normMat,datingLabels,3) print("You will probably like this person ",resultList[classifierResult-1])
測試結果如下:
if __name__ == '__main__': # datingClassTest() filename = 'datingTestSet2.txt' datingDataMat,datingLabels = file2Matrix(filename) classifyPerson() ''' percentage of time spent playing video games ?10 frequent flier miles earned per year?10000 liters od ice cream consumed per year?0.5 You will probably like this personin small '''
八,手寫數字識別系統(兩種方法實戰)
一,基於K-近鄰演算法的手寫識別系統
這裡我們一步步的構造使用k-近鄰分類器的手寫識別系統,為了簡單起見,這裡構造的系統只能識別數字0到9,參考圖2-6,需要識別的數字已經使用圖形處理軟體,處理成具有相同的色彩和大小:寬高是32畫素*32畫素的黑白圖形。儘管採用文字格式儲存圖形不能有效地利用記憶體空間,但是為了方便理解,我們還是將影象轉化為文字格式。
1.1,使用k-近鄰演算法的手寫識別系統步驟
(1)收集資料:提供文字檔案
(2)準備資料:編寫函式classify0(),將影象格式轉換為分類器使用的list格式
(3)分析資料:在Python命令提示符中檢查資料,確保它符合要求
(4)訓練演算法:此步驟不適用與k-近鄰演算法
(5)測試演算法:編寫函式使用提供的部分資料集作為測試樣本,測試樣本與非測試樣本的區別在於測試樣本是已經完成分類的資料,如果預測分類與實際類別不同,則標記為一個錯誤
(6)使用演算法:本例沒有完成此步驟,如果感興趣的話,可以構建完整的應用程式,從圖形中提取數字,並完成數字識別
1.2,準備資料:將影象轉換為測試向量
實際圖形儲存在原始碼的兩個子目錄中:目標trainingDigits中包含了大約2000個例子,每個例子如圖2-6所示,每個數字大約有200個樣本;目錄testDigits中包含了大約900個測試資料,我們使用目錄trainingDigits中的資料訓練分類器,使用目錄testDigits中的資料測試分類器的效果,兩組資料沒有覆蓋,你可以檢查一下這些資料夾的檔案是否符合要求。
為了使用前面兩個例子的分類器,我們必須將影象格式化處理為一個向量,我們將把一個32*32的二進位制影象矩陣轉化為1*1024的向量,這樣前兩節使用的分類器就可以處理數字影象資訊了。
我們首先編寫一段函式img2vector,將影象轉化為向量:該函式建立1*1024的Numpy陣列,然後開啟給定的問價,迴圈讀出檔案的前32行,並將每行的頭32個字串儲存在Numpy陣列中,最後返回陣列。
def img2vector(filename): # 每個手寫識別為32*32大小的二進位制影象矩陣,轉換為1*1024 numpy向量陣列returenVect returnVect = zeros((1,1024)) fr = open(filename) # 迴圈讀出前32行 for i in range(32): lineStr = fr.readline() for j in range(32): # 將每行的32個字元值儲存在numpy陣列中 returnVect[0,32*i+j] = int(lineStr[j]) return returnVect
測試程式碼,結果如下:
if __name__ == '__main__': res = img2vector('testDigits/0_13.txt') print(res[0,0:31]) print(res[0,32:63]) ''' [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] '''
1.3,測試演算法:使用k-近鄰演算法識別手寫數字
上面我們已經將資料處理成分類器可以識別的格式,本小節我們將這些資料輸入到分類器,檢測分類器的執行效果,下面程式是測試分類器的程式碼。
手寫數字識別系統的測試程式碼:
# 測試演算法 def handwritingClassTest(): hwLabels = [] trainingFileList = os.listdir('trainingDigits') m = len(trainingFileList) # 定義檔案數 x 每個向量的訓練集 for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] # 解析檔名 classNumStr = int(fileStr.split('_')[0]) # 儲存類別 hwLabels.append(classNumStr) # 訪問第i個檔案內的資料 trainingMat[i,:] = img2vector('trainingDigits/%s'%fileNameStr) # 測試資料集 testFileList = os.listdir('testDigits') errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] # 從檔名中分離出數字作為基準 classNumStr = int(fileStr.split('_')[0]) # 訪問第i個檔案內的測試資料,不儲存類 直接測試 vectorUnderTest = img2vector('testDigits/%s'%fileNameStr) classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3) print("the classifier came back with: %d,the real answer is: %d" %(classifierResult,classNumStr)) if(classifierResult!=classNumStr): errorCount+=1.0 print("\nthe total number of errors is: %d" % errorCount) print("\nthe total rate is:%f"% (errorCount/float(mTest)))
測試程式碼,結果如下:
the classifier came back with: 0,the real answer is: 0 the total number of errors is: 0 the total rate is:0.000000 ...... the total number of errors is: 4 the total rate is:0.004228 ... ... the classifier came back with: 9,the real answer is: 9 the total number of errors is: 9 the total rate is:0.009514 the classifier came back with: 9,the real answer is: 9 the total number of errors is: 10 the total rate is:0.010571
k-近鄰演算法識別手寫數字集,錯誤率為1.2%,改變變數k的值,修改函式handwritingClassTest隨機選取訓練樣本,改變訓練樣本的數目,都會對k-近鄰演算法的錯誤率產生影響,感興趣的話可以改變這些變數值,觀察錯誤率的變化。
實際使用這個演算法的時候,演算法的執行效率並不高,因為演算法需要為每個測試向量做2000次距離計算,每個距離計算包括了1024個維度浮點計算,總計要執行900次。此外,我們還需要為測試向量準備2MB的儲存空間。是否存在一種演算法減少儲存空間和計算時間的開銷呢?k決策樹就是k-近鄰演算法的優化版,可以節省大量的計算開銷。
1.4 使用k-近鄰演算法識別手寫數字完整程式碼
import os ,sys from numpy import * import operator # k-近鄰演算法 def classify0(inX,dataSet,labels,k): # shape讀取資料矩陣第一維度的長度 dataSetSize = dataSet.shape[0] # tile重複陣列inX,有dataSet行 1個dataSet列,減法計算差值 diffMat = tile(inX,(dataSetSize,1)) - dataSet # **是冪運算的意思,這裡用的歐式距離 sqDiffMat = diffMat ** 2 # 普通sum預設引數為axis=0為普通相加,axis=1為一行的行向量相加 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances ** 0.5 # argsort返回數值從小到大的索引值(陣列索引0,1,2,3) sortedDistIndicies = distances.argsort() # 選擇距離最小的k個點 classCount = {} for i in range(k): # 根據排序結果的索引值返回靠近的前k個標籤 voteLabel = labels[sortedDistIndicies[i]] # 各個標籤出現頻率 classCount[voteLabel] = classCount.get(voteLabel,0) +1 ##!!!!!classCount.iteritems()修改為classCount.items() #sorted(iterable, cmp=None, key=None, reverse=False) --> new sorted list。 # reverse預設升序 key關鍵字排序itemgetter(1)按照第一維度排序(0,1,2,3) sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] def img2vector(filename): # 每個手寫識別為32*32大小的二進位制影象矩陣,轉換為1*1024 numpy向量陣列returenVect returnVect = zeros((1,1024)) fr = open(filename) # 迴圈讀出前32行 for i in range(32): lineStr = fr.readline() for j in range(32): # 將每行的32個字元值儲存在numpy陣列中 returnVect[0,32*i+j] = int(lineStr[j]) return returnVect # 測試演算法 def handwritingClassTest(): hwLabels = [] trainingFileList = os.listdir('trainingDigits') m = len(trainingFileList) # 定義檔案數 x 每個向量的訓練集 trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] # 解析檔名 classNumStr = int(fileStr.split('_')[0]) # 儲存類別 hwLabels.append(classNumStr) # 訪問第i個檔案內的資料 trainingMat[i,:] = img2vector('trainingDigits/%s'%fileNameStr) # 測試資料集 testFileList = os.listdir('testDigits') errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] # 從檔名中分離出數字作為基準 classNumStr = int(fileStr.split('_')[0]) # 訪問第i個檔案內的測試資料,不儲存類 直接測試 vectorUnderTest = img2vector('testDigits/%s'%fileNameStr) classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3) print("the classifier came back with: %d,the real answer is: %d" %(classifierResult,classNumStr)) if(classifierResult!=classNumStr): errorCount+=1.0 print("\nthe total number of errors is: %d" % errorCount) print("\nthe total rate is:%f"% (errorCount/float(mTest))) if __name__ == '__main__': handwritingClassTest() # res = img2vector('testDigits/0_13.txt') # print(res[0,0:31]) # print(res[0,32:63])
1.5 小結
k-近鄰演算法是分類資料最簡單最有效的演算法,本文通過兩個例子講述瞭如何使用k-近鄰演算法構造分類器,k-近鄰演算法是基於例項的學習,使用演算法時我們必須有接近實際資料的訓練樣本資料,k-近鄰演算法必須儲存全部的資料集,如果訓練資料集的很大,必須使用大量的儲存空間。此外,由於必須對資料集中的每個資料計算距離值,實際使用時可能非常耗時。
k-近鄰演算法的另一個缺陷是它無法給出任何資料的基礎結構資訊,因此我們也無法知曉平均例項樣本和典型例項樣本具有什麼特徵。
二,基於Sklearn的k-近鄰演算法實戰(手寫數字識別)
2.1,實戰背景
上面已經說明了,將圖片轉換為文字格式,這裡不再累贅,我們使用同樣的方法處理。接下來我們使用強大的第三方Python科學計算庫Sklearn構建手寫數字系統。
2.2,Sklearn中k-近鄰演算法引數說明
關於sklearn的英文官方文件地址: 點我檢視
sklearn.neighbors模組實現了k-近鄰演算法,內容如下圖:
我們使用sklearn.neighbors.KNeighborsClassifier就可以是實現上小結,我們實現的k-近鄰演算法。KNeighborsClassifier函式一共有8個引數,如圖所示:
KNneighborsClassifier引數說明:
- n_neighbors:預設為5,就是k-NN的k的值,選取最近的k個點。
- weights:預設是uniform,引數可以是uniform、distance,也可以是使用者自己定義的函式。uniform是均等的權重,就說所有的鄰近點的權重都是相等的。distance是不均等的權重,距離近的點比距離遠的點的影響大。使用者自定義的函式,接收距離的陣列,返回一組維數相同的權重。
- algorithm:快速k近鄰搜尋演算法,預設引數為auto,可以理解為演算法自己決定合適的搜尋演算法。除此之外,使用者也可以自己指定搜尋演算法ball_tree、kd_tree、brute方法進行搜尋,brute是蠻力搜尋,也就是線性掃描,當訓練集很大時,計算非常耗時。kd_tree,構造kd樹儲存資料以便對其進行快速檢索的樹形資料結構,kd樹也就是資料結構中的二叉樹。以中值切分構造的樹,每個結點是一個超矩形,在維數小於20時效率高。ball tree是為了克服kd樹高緯失效而發明的,其構造過程是以質心C和半徑r分割樣本空間,每個節點是一個超球體。
- leaf_size:預設是30,這個是構造的kd樹和ball樹的大小。這個值的設定會影響樹構建的速度和搜尋速度,同樣也影響著儲存樹所需的記憶體大小。需要根據問題的性質選擇最優的大小。
- metric:用於距離度量,預設度量是minkowski,也就是p=2的歐氏距離(歐幾里德度量)。
- p:距離度量公式。在上小結,我們使用歐氏距離公式進行距離度量。除此之外,還有其他的度量方法,例如曼哈頓距離。這個引數預設為2,也就是預設使用歐式距離公式進行距離度量。也可以設定為1,使用曼哈頓距離公式進行距離度量。
- metric_params:距離公式的其他關鍵引數,這個可以不管,使用預設的None即可。
- n_jobs:並行處理設定。預設為1,臨近點搜尋並行工作數。如果為-1,那麼CPU的所有cores都用於並行工作。
2.3,sklearn實戰手寫數字識別系統
我們知道數字圖片是32x32的二進位制影象,為了方便計算,我們可以將32x32的二進位制影象轉換為1x1024的向量。對於sklearn的KNeighborsClassifier輸入可以是矩陣,不用一定轉換為向量,不過為了跟自己寫的k-近鄰演算法分類器對應上,這裡也做了向量化處理。然後構建kNN分類器,利用分類器做預測。建立kNN_test04.py檔案,編寫程式碼如下:
#_*_coding:utf-8_*_ import numpy as np import operator import os from sklearn.neighbors import KNeighborsClassifier as KNN def img2vector(filename): # 建立1*1024零向量 returnVect = np.zeros((1,1024)) # 開啟檔案 fr = open(filename) # 按照行讀取 for i in range(32): # 讀一行資料 lineStr = fr.readline() # 每一行的前32個元素依次新增到returnVect中 for j in range(32): returnVect[0,32*i+j] = int(lineStr[j]) # 返回轉換後的1*1024向量 return returnVect # 手寫數字分類測試 def handwritingClassTest(): # 測試集的Labels hwLabels = [] # 返回trainingDigts目錄下的檔名 trainingFileList = os.listdir('trainingDigits') # 返回資料夾下檔案的個數 m = len(trainingFileList) # 初始化訓練的Mat矩陣,測試集 trainingMat = np.zeros((m,1024)) # 從檔名中解析出訓練集的類別 for i in range(m): # 獲取檔案的名字 fileNameStr = trainingFileList[i] # 獲得分類的數字 classNumber = int(fileNameStr.split('_')[0]) # 將獲得的類別新增到hwLabels中 hwLabels.append(classNumber) # 將每一個檔案的1*1024資料儲存到trainingMat矩陣中 trainingMat[i,:] = img2vector('trainingDigits/%s' % (fileNameStr)) # 構建KNN分類器 neigh = KNN(n_neighbors=3,algorithm='auto') # 擬合模型,trainingMat為訓練矩陣,hwLabels為對應的標籤 neigh.fit(trainingMat,hwLabels) # 返回testDigits目錄下的檔案列表 testFileList = os.listdir('testDigits') # 錯誤檢測計數 errorCount = 0.0 # 測試資料的數量 mTest = len(testFileList) # 從檔案中解析出測試集的類別並進行分類測試 for i in range(mTest): # 獲取檔案的名字 fileNameStr = testFileList[i] # 獲取分類的數字 classNumber = int(fileNameStr.split('_')[0]) # 獲得測試集的1*1024向量,用於訓練 vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr)) # 獲得預測結果 # classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) classifierResult = neigh.predict(vectorUnderTest) print("分類返回結果為%d\t真實結果為%d" % (classifierResult, classNumber)) if (classifierResult != classNumber): errorCount += 1.0 print("總共錯了%d個數據\n錯誤率為%f%%" % (errorCount, errorCount / mTest * 100)) if __name__ == '__main__': handwritingClassTest()
測試結果如下:
分類返回結果為0真實結果為0 分類返回結果為0真實結果為0 分類返回結果為0真實結果為0 分類返回結果為0真實結果為0 分類返回結果為0真實結果為0 分類返回結果為0真實結果為0 分類返回結果為0真實結果為0 分類返回結果為0真實結果為0 分類返回結果為0真實結果為0 分類返回結果為0真實結果為0 ...... 分類返回結果為9真實結果為9 總共錯了12個數據 錯誤率為1.268499%
上述程式碼使用的algorithm引數是auto,更改algorithm引數為brute,使用暴力搜尋,你會發現,執行時間變長了,變為10s+。更改n_neighbors引數,你會發現,不同的值,檢測精度也是不同的。自己可以嘗試更改這些引數的設定,加深對其函式的理解。
參考文獻:https://www.jianshu.com/p/1ded78d92e68