纯干货 | 机器学习中梯度下降法的分类及对比分析(附源码)

更多深度文章,请关注:https://yq.aliyun.com/cloud

HackerEarth,一家来自印度的创业公司,旨在帮助开发者通过线上编程竞赛获得工作机会。和Github类似,它提供一个多种编程语言的代码交流平台。而HackerEarth blog 上多刊登一些跟大数据、人工智能、机器学习、算法及编程竞赛相关的博文。

引言

      梯度下降法 (Gradient Descent Algorithm,GD) 是为目标函数J(θ),如代价函数(cost function), 求解全局最小值(Global Minimum)的一种迭代算法。本文会详细讨论按照准确性和耗费时间(accuracy and time consuming factor)将梯度下降法进行分类。这个算法在机器学习中被广泛用来最小化目标函数,如下图所示。

为什么使用梯度下降法

      我们使用梯度下降法最小化目标函数J(θ)。在使用梯度下降法时,首先初始化参数值,然后一直改变这些值,直到得到全局最小值。其中,我们计算在每次迭代时计算代价函数的导数,然后使用如下公式同时更新参数值:

α表示学习速率(learning
rate
)。

在本文中,考虑使用线性回归linear
regression
)作为算法实例,当然梯度下降法也可以应用到其他算法,如逻辑斯蒂回归(Logistic
regression
)和
神经网络(Neural
networks
)。在线性回归中,我们使用如下拟合函数(hypothesis
function
):

其中, 是参数, 是输入特征。为了求解线性回归模型,需要找到合适的参数使拟合函数能够更好地适合模型,然后使用梯度下降最小化代价函数J(θ)

代价函数(普通的最小平方差,ordinary
least square error
)如下所示:

代价函数的梯度(Gradient of Cost function):

参数与代价函数关系如下图所示:

梯度下降法的工作原理

下面的伪代码能够解释其详细原理:

1. 初始化参数值

2. 迭代更新这些参数使目标函数J(θ)不断变小。

梯度下降法的类型

基于如何使用数据计算代价函数的导数,梯度下降法可以被定义为不同的形式(various
variants
)。确切地说,根据使用数据量的大小the amount of data),时间复杂度time complexity)和算法的准确率accuracy of the algorithm),梯度下降法可分为:

1.      
批量梯度下降法Batch
Gradient Descent, BGD
);

2.      
随机梯度下降法Stochastic Gradient Descent, SGD);

3.      
小批量梯度下降法Mini-Batch Gradient Descent, MBGD)。

批量梯度下降法原理

      这是梯度下降法的基本类型,这种方法使用整个数据集(the complete dataset)去计算代价函数的梯度。每次使用全部数据计算梯度去更新参数,批量梯度下降法会很慢,并且很难处理不能载入内存(don’t
fit in memory
)的数据集。在随机初始化参数后,按如下方式计算代价函数的梯度:

其中,m是训练样本(training
examples
)的数量。

Note:

     1. 如果训练集有3亿条数据,你需要从硬盘读取全部数据到内存中;

     2. 每次一次计算完求和后,就进行参数更新;

     3.  然后重复上面每一步;

     4. 这意味着需要较长的时间才能收敛

     5. 特别是因为磁盘输入/输出(disk
I/O
)是系统典型瓶颈,所以这种方法会不可避免地需要大量的读取。

上图是每次迭代后的等高线图,每个不同颜色的线表示代价函数不同的值。运用梯度下降会快速收敛到圆心,即唯一的一个全局最小值。

批量梯度下降法不适合大数据集。下面的Python代码实现了批量梯度下降法:

