《Python大规模机器学习》 —2.3.3Scikit-learn的SGD实现

网友投稿 781 2022-05-30

2.3.3Scikit-learn的SGD实现

Scikit-learn软件包含许多在线学习算法。并不是所有机器学习算法都有在线学习算法,但是在线算法种类一直在稳步增长。在监督学习方面,我们将可用学习器分成分类器和回归器,并列举它们。

对于分类器有以下几点说明:

sklearn.naive_bayes.MultinomialNB

sklearn.naive_bayes.BernoulliNB

sklearn.linear_model.Perceptron

sklearn.linear_model.PassiveAggressiveClassifier

sklearn.linear_model.SGDClassifier

回归器有两种选择:

sklearn.linear_model.PassiveAggressiveRegressor

sklearn.linear_model.SGDRegressor

它们都可以增量学习,逐实例更新自己,但只有SGDClassifier和SGDRegressor是基于我们之前描述的随机梯度下降优化算法,本章重点介绍它们。对于所有大型问题,SGD学习器都为最优,因为其复杂度为O(k*n*p),其中k为数据遍历次数,n为实例数量,p为特征数(如果使用稀疏矩阵为非零特征):一个完全线性时间学习器,学习时间与所显示的实例数量成正比。

其他在线算法将被用作比较基准。此外,基于在线学习的partial_fit 算法和mini-batch(传输更大块非单个实例)的所有算法都使用相同API。共享相同API便于这些学习技术在你的学习框架中任意互换。

拟合方法能使用所有可用数据进行即时优化,与之相比,partial_fit基于传递的每个实例进行局部优化。即使数据集全传递给partial_fit,它也不会处理整批数据,而是处理每个元素,以保证学习操作的复杂度呈线性。此外,在partial_fit后,学习器可通过后续调用partial_fit来不断更新,这样的方法非常适合于从连续数据流进行在线学习。

分类时,唯一要注意的是,第一次初始化时需要知道要学习的种类数及其标记方法。可以使用类参数来完成,并指出标签数值的列表。这就需要事先进行探索,疏理数据流以记录问题的标签,并在不平衡情况下关注其分布——相对其他类,该类在数值上会太大或太小(但是Scikit-learn提供了一种自动处理问题的方法)。如果目标变量为数值变量,则了解其分布仍然很有用,但是这对于成功运行学习算法来说不是必需的。

Scikit-learn中有两个实现——一个用于分类问题(SGDClassfier),一个用于回归问题(SGDRegressor)。分类实现使用一对多(OVA)策略处理多类问题。这个策略就是,给定k个类,就建立k个模型,每个类针对其他类的所有实例都会建立一个模型,因此总共创建k个二进制分类。这就会产生k组系数和k个向量的预测及其概率。最后,与其他类比较每类的发生概率,将分类结果分配给概率最高的类。如果要求给出多项式分布的实际概率,只要简单地与其相除就能对结果归一化。(神经网络中的softmax层就会这么处理,下一章会看到详细介绍。)

Scikit-learn中实现分类和SGD回归时都会有不同的损失函数(成本函数,随机梯度下降优化法的核心)。

可以按照以下内容用损失参数表示分类:

loss='log':经典逻辑回归

loss='hinge':软边界,即线性支持向量机

《Python大规模机器学习》 —2.3.3Scikit-learn的SGD实现

loss='modified_huber':平滑hinge loss

回归有三个损失函数:

loss='squared_loss':普通最小二乘法线性回归(OLS)

loss='huber': 抗噪强的鲁棒回归抗Huberloss

loss='epsilon_insensitive':线性支持向量回归

我们将给出一些使用经典统计损失函数的实例,例如对数损失和OLS。下一章讨论hinge loss和支持向量机(SVMS),并详细介绍其功能。

作为提醒(这样读者就不必再查阅其他机器学习辅助书籍),如果回归函数定义为h,其预测由h(x)给出,因为X为特征矩阵,那么其公式如下:

因此,最小化的OLS成本函数如下:

在逻辑回归中,将二进制结果0/1变换为优势比,пy为正结果的概率,公式如下:

因此,对数损失函数定义如下:

Tensorflow python 机器学习

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:Python 协程之 Gevent 模块
下一篇:《数字化转型之路》 —1.1.4 做强实体经济
相关文章