1#!/usr/bin/env python
2
3'''
4SVM and KNearest digit recognition.
5
6Sample loads a dataset of handwritten digits from '../data/digits.png'.
7Then it trains a SVM and KNearest classifiers on it and evaluates
8their accuracy.
9
10Following preprocessing is applied to the dataset:
11 - Moment-based image deskew (see deskew())
12 - Digit images are split into 4 10x10 cells and 16-bin
13   histogram of oriented gradients is computed for each
14   cell
15 - Transform histograms to space with Hellinger metric (see [1] (RootSIFT))
16
17
18[1] R. Arandjelovic, A. Zisserman
19    "Three things everyone should know to improve object retrieval"
20    http://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf
21
22Usage:
23   digits.py
24'''
25
26# built-in modules
27from multiprocessing.pool import ThreadPool
28
29import cv2
30
31import numpy as np
32from numpy.linalg import norm
33
34# local modules
35from common import clock, mosaic
36
37
38
39SZ = 20 # size of each digit is SZ x SZ
40CLASS_N = 10
41DIGITS_FN = '../data/digits.png'
42
43def split2d(img, cell_size, flatten=True):
44    h, w = img.shape[:2]
45    sx, sy = cell_size
46    cells = [np.hsplit(row, w//sx) for row in np.vsplit(img, h//sy)]
47    cells = np.array(cells)
48    if flatten:
49        cells = cells.reshape(-1, sy, sx)
50    return cells
51
52def load_digits(fn):
53    print 'loading "%s" ...' % fn
54    digits_img = cv2.imread(fn, 0)
55    digits = split2d(digits_img, (SZ, SZ))
56    labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
57    return digits, labels
58
59def deskew(img):
60    m = cv2.moments(img)
61    if abs(m['mu02']) < 1e-2:
62        return img.copy()
63    skew = m['mu11']/m['mu02']
64    M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
65    img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
66    return img
67
68class StatModel(object):
69    def load(self, fn):
70        self.model.load(fn)
71    def save(self, fn):
72        self.model.save(fn)
73
74class KNearest(StatModel):
75    def __init__(self, k = 3):
76        self.k = k
77        self.model = cv2.ml.KNearest_create()
78
79    def train(self, samples, responses):
80        self.model = cv2.ml.KNearest_create()
81        self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
82
83    def predict(self, samples):
84        retval, results, neigh_resp, dists = self.model.findNearest(samples, self.k)
85        return results.ravel()
86
87class SVM(StatModel):
88    def __init__(self, C = 1, gamma = 0.5):
89        self.model = cv2.ml.SVM_create()
90        self.model.setGamma(gamma)
91        self.model.setC(C)
92        self.model.setKernel(cv2.ml.SVM_RBF)
93        self.model.setType(cv2.ml.SVM_C_SVC)
94
95    def train(self, samples, responses):
96        self.model = cv2.ml.SVM_create()
97        self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
98
99    def predict(self, samples):
100        return self.model.predict(samples)[1][0].ravel()
101
102
103def evaluate_model(model, digits, samples, labels):
104    resp = model.predict(samples)
105    err = (labels != resp).mean()
106    print 'error: %.2f %%' % (err*100)
107
108    confusion = np.zeros((10, 10), np.int32)
109    for i, j in zip(labels, resp):
110        confusion[i, j] += 1
111    print 'confusion matrix:'
112    print confusion
113    print
114
115    vis = []
116    for img, flag in zip(digits, resp == labels):
117        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
118        if not flag:
119            img[...,:2] = 0
120        vis.append(img)
121    return mosaic(25, vis)
122
123def preprocess_simple(digits):
124    return np.float32(digits).reshape(-1, SZ*SZ) / 255.0
125
126def preprocess_hog(digits):
127    samples = []
128    for img in digits:
129        gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
130        gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
131        mag, ang = cv2.cartToPolar(gx, gy)
132        bin_n = 16
133        bin = np.int32(bin_n*ang/(2*np.pi))
134        bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:]
135        mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:]
136        hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
137        hist = np.hstack(hists)
138
139        # transform to Hellinger kernel
140        eps = 1e-7
141        hist /= hist.sum() + eps
142        hist = np.sqrt(hist)
143        hist /= norm(hist) + eps
144
145        samples.append(hist)
146    return np.float32(samples)
147
148
149if __name__ == '__main__':
150    print __doc__
151
152    digits, labels = load_digits(DIGITS_FN)
153
154    print 'preprocessing...'
155    # shuffle digits
156    rand = np.random.RandomState(321)
157    shuffle = rand.permutation(len(digits))
158    digits, labels = digits[shuffle], labels[shuffle]
159
160    digits2 = map(deskew, digits)
161    samples = preprocess_hog(digits2)
162
163    train_n = int(0.9*len(samples))
164    cv2.imshow('test set', mosaic(25, digits[train_n:]))
165    digits_train, digits_test = np.split(digits2, [train_n])
166    samples_train, samples_test = np.split(samples, [train_n])
167    labels_train, labels_test = np.split(labels, [train_n])
168
169
170    print 'training KNearest...'
171    model = KNearest(k=4)
172    model.train(samples_train, labels_train)
173    vis = evaluate_model(model, digits_test, samples_test, labels_test)
174    cv2.imshow('KNearest test', vis)
175
176    print 'training SVM...'
177    model = SVM(C=2.67, gamma=5.383)
178    model.train(samples_train, labels_train)
179    vis = evaluate_model(model, digits_test, samples_test, labels_test)
180    cv2.imshow('SVM test', vis)
181    print 'saving SVM as "digits_svm.dat"...'
182    model.save('digits_svm.dat')
183
184    cv2.waitKey(0)
185