scikit-learn机器学习常用算法原理及编程实战》—2.6 scikit-learn简介

网友投稿 644 2022-05-30

2.6  scikit-learn简介

scikit-learn是一个开源的Python语言机器学习工具包,它涵盖了几乎所有主流机器学习算法的实现,并且提供了一致的调用接口。它基于Numpy和scipy等Python数值计算库,提供了高效的算法实现。总结起来,scikit-learn工具包有以下几个优点。

* 文档齐全:官方文档齐全,更新及时。

* 接口易用:针对所有的算法提供了一致的接口调用规则,不管是KNN、K-Mean还是PCA。

* 算法全面:涵盖主流机器学习任务的算法,包括回归算法、分类算法、聚类分析、数据降维处理等。

当然,scikit-learn不支持分布式计算,不适合用来处理超大型数据。但这并不影响 scikit-learn作为一个优秀的机器学习工具库这个事实。许多知名的公司,包括Evernote和Spotify都使用scikit-learn来开发他们的机器学习应用。

2.6.1  scikit-learn示例

回顾前面章节介绍的机器学习应用开发的典型步骤,我们使用scikit-learn来完成一个手写数字识别的例子。这是一个有监督的学习,数据是标记过的手写数字的图片。即通过采集足够多的手写数字样本数据,选择合适的模型,并使用采集到的数据进行模型训练,最后验证手写识别程序的正确性。

1.数据采集和标记

如果我们从头实现一个数字手写识别的程序,需要先采集数据,即让尽量多不同书写习惯的用户,写出从0~9的所有数字,然后把用户写出来的数据进行标记,即用户每写出一个数字,就标记他写出的是哪个数字。

为什么要采集尽量多不同书写习惯的用户写的数字呢?因为只有这样,采集到的数据才有代表性,才能保证最终训练出来的模型的准确性。极端的例子,我们采集的都是习惯写出瘦高形数字的人,那么针对习惯写出矮胖形数字的人写出来的数字,模型的识别成功率就会很低。

所幸我们不需要从头开始这项工作,scikit-learn自带了一些数据集,其中一个是手写数字识别图片的数据,使用以下代码来加载数据。

from sklearn import datasets

digits = datasets.load_digits()

可以在ipython notebook环境下把数据所表示的图片用Mathplotlib显示出来:

# 把数据所代表的图片显示出来

images_and_labels = list(zip(digits.images, digits.target))

plt.figure(figsize=(8, 6), dpi=200)

for index, (image, label) in enumerate(images_and_labels[:8]):

plt.subplot(2, 4, index + 1)

plt.axis('off')

plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')

plt.title('Digit: %i' % label, fontsize=20)

其结果如图2-19所示。

图2-19  数字图片

从图2-19中可以看出,图片是一个个手写的数字。

2.特征选择

针对一个手写的图片数据,应该怎么样来选择特征呢?一个直观的方法是,直接使用图片的每个像素点作为一个特征。比如一个图片是200 ? 200的分辨率,那么我们就有 40000个特征,即特征向量的长度是40000。

实际上,scikit-learn使用Numpy的array对象来表示数据,所有的图片数据保存在 digits.images里,每个元素都是一个8?8尺寸的灰阶图片。我们在进行机器学习时,需要把数据保存为样本个数?特征个数格式的array对象,针对手写数字识别这个案例,scikit-learn已经为我们转换好了,它就保存在digits.data数据里,可以通过digits.data.shape来查看它的数据格式为:

print("shape of raw image data: {0}".format(digits.images.shape))

print("shape of data: {0}".format(digits.data.shape))

输出为:

shape of raw image data: (1797, 8, 8)

shape of data: (1797, 64)

可以看到,总共有1797个训练样本,其中原始的数据是8?8的图片,而用来训练的数据是把图片的64个象素点都转换为特征。下面将直接使用digits.data作为训练数据。

《scikit-learn机器学习常用算法原理及编程实战》—2.6 scikit-learn简介

3.数据清洗

