4.3.2 升级预测函数
由于现在包含多个输入变量,因此需要将预测公式从简单的直线方程转换为加权和形式的如下计算公式:
你可能已经注意到公式里少了什么。是的,为了简化公式,我暂时移除了偏置b,它很快就会回来的。
现在可以将上述加权和公式转换为predict()函数的多维版本。提醒一下,下面是以前的单维predict()函数(没有偏置):
新的predict()函数仍然应该使用符号X和w表示变量,但是这些变量现在有了更多的维度。X曾经表示包含m个元素的向量,其中m表示样本的个数。现在X表示一个(m,n)矩阵,其中n是输入变量的个数。对于我们当前考察的应用实例,有30个样本和3个输入变量,所以X是一个(30,3)矩阵。
那么w会如何呢?正如每个输入变量必须有一个x一样,每个输入变量也必须有一个w。与x不同的是,w对于每个样本来说都是相同的。因此,可以把权重设为(n,1)矩阵或(1,n)矩阵,原因等一会就清楚了。我们最好将其设为(n,1)矩阵:将每个输入变量作为一行,且只有一列。
下面来对这个(n,1)矩阵进行初始化。还记得我们曾经把w初始化为0吗?现在w是一个矩阵,所以必须把它的所有元素初始化为0。NumPy有个zeros()函数:
别忘了X.shape[0]表示X的行数,X.shape[1]表示X的列数。这里的代码表示w的行数与X的列数相等,在我们的例子中都是3。
这就是矩阵乘法最终能够派上用场的地方。回头看一下加权和:
如果X只有一行,这就与用一行X乘以w的效果完全一样:
在我们的应用示例中,X并不只有一行。X有很多行,比如(30,3)。当我们把它乘以w,也就是(3,1)矩阵时,就会得到一个(30,1)矩阵。在这个矩阵中,每个样本只有一行,并且一个列元素包含了这个样本的预测值。换言之,我们可以通过一次矩阵乘法运算得到关于所有样本的预测值!
因此,重写多个输入变量的预测结果就像用矩阵X乘以w一样简单。可以使用NumPy中的matmul()函数做到这一点:
我们花了一些时间来理解这个小型函数,而且通过暂时忽略偏置的方式抄了近路,我们将在后面重新引入偏置。但是获得的最终结果是值得的,predict()函数只用了短短几行代码就构造了多重线性回归模型。