3.2 梯度下降法
现在我们来寻找更好的train()算法。train()函数的任务是找到具有最小损失的参数。让我们关注loss()函数自身:
考察这个函数的参数。X和Y分别代表输入变量和输入标签,因此它们在loss()函数调用的过程中不会发生变化。为了简化下面的讨论,我们在这里临时将b设置为0。因此,现在唯一的变量就是参数w。
那么,损失是如何随着w的变化而变化的呢?我编写了一个程序,为w绘制了从-1到4的loss()函数值,并在最小值上绘制了一个十字(和往常一样,这段代码也是本书的源代码)。
漂亮的曲线!我们通常将其称为损失曲线。train()的基本思想就是要找到曲线底部的标记,即w的值,由此得到最小的损失。模型在参数w处实现了对数据点的最佳逼近。
现在想象一下,上述损失曲线就像是一个山谷,有一名徒步旅行者站在山谷的某个地方。这个旅行者想要到达她的营地,就是标记的地方。问题是现在山谷很黑,她只能看到她脚下的地形。为了能够找到营地,她可以遵循一个非常简单的方法,即沿着最陡峭的斜坡往下走。如果地形中没有山洞或悬崖,此时我们的损失函数也没有山洞或悬崖,那么该旅行者每走一步都会更接近营地。
为了将上述方法变成一个具体的算法,我们需要度量出损失曲线的斜率。在数学领域,通常将这个斜率称为曲线的梯度。一般来说,曲线上某一点的梯度就是从该点出发的切线的斜率,如下图所示:
为了实现对梯度的度量,我们使用一种名为“损失函数对权重参数的导数”的数学工具,并将其记为∂L⁄∂w。更为正式地表述,就是某一点上的导数值刻画了参数w产生的微小变化,使得相应的损失函数值L在该点上产生变化的情况。想象一下,如果把权重参数w值增加一点点,那么损失函数值会发生什么样的变化?如果这个图中的导数值为负数,那就意味着损失函数值在减少;如果导数值为正数,那就意味着损失函数值在增加。在曲线的最小值点处,也就是用十字标记的那个点处的曲线的切线是水平的,该点的导数值为零。
这个图中的导数值为负数,就意味着损失函数会随着w值的增加而减少。如果这个徒步旅行者站在图的右侧,则相应的导数值为正,那就意味着随着w值的增加,损失函数值也会增加。在曲线的最小值点处,也就是用十字标记的那个点处的切线是水平的,该点的导数值为零。
需要注意的是,该旅行者应该朝着梯度的反方向走才能够到达最小值点,因此在上图所示的导数值为负的情况下,她应该朝着正方向迈步。旅行者步伐的大小也应该与坡度成正比。如果导数值是一个很大的数(无论是正的还是负的),这就意味着曲线十分陡峭,而且营地十分遥远。因此,该旅行者可以自信地迈出一大步。当该旅行者接近营地时,坡度会变小,她的步伐也会变小。
以上描述的算法叫作梯度下降法,简称为GD法。实现这个算法需要一点数学知识。