diff --git a/examples/svm/plot_separating_hyperplane.py b/examples/svm/plot_separating_hyperplane.py index ff6f3fc8f31ad..fafadb2d381d0 100644 --- a/examples/svm/plot_separating_hyperplane.py +++ b/examples/svm/plot_separating_hyperplane.py @@ -12,37 +12,33 @@ import numpy as np import matplotlib.pyplot as plt from sklearn import svm +from sklearn.datasets import make_blobs + # we create 40 separable points -np.random.seed(0) -X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]] -Y = [0] * 20 + [1] * 20 +X, y = make_blobs(n_samples=40, centers=2, random_state=12, cluster_std=0.35) # fit the model clf = svm.SVC(kernel='linear') -clf.fit(X, Y) - -# get the separating hyperplane -w = clf.coef_[0] -a = -w[0] / w[1] -xx = np.linspace(-5, 5) -yy = a * xx - (clf.intercept_[0]) / w[1] - -# plot the parallels to the separating hyperplane that pass through the -# support vectors -b = clf.support_vectors_[0] -yy_down = a * xx + (b[1] - a * b[0]) -b = clf.support_vectors_[-1] -yy_up = a * xx + (b[1] - a * b[0]) - -# plot the line, the points, and the nearest vectors to the plane -plt.plot(xx, yy, 'k-') -plt.plot(xx, yy_down, 'k--') -plt.plot(xx, yy_up, 'k--') - -plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], - s=80, facecolors='none') -plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired) - -plt.axis('tight') -plt.show() +clf.fit(X, y) + +plt.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=plt.cm.Paired) + +# plot the decision function +ax = plt.gca() +xlim = ax.get_xlim() +ylim = ax.get_ylim() + +# create grid to evaluate model +xx = np.linspace(xlim[0], xlim[1], 30) +yy = np.linspace(ylim[0], ylim[1], 30) +YY, XX = np.meshgrid(yy, xx) +xy = np.vstack([XX.ravel(), YY.ravel()]).T +Z = clf.decision_function(xy).reshape(XX.shape) + +# plot decision boundary and margins +ax.contour(XX, YY, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, + linestyles=['--', '-', '--']) +# plot support vectors +ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=100, + linewidth=1, facecolors='none') diff --git a/examples/svm/plot_separating_hyperplane_unbalanced.py b/examples/svm/plot_separating_hyperplane_unbalanced.py index 438291dc5538d..cf3130a6ae5c5 100644 --- a/examples/svm/plot_separating_hyperplane_unbalanced.py +++ b/examples/svm/plot_separating_hyperplane_unbalanced.py @@ -29,7 +29,6 @@ import numpy as np import matplotlib.pyplot as plt from sklearn import svm -#from sklearn.linear_model import SGDClassifier # we create 40 separable points rng = np.random.RandomState(0) @@ -43,25 +42,36 @@ clf = svm.SVC(kernel='linear', C=1.0) clf.fit(X, y) -w = clf.coef_[0] -a = -w[0] / w[1] -xx = np.linspace(-5, 5) -yy = a * xx - clf.intercept_[0] / w[1] - - -# get the separating hyperplane using weighted classes +# fit the model and get the separating hyperplane using weighted classes wclf = svm.SVC(kernel='linear', class_weight={1: 10}) wclf.fit(X, y) -ww = wclf.coef_[0] -wa = -ww[0] / ww[1] -wyy = wa * xx - wclf.intercept_[0] / ww[1] - # plot separating hyperplanes and samples -h0 = plt.plot(xx, yy, 'k-', label='no weights') -h1 = plt.plot(xx, wyy, 'k--', label='with weights') plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired, edgecolors='k') plt.legend() -plt.axis('tight') -plt.show() +# plot the decision functions for both classifiers +ax = plt.gca() +xlim = ax.get_xlim() +ylim = ax.get_ylim() + +# create grid to evaluate model +xx = np.linspace(xlim[0], xlim[1], 30) +yy = np.linspace(ylim[0], ylim[1], 30) +YY, XX = np.meshgrid(yy, xx) +xy = np.vstack([XX.ravel(), YY.ravel()]).T + +# get the separating hyperplane +Z = clf.decision_function(xy).reshape(XX.shape) + +# plot decision boundary and margins +a = ax.contour(XX, YY, Z, colors='k', levels=[0], alpha=0.5, linestyles=['-']) + +# get the separating hyperplane for weighted classes +Z = wclf.decision_function(xy).reshape(XX.shape) + +# plot decision boundary and margins for weighted classes +b = ax.contour(XX, YY, Z, colors='r', levels=[0], alpha=0.5, linestyles=['-']) + +plt.legend([a.collections[0], b.collections[0]], ["non weighted", "weighted"], + loc="upper right")