人们不可能在8?8这么小的分辨率的图片上写出数字,在采集数据的时候,是让用户在一个大图片上写出这些数字,如果图片是200 ? 200分辨率,那么一个训练样例就有40000个特征,计算量将是巨大的。为了减少计算量,也为了模型的稳定性,我们需要把200 ? 200的图片缩小为8?8的图片。这个过程就是数据清洗,即把采集到的、不适合用来做机器学习训练的数据进行预处理,从而转换为适合机器学习的数据。

4.模型选择

不同的机器学习算法模型针对特定的机器学习应用有不同的效率,模型的选择和验证留到后面章节详细介绍。此处,我们使用支持向量机来作为手写识别算法的模型。关于支持向量机,后面章节也会详细介绍。

5.模型训练

在开始训练我们的模型之前,需要先把数据集分成训练数据集和测试数据集。为什么要这样做呢?第1章的模型训练和测试里有详细的介绍。我们可以使用下面代码把数据集分出20%作为测试数据集。

# 把数据分成训练数据集和测试数据集

from sklearn.cross_validation import train_test_split

Xtrain, Xtest, Ytrain, Ytest = train_test_split(digits.data,

digits.target, test_size=0.20, random_state=2);

接着,使用训练数据集Xtrain和Ytrain来训练模型。

# 使用支持向量机来训练模型

from sklearn import svm

clf = svm.SVC(gamma=0.001, C=100.)

clf.fit(Xtrain, Ytrain);

训练完成后,clf对象就会包含我们训练出来的模型参数,可以使用这个模型对象来进行预测。

6.模型测试

我们来测试一下训练出来的模型的准确度。一个直观的方法是,我们用训练出来的模型clf预测测试数据集,然后把预测结果Ypred和真正的结果Ytest比较,看有多少个是正确的,这样就能评估出模型的准确度了。所幸,scikit-learn提供了现成的方法来完成这项工作:

clf.score(Xtest, Ytest)

笔者计算机上的输出结果为:

0.9*********5

显示出模型有97.8%的准确率。读者如果运行这段代码的话,在准确率上可能会稍有差异。

除此之外,还可以直接把测试数据集里的部分图片显示出来,并且在图片的左下角显示预测值,右下角显示真实值。运行效果如图2-20所示。

# 查看预测的情况

fig, axes = plt.subplots(4, 4, figsize=(8, 8))

fig.subplots_adjust(hspace=0.1, wspace=0.1)

for i, ax in enumerate(axes.flat):

ax.imshow(Xtest[i].reshape(8, 8), cmap=plt.cm.gray_r,

interpolation='nearest')

ax.text(0.05, 0.05, str(Ypred[i]), fontsize=32,

transform=ax.transAxes,

color='green' if Ypred[i] == Ytest[i] else 'red')

ax.text(0.8, 0.05, str(Ytest[i]), fontsize=32,

transform=ax.transAxes,

color='black')

ax.set_xticks([])

ax.set_yticks([])

图2-20  预测值与真实值

从图2-20中可以看出来,第二行第一个图片预测出错了,真实的数字是4,但预测成了8。

7.模型保存与加载

当我们对模型的准确度感到满意后,就可以把模型保存下来。这样下次需要预测时,可以直接加载模型来进行预测,而不是重新训练一遍模型。可以使用下面的代码来保存模型:

# 保存模型参数

from sklearn.externals import joblib

joblib.dump(clf, 'digits_svm.pkl');

当我们需要这个模型来进行预测时,直接加载模型即可进行预测。

# 导入模型参数,直接进行预测

clf = joblib.load('digits_svm.pkl')

Ypred = clf.predict(Xtest);

clf.score(Ytest, Ypred)

笔者计算机上的输出结果是:

0.9*********5

这个例子包含在随书代码ch02.06.ipynb上,读者可以下载下来运行并参考。

机器学习 scikit-learn

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:机器人编程趣味实践20-版本课程(教学)
下一篇:CDM之ANTLR学习
相关文章