上QQ阅读APP看书,第一时间看更新
2.2.6 运行代码
下面的代码将Roberto的样本数据输入给train()函数。使用Python的习惯命名参数(详见A.3.1节)会更加直观。调用train()函数之后,系统会输出直线权重的最终值,以及预订座位数为20时的比萨销量预测值。
在调用train()时,我们需要确定迭代次数和学习率lr的取值。这一点可以通过反复尝试的方式来实现。我可以要求最多迭代10 000次,这感觉是个不错的开始。至于lr,请记住它的作用:它决定了w在训练的每一步有多大的变化。我们可以将它设置为0.01,这对于计算比萨销量来说应该是足够精确了。
实际上,经过实际的程序运行,train()前200次迭代的结果如下所示:
成功了!每次迭代的损失都在减少,直到算法不再进一步减少这种损失,此时的权重是1.84。因此,这就是Roberto可以预期的每个座位预订数所对应的比萨销量。如果有20个座位预订,那么他就可以卖出36.80个比萨。(比萨并不会以分数形式出售,但是Roberto喜欢保留小数点的后两位)。
通过计算w,我们的代码就相当于“在图表上画出了一条直线”。我们将这条直线进行了可视化(你可以在本书的源代码中找到相关的绘图代码)。
很好。但是,我们还能做得更好。