1#!/usr/bin/env python 2 3''' 4SVM and KNearest digit recognition. 5 6Sample loads a dataset of handwritten digits from '../data/digits.png'. 7Then it trains a SVM and KNearest classifiers on it and evaluates 8their accuracy. 9 10Following preprocessing is applied to the dataset: 11 - Moment-based image deskew (see deskew()) 12 - Digit images are split into 4 10x10 cells and 16-bin 13 histogram of oriented gradients is computed for each 14 cell 15 - Transform histograms to space with Hellinger metric (see [1] (RootSIFT)) 16 17 18[1] R. Arandjelovic, A. Zisserman 19 "Three things everyone should know to improve object retrieval" 20 http://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf 21 22Usage: 23 digits.py 24''' 25 26# built-in modules 27from multiprocessing.pool import ThreadPool 28 29import cv2 30 31import numpy as np 32from numpy.linalg import norm 33 34# local modules 35from common import clock, mosaic 36 37 38 39SZ = 20 # size of each digit is SZ x SZ 40CLASS_N = 10 41DIGITS_FN = '../data/digits.png' 42 43def split2d(img, cell_size, flatten=True): 44 h, w = img.shape[:2] 45 sx, sy = cell_size 46 cells = [np.hsplit(row, w//sx) for row in np.vsplit(img, h//sy)] 47 cells = np.array(cells) 48 if flatten: 49 cells = cells.reshape(-1, sy, sx) 50 return cells 51 52def load_digits(fn): 53 print 'loading "%s" ...' % fn 54 digits_img = cv2.imread(fn, 0) 55 digits = split2d(digits_img, (SZ, SZ)) 56 labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N) 57 return digits, labels 58 59def deskew(img): 60 m = cv2.moments(img) 61 if abs(m['mu02']) < 1e-2: 62 return img.copy() 63 skew = m['mu11']/m['mu02'] 64 M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]]) 65 img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR) 66 return img 67 68class StatModel(object): 69 def load(self, fn): 70 self.model.load(fn) 71 def save(self, fn): 72 self.model.save(fn) 73 74class KNearest(StatModel): 75 def __init__(self, k = 3): 76 self.k = k 77 self.model = cv2.ml.KNearest_create() 78 79 def train(self, samples, responses): 80 self.model = cv2.ml.KNearest_create() 81 self.model.train(samples, cv2.ml.ROW_SAMPLE, responses) 82 83 def predict(self, samples): 84 retval, results, neigh_resp, dists = self.model.findNearest(samples, self.k) 85 return results.ravel() 86 87class SVM(StatModel): 88 def __init__(self, C = 1, gamma = 0.5): 89 self.model = cv2.ml.SVM_create() 90 self.model.setGamma(gamma) 91 self.model.setC(C) 92 self.model.setKernel(cv2.ml.SVM_RBF) 93 self.model.setType(cv2.ml.SVM_C_SVC) 94 95 def train(self, samples, responses): 96 self.model = cv2.ml.SVM_create() 97 self.model.train(samples, cv2.ml.ROW_SAMPLE, responses) 98 99 def predict(self, samples): 100 return self.model.predict(samples)[1][0].ravel() 101 102 103def evaluate_model(model, digits, samples, labels): 104 resp = model.predict(samples) 105 err = (labels != resp).mean() 106 print 'error: %.2f %%' % (err*100) 107 108 confusion = np.zeros((10, 10), np.int32) 109 for i, j in zip(labels, resp): 110 confusion[i, j] += 1 111 print 'confusion matrix:' 112 print confusion 113 print 114 115 vis = [] 116 for img, flag in zip(digits, resp == labels): 117 img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 118 if not flag: 119 img[...,:2] = 0 120 vis.append(img) 121 return mosaic(25, vis) 122 123def preprocess_simple(digits): 124 return np.float32(digits).reshape(-1, SZ*SZ) / 255.0 125 126def preprocess_hog(digits): 127 samples = [] 128 for img in digits: 129 gx = cv2.Sobel(img, cv2.CV_32F, 1, 0) 130 gy = cv2.Sobel(img, cv2.CV_32F, 0, 1) 131 mag, ang = cv2.cartToPolar(gx, gy) 132 bin_n = 16 133 bin = np.int32(bin_n*ang/(2*np.pi)) 134 bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:] 135 mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:] 136 hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)] 137 hist = np.hstack(hists) 138 139 # transform to Hellinger kernel 140 eps = 1e-7 141 hist /= hist.sum() + eps 142 hist = np.sqrt(hist) 143 hist /= norm(hist) + eps 144 145 samples.append(hist) 146 return np.float32(samples) 147 148 149if __name__ == '__main__': 150 print __doc__ 151 152 digits, labels = load_digits(DIGITS_FN) 153 154 print 'preprocessing...' 155 # shuffle digits 156 rand = np.random.RandomState(321) 157 shuffle = rand.permutation(len(digits)) 158 digits, labels = digits[shuffle], labels[shuffle] 159 160 digits2 = map(deskew, digits) 161 samples = preprocess_hog(digits2) 162 163 train_n = int(0.9*len(samples)) 164 cv2.imshow('test set', mosaic(25, digits[train_n:])) 165 digits_train, digits_test = np.split(digits2, [train_n]) 166 samples_train, samples_test = np.split(samples, [train_n]) 167 labels_train, labels_test = np.split(labels, [train_n]) 168 169 170 print 'training KNearest...' 171 model = KNearest(k=4) 172 model.train(samples_train, labels_train) 173 vis = evaluate_model(model, digits_test, samples_test, labels_test) 174 cv2.imshow('KNearest test', vis) 175 176 print 'training SVM...' 177 model = SVM(C=2.67, gamma=5.383) 178 model.train(samples_train, labels_train) 179 vis = evaluate_model(model, digits_test, samples_test, labels_test) 180 cv2.imshow('SVM test', vis) 181 print 'saving SVM as "digits_svm.dat"...' 182 model.save('digits_svm.dat') 183 184 cv2.waitKey(0) 185