4.3.3 升级损失函数
接下来我们升级loss()函数。别忘了,我们用于计算损失的均方误差如下所示:
再次声明,我们用的是不带偏置的损失函数,这样就少了一个需要担心的参数。我们会很快重新引入偏置。首先来看看,应该如何修改上面的loss()函数,使其能够处理多个维度。
我还记得在编写第一个机器学习程序时,对矩阵运算是多么失望。特别是矩阵的维度,似乎从来都不合理。随着时间的推移,我逐步认识到矩阵维度实际上可以成为一个好帮手,如果我能够对它们进行仔细分析,那么就可以借助它们将我的代码拼凑在一起。在这里我们要做同样的事情,使用矩阵维度来指导我们完成代码的编写。
我们从样本的标签开始。之前的代码里包含两个标签矩阵,分别是表示样本数据集中的实际真值的Y,以及由predict()函数计算出来的y_hat。y_hat和Y都是(m,1)矩阵,这表示每个样本有一行和一列数据。我们的例子中有30个样本,所以它们是(30,1)矩阵。如果要从y_hat中减去Y,NumPy则会检查这两个矩阵的大小是否相同,然后从y_hat的每个元素中减去Y的每个元素,结果仍然是(30,1)矩阵。
然后,计算(30,1)矩阵中所有元素的平方。由于该矩阵被构造成NumPy数组的形式,因此可以使用NumPy中一种名为广播的特性。我们在前面的章节中已经使用过这种特性,在对NumPy数组使用算术运算时,该运算在数组的每个元素上“广播”。换言之,我们可以对整个矩阵求平方,NumPy尽职地将“平方”运算应用到矩阵中的每个元素上。同样,这个运算的结果仍然是(30,1)矩阵。
最后,我们调用average()函数对矩阵中所有元素求出平均值,返回单个的标量数值:
根据NumPy的表示形式,空括号的含义是:“这是一个标量,所以它没有维度。”
最重要的是,我们根本不需要改变损失函数的计算方式。使用我们的均方误差代码可以像处理一个输入变量那样处理多个输入变量。