1.	import numpy as np
2.	import random
3.	def gradient_descent(alpha, x, y, ep=0.0001, max_iter=10000):
4.	    converged = False
5.	    iter = 0
6.	    m = x.shape[0] # number of samples
7.
8.	    # initial theta
9.	    t0 = np.random.random(x.shape[1])
10.	    t1 = np.random.random(x.shape[1])
11.
12.	    # total error, J(theta)
13.	    J = sum([(t0 + t1*x[i] - y[i])**2 for i in range(m)])
14.
15.	    # Iterate Loop
16.	    while not converged:
17.	        # for each training sample, compute the gradient (d/d_theta j(theta))
18.	        grad0 = 1.0/m * sum([(t0 + t1*x[i] - y[i]) for i in range(m)])
19.	        grad1 = 1.0/m * sum([(t0 + t1*x[i] - y[i])*x[i] for i in range(m)])
20.	        # update the theta_temp
21.	        temp0 = t0 - alpha * grad0
22.	        temp1 = t1 - alpha * grad1
23.
24.	        # update theta
25.	        t0 = temp0
26.	        t1 = temp1
27.
28.	        # mean squared error
29.	        e = sum( [ (t0 + t1*x[i] - y[i])**2 for i in range(m)] )
30.
31.	        if abs(J-e) <= ep:
32.	            print 'Converged, iterations: ', iter, '!!!'
33.	            converged = True
34.
35.	        J = e   # update error
36.	        iter += 1  # update iter
37.
38.	        if iter == max_iter:
39.	            print 'Max interactions exceeded!'
40.	            converged = True
41.
42.	    return t0,t1 

随机梯度下降法原理

   批量梯度下降法被证明是一个较慢的算法,所以,我们可以选择随机梯度下降法达到更快的计算。随机梯度下降法的第一步是随机化整个数据集。在每次迭代仅选择一个训练样本去计算代价函数的梯度,然后更新参数。即使是大规模数据集,随机梯度下降法也会很快收敛。随机梯度下降法得到结果的准确性可能不会是最好的,但是计算结果的速度很快。在随机化初始参数之后,使用如下方法计算代价函数的梯度:

这里m表示训练样本的数量。

如下为随机梯度下降法的伪码:

       1. 进入内循环(inner loop);

       2. 第一步:挑选第一个训练样本并更新参数,然后使用第二个实例;

       3. 第二步:选第二个训练样本,继续更新参数;

       4. 然后进行第三步…直到第n步;

       5. 直到达到全局最小值

如下图所示,随机梯度下降法不像批量梯度下降法那样收敛,而是游走到接近全局最小值的区域终止

小批量梯度下降法原理

 小批量梯度下降法是最广泛使用的一种算法,该算法每次使用m个训练样本(称之为一批)进行训练,能够更快得出准确的答案。小批量梯度下降法不是使用完整数据集,在每次迭代中仅使用m个训练样本去计算代价函数的梯度。一般小批量梯度下降法所选取的样本数量在50到256个之间,视具体应用而定。

1.这种方法减少了参数更新时的变化,能够更加稳定地收敛。

2.同时,也能利用高度优化的矩阵,进行高效的梯度计算。

随机初始化参数后,按如下伪码计算代价函数的梯度:

这里b表示一批训练样本的个数,m是训练样本的总数。

Notes:

1. 实现该算法时,同时更新参数

2. 学习速率α(也称之为步长)如果α过大,算法可能不会收敛;如果α比较小,就会很容易收敛。

3. 检查梯度下降法的工作过程。画出迭代次数与每次迭代后代价函数值的关系图,这能够帮助你了解梯度下降法是否取得了好的效果。每次迭代后J(θ)应该降低,多次迭代后应该趋于收敛。

4. 不同的学习速率在梯度下降法中的效果

总结

本文详细介绍了不同类型的梯度下降法。这些算法已经被广泛应用于神经网络。下面的图详细展示了3种梯度下降法的比较。

以上为译文

本文由北邮@爱可可-爱生活 老师推荐,阿里云组织翻译。

文章原标题《3 Types of Gradient Descent
Algorithms for Small & Large Data Sets》,由HackerEarth blog发布。

译者:李烽 ;审校:段志成-海棠

文章为简译,更为详细的内容,请查看原文。中文译制文档下载见此。

时间: 2024-10-29 00:31:52

纯干货 | 机器学习中梯度下降法的分类及对比分析(附源码)的相关文章

ASP.NET中ListView(列表视图)的使用前台绑定附源码_实用技巧

1.A,运行效果图   1.B,源代码 复制代码 代码如下: <%@ Page Language="C#" AutoEventWireup="true" CodeFile="DropLvw.aspx.cs" Inherits="DropLvw" %> <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "h

(转)干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码)

