Created
February 18, 2019 16:20
-
-
Save pietyta/067e2e081946e942b4bd693d78d3727d to your computer and use it in GitHub Desktop.
digit classification with SVM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Standard scientific Python imports | |
import matplotlib.pyplot as plt | |
# Import datasets, classifiers and performance metrics | |
from sklearn import datasets, svm, metrics | |
# The digits dataset | |
digits = datasets.load_digits() | |
# The data that we are interested in is made of 8x8 images of digits, let's | |
# have a look at the first 4 images, stored in the `images` attribute of the | |
# dataset. If we were working from image files, we could load them using | |
# matplotlib.pyplot.imread. Note that each image must have the same size. For these | |
# images, we know which digit they represent: it is given in the 'target' of | |
# the dataset. | |
images_and_labels = list(zip(digits.images, digits.target)) | |
for index, (image, label) in enumerate(images_and_labels[:4]): | |
plt.subplot(2, 4, index + 1) | |
plt.axis('off') | |
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') | |
plt.title('Training: %i' % label) | |
# To apply a classifier on this data, we need to flatten the image, to | |
# turn the data in a (samples, feature) matrix: | |
n_samples = len(digits.images) | |
data = digits.images.reshape((n_samples, -1)) | |
# Create a classifier: a support vector classifier | |
classifier = svm.SVC(gamma=0.001) | |
# We learn the digits on the first half of the digits | |
classifier.fit(data[:n_samples // 2], digits.target[:n_samples // 2]) | |
# Now predict the value of the digit on the second half: | |
expected = digits.target[n_samples // 2:] | |
predicted = classifier.predict(data[n_samples // 2:]) | |
print("Classification report for classifier %s:\n%s\n" | |
% (classifier, metrics.classification_report(expected, predicted))) | |
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted)) | |
images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted)) | |
for index, (image, prediction) in enumerate(images_and_predictions[:4]): | |
plt.subplot(2, 4, index + 5) | |
plt.axis('off') | |
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') | |
plt.title('Prediction: %i' % prediction) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment