关于机器学习的三个阶段
629
2022-05-30
2.6 scikit-learn简介
scikit-learn是一个开源的Python语言机器学习工具包,它涵盖了几乎所有主流机器学习算法的实现,并且提供了一致的调用接口。它基于Numpy和scipy等Python数值计算库,提供了高效的算法实现。总结起来,scikit-learn工具包有以下几个优点。
* 文档齐全:官方文档齐全,更新及时。
* 接口易用:针对所有的算法提供了一致的接口调用规则,不管是KNN、K-Mean还是PCA。
* 算法全面:涵盖主流机器学习任务的算法,包括回归算法、分类算法、聚类分析、数据降维处理等。
当然,scikit-learn不支持分布式计算,不适合用来处理超大型数据。但这并不影响 scikit-learn作为一个优秀的机器学习工具库这个事实。许多知名的公司,包括Evernote和Spotify都使用scikit-learn来开发他们的机器学习应用。
2.6.1 scikit-learn示例
回顾前面章节介绍的机器学习应用开发的典型步骤,我们使用scikit-learn来完成一个手写数字识别的例子。这是一个有监督的学习,数据是标记过的手写数字的图片。即通过采集足够多的手写数字样本数据,选择合适的模型,并使用采集到的数据进行模型训练,最后验证手写识别程序的正确性。
1.数据采集和标记
如果我们从头实现一个数字手写识别的程序,需要先采集数据,即让尽量多不同书写习惯的用户,写出从0~9的所有数字,然后把用户写出来的数据进行标记,即用户每写出一个数字,就标记他写出的是哪个数字。
为什么要采集尽量多不同书写习惯的用户写的数字呢?因为只有这样,采集到的数据才有代表性,才能保证最终训练出来的模型的准确性。极端的例子,我们采集的都是习惯写出瘦高形数字的人,那么针对习惯写出矮胖形数字的人写出来的数字,模型的识别成功率就会很低。
所幸我们不需要从头开始这项工作,scikit-learn自带了一些数据集,其中一个是手写数字识别图片的数据,使用以下代码来加载数据。
from sklearn import datasets
digits = datasets.load_digits()
可以在ipython notebook环境下把数据所表示的图片用Mathplotlib显示出来:
# 把数据所代表的图片显示出来
images_and_labels = list(zip(digits.images, digits.target))
plt.figure(figsize=(8, 6), dpi=200)
for index, (image, label) in enumerate(images_and_labels[:8]):
plt.subplot(2, 4, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Digit: %i' % label, fontsize=20)
其结果如图2-19所示。
图2-19 数字图片
从图2-19中可以看出,图片是一个个手写的数字。
2.特征选择
针对一个手写的图片数据,应该怎么样来选择特征呢?一个直观的方法是,直接使用图片的每个像素点作为一个特征。比如一个图片是200 ? 200的分辨率,那么我们就有 40000个特征,即特征向量的长度是40000。
实际上,scikit-learn使用Numpy的array对象来表示数据,所有的图片数据保存在 digits.images里,每个元素都是一个8?8尺寸的灰阶图片。我们在进行机器学习时,需要把数据保存为样本个数?特征个数格式的array对象,针对手写数字识别这个案例,scikit-learn已经为我们转换好了,它就保存在digits.data数据里,可以通过digits.data.shape来查看它的数据格式为:
print("shape of raw image data: {0}".format(digits.images.shape))
print("shape of data: {0}".format(digits.data.shape))
输出为:
shape of raw image data: (1797, 8, 8)
shape of data: (1797, 64)
可以看到,总共有1797个训练样本,其中原始的数据是8?8的图片,而用来训练的数据是把图片的64个象素点都转换为特征。下面将直接使用digits.data作为训练数据。
3.数据清洗
人们不可能在8?8这么小的分辨率的图片上写出数字,在采集数据的时候,是让用户在一个大图片上写出这些数字,如果图片是200 ? 200分辨率,那么一个训练样例就有40000个特征,计算量将是巨大的。为了减少计算量,也为了模型的稳定性,我们需要把200 ? 200的图片缩小为8?8的图片。这个过程就是数据清洗,即把采集到的、不适合用来做机器学习训练的数据进行预处理,从而转换为适合机器学习的数据。
4.模型选择
不同的机器学习算法模型针对特定的机器学习应用有不同的效率,模型的选择和验证留到后面章节详细介绍。此处,我们使用支持向量机来作为手写识别算法的模型。关于支持向量机,后面章节也会详细介绍。
5.模型训练
在开始训练我们的模型之前,需要先把数据集分成训练数据集和测试数据集。为什么要这样做呢?第1章的模型训练和测试里有详细的介绍。我们可以使用下面代码把数据集分出20%作为测试数据集。
# 把数据分成训练数据集和测试数据集
from sklearn.cross_validation import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(digits.data,
digits.target, test_size=0.20, random_state=2);
接着,使用训练数据集Xtrain和Ytrain来训练模型。
# 使用支持向量机来训练模型
from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100.)
clf.fit(Xtrain, Ytrain);
训练完成后,clf对象就会包含我们训练出来的模型参数,可以使用这个模型对象来进行预测。
6.模型测试
我们来测试一下训练出来的模型的准确度。一个直观的方法是,我们用训练出来的模型clf预测测试数据集,然后把预测结果Ypred和真正的结果Ytest比较,看有多少个是正确的,这样就能评估出模型的准确度了。所幸,scikit-learn提供了现成的方法来完成这项工作:
clf.score(Xtest, Ytest)
笔者计算机上的输出结果为:
0.9*********5
显示出模型有97.8%的准确率。读者如果运行这段代码的话,在准确率上可能会稍有差异。
除此之外,还可以直接把测试数据集里的部分图片显示出来,并且在图片的左下角显示预测值,右下角显示真实值。运行效果如图2-20所示。
# 查看预测的情况
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.imshow(Xtest[i].reshape(8, 8), cmap=plt.cm.gray_r,
interpolation='nearest')
ax.text(0.05, 0.05, str(Ypred[i]), fontsize=32,
transform=ax.transAxes,
color='green' if Ypred[i] == Ytest[i] else 'red')
ax.text(0.8, 0.05, str(Ytest[i]), fontsize=32,
transform=ax.transAxes,
color='black')
ax.set_xticks([])
ax.set_yticks([])
图2-20 预测值与真实值
从图2-20中可以看出来,第二行第一个图片预测出错了,真实的数字是4,但预测成了8。
7.模型保存与加载
当我们对模型的准确度感到满意后,就可以把模型保存下来。这样下次需要预测时,可以直接加载模型来进行预测,而不是重新训练一遍模型。可以使用下面的代码来保存模型:
# 保存模型参数
from sklearn.externals import joblib
joblib.dump(clf, 'digits_svm.pkl');
当我们需要这个模型来进行预测时,直接加载模型即可进行预测。
# 导入模型参数,直接进行预测
clf = joblib.load('digits_svm.pkl')
Ypred = clf.predict(Xtest);
clf.score(Ytest, Ypred)
笔者计算机上的输出结果是:
0.9*********5
这个例子包含在随书代码ch02.06.ipynb上,读者可以下载下来运行并参考。
机器学习 scikit-learn
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。