干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码)   该博客来源自:https://mp.weixin.qq.com/s?__biz=MzA4NzE1NzYyMw==&mid=2247492203&idx=5&sn=3020c3a43bd4dd678782d8aa24996745&chksm=903f1c73a74895652ee688d070fd807771e3fe6a8947f77f3a15a44a65557da0313ac5ad59

winform中继承base实现屏蔽系统热键,求源码

问题描述 winform中继承base实现屏蔽系统热键,求源码 近期想做一个锁屏,采用键盘钩子在win7上一直不完美,听说继承base类可以实现,但是不知道具体怎么实现屏蔽系统热键,比如alt+f4,任务管理器等等啊,求源码或详细思路,谢谢 解决方案 继承base实现屏蔽系统热键 没这么神奇,只能吃掉本窗体的键盘消息的响应. 解决方案二: 任务管理器可以通过组策略禁用 阻止alt+f4只要在Closing事件中写e.cancel=true即可.

从零开始编写自己的C#框架(12)——T4模板在逻辑层中的应用(一)(附源码)

原文:从零开始编写自己的C#框架(12)--T4模板在逻辑层中的应用(一)(附源码) 对于T4模板很多朋友都不太熟悉,它在项目开发中,会帮我们减轻很大的工作量,提升我们的开发效率,减少出错概率.所以学好T4模板的应用,对于开发人员来说是非常重要的. 园子里对于T4模板的介绍与资料已经太多了,所以在这里我就不再详细讲述基础知识了,只是说说T4模板在本框架中的具体应用与实践.   一.创建逻辑层项目   二.添加引用 将之前添加的三个项目添加到引用   三.创建T4模板放置的文件夹,并命名为SubS

Javascript中的几种继承方式对比分析_基础知识

开篇从'严格'意义上说,javascript并不是一门真正的面向对象语言.这种说法原因一般都是觉得javascript作为一门弱类型语言与类似java或c#之类的强型语言的继承方式有很大的区别,因而默认它就是非主流的面向对象方式,甚至竟有很多书将其描述为'非完全面向对象'语言.其实个人觉得,什么方式并不重要,重要的是是否具有面向对象的思想,说javascript不是面向对象语言的,往往都可能没有深入研究过javascript的继承方式,故特撰此文以供交流. 为何需要利用javascript实现继

Win7系统中解除VS2008过期限制程序,附源码

Win7系统中解除VS2008过期限制程序,附源码 下载地址: http://files.cnblogs.com/waw/Win7_VS2008_Cracker.rar

一文详解如何用 TensorFlow 实现基于 LSTM 的文本分类(附源码)

 引言 学习一段时间的tensor flow之后,想找个项目试试手,然后想起了之前在看Theano教程中的一个文本分类的实例,这个星期就用tensorflow实现了一下,感觉和之前使用的theano还是有很大的区别,有必要总结mark一下.   模型说明 这个分类的模型其实也是很简单,主要就是一个单层的LSTM模型,当然也可以实现多层的模型,多层的模型使用Tensorflow尤其简单,下面是这个模型的图  简单解释一下这个图,每个word经过embedding之后,进入LSTM层,这里LSTM是

php中随机函数mt_rand()与rand()性能对比分析_php技巧

本文实例对比分析了php中随机函数mt_rand()与rand()性能问题.分享给大家供大家参考.具体分析如下: 在php中mt_rand()和rand()函数都是可以随机生成一个纯数字的,他们都是需要我们设置好种子数据然后生成,那么mt_rand()和rand()那个性能会好一些呢,下面我们带着疑问来测试一下. 例子1. mt_rand() 范例,代码如下: 复制代码 代码如下: <?php echo mt_rand() . "n"; echo mt_rand() . &quo

一起谈.NET技术,Silverlight 4中把DataGrid数据导出Excel—附源码下载

Silverlight中常常用到DataGrid来展示密集数据. 而常见应用系统中我们需要把这些数据导入导出到固定Office套件中例如常用的Excel表格. 那么在Silverlight 中如何加以实现? 在参考大量资料后 提供参考思路如下: A:纯客户端导出处理.利用Silverlight 与Javascript 进行交互实现导出Excel. B:服务器端导出.获得DataGrid数据源. 传递给WCF Service到服务器端. 然后把传回数据通过Asp.net中通用处理导出Excel方法