How to implement .dat file for handwritten recognition using SVM in Python

萝らか妹 提交于 2021-01-28 06:35:27

问题


I've been trying to train Hand-written Digits using SVM based on the code on OpenCV library. My training part is as follow:

import cv2
import numpy as np

SZ=20
bin_n = 16
svm_params = dict( kernel_type = cv2.SVM_LINEAR,
                   svm_type = cv2.SVM_C_SVC,
                C=2.67, gamma=5.383 )
affine_flags = cv2.WARP_INVERSE_MAP|cv2.INTER_LINEAR

def deskew(img):
    m = cv2.moments(img)
    if abs(m['mu02']) < 1e-2:
        return img.copy()
    skew = m['mu11']/m['mu02']
    M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
    img = cv2.warpAffine(img,M,(SZ, SZ),flags=affine_flags)
    return img
def hog(img):
    gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
    gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
    mag, ang = cv2.cartToPolar(gx, gy)
    bins = np.int32(bin_n*ang/(2*np.pi))    # quantizing binvalues in (0...16)
    bin_cells = bins[:10,:10], bins[10:,:10], bins[:10,10:], bins[10:,10:]
    mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:]
    hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
    hist = np.hstack(hists)     # hist is a 64 bit vector
    return hist

img = cv2.imread('digits.png',0)
if img is None:
    raise Exception("we need the digits.png image from samples/data here !")


cells = [np.hsplit(row,100) for row in np.vsplit(img,50)]

train_cells = [ i[:50] for i in cells ]
test_cells = [ i[50:] for i in cells]

deskewed = [map(deskew,row) for row in train_cells]
hogdata = [map(hog,row) for row in deskewed]
trainData = np.float32(hogdata).reshape(-1,64)
responses = np.float32(np.repeat(np.arange(10),250)[:,np.newaxis])

svm = cv2.SVM()
svm.train(trainData,responses, params=svm_params)
svm.save('svm_data.dat')

Heres the digits.png enter image description here

As a result, I got the svm_data.dat file. But now I don't know how to implement the model. Lets say I want to read this number here enter image description here

Can anyone help me out please?


回答1:


I am going to assume that by "how to implement the model" you mean "how to predict the label for a new image".

First off, note that this does not have anything to do with the saved svm_data.dat per se, unless you want to do this in a different script/session, in which case you can reload your trained svm object from the file.

With that out of the way, making a prediction for new data requires three steps:

  1. If your new data is somehow different from the training data, preprocess it so it matches the training data (see inverting and resizing below).

  2. Extract features the same way as was done for the training data.

  3. Use the trained classifier to predict the label.

For the example image you uploaded, this can be done as follows:

# Load the image
img_predict = cv2.imread('predict.png', 0)

# Preprocessing: this image is inverted compared to the training data
# Here it is inverted back
img_predict = np.invert(img_predict)

# Preprocessing: it also has a completely different size
# This resizes it to the same size as the training data
img_predict = cv2.resize(img_predict, (20, 20), interpolation=cv2.INTER_CUBIC)

# Extract the features
img_predict_ready = np.float32(hog(deskew(img_predict)))

# Reload the trained svm
# Not necessary if predicting is done in the same session as training
svm = cv2.SVM()
svm.load("svm_data.dat")

# Run the prediction
prediction = svm.predict(img_predict_ready)
print int(prediction)

The output is 0, as expected.

Note that it is very important to match the data you want to classify to the data you used for training. In this case, skipping the re-sizing step will lead to a mis-classification of the image as 2.

Furthermore, a closer look at the image reveals that it is still a bit different from the training data (more background, different mean), so I would not be surprised if the classifier ends up performing a bit worse on images like this compared to the test data used in the original tutorial (which was just half of the training data). However, this depends on how sensitive the feature extraction is to the differences between the training images and your prediction images.



来源:https://stackoverflow.com/questions/45807351/how-to-implement-dat-file-for-handwritten-recognition-using-svm-in-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!