人工智能(高中版)
上QQ阅读APP看书,第一时间看更新

2.3 泛化

在上面的损失函数的介绍中,假如定义一个极为复杂的函数f,使得当输入为xi时,f输出yi,否则输出0,那么很容易可以发现,在训练数据集上一定会有损失函数LfXY)=0!可是,这个函数除了将训练数据记忆下来之外,并没有做任何其他事情。也就是说,这个函数f包含的信息和训练数据集包含的信息是一样的,它没有学习训练数据的任何特征,也没有任何智能可言。

这个例子提醒我们,一个好的函数f不仅需要在训练数据上表现很好,可以得到一个很小的损失函数(即拟合能力),同时它需要有很强的举一反三、归纳推广的能力。换句话说,对于在训练的时候没有见过的数据,它也需要有比较好的表现。这样的能力叫做泛化(generalization)。一个具有泛化能力保证的f,才是一个真正有意义的目标函数。

事实上,如何才能够确保具有泛化能力是机器学习领域一个非常核心的问题,科研领域也有大量的理论成果,但目前并没有放之四海而皆准的方法。在实际应用中,一个比较有效的方法是调优(validation)。具体来说,调优是把数据集分成两块,一块(通常占90%~95%左右)叫做训练数据集,而另一块(一般占5%~10%左右)叫做调优集(validation set,又称验证集)。接下来,在训练的时候只使用训练数据集进行训练;然后在使用测试数据集之前,先在调优集上面看看算法的泛化效果。由于训练时算法并没有见过调优集,训练结束之后它在调优集上的表现可以视为一个比较好的泛化能力的估测。至少,单纯对训练集合进行死记硬背,难以在调优集上得到比较好的表现。

在调优方法的基础上,人们还进一步提出了交叉调优(cross validation,又称交叉验证)的思路。其具体做法如图2.3所示,其中白色表示训练使用的数据,灰色表示剔除的数据。每次训练剔除不同的数据,并根据得到的函数在剔除数据上的表现得到泛化能力的估计。

如图所示,交叉调优将训练数据集分成k份,然后相应地训练出k个不同的函数f1f2,…,fk。这里,在训练函数fi时,剔除了第i份数据(将其当做调优集),只用其他的k-1份数据。由于每个函数都会剔除不同的数据进行训练,最后也使用不同的数据进行验证,我们得到了一个目标函数训练方法的稳健泛化能力分析。最后,可以根据f1f2,…,fk在对应验证集合上的表现,确定最好的参数方案。使用这个参数方案对整个训练数据进行训练之后,就可以得到最后的f函数。

举个例子,假设我们将数据分成了4份,如图2.3所示(K=4),分别叫做(X1Y1),(X2Y2),(X3Y3),(X4Y4)。

图2.3 交叉调优示意图

第一轮:在(X2Y2),(X3Y3),(X4Y4)进行训练,得到函数在(X1Y1)上的损失函数值

第二轮:在(X1Y1),(X3Y3),(X4Y4)进行训练,得到函数在(X2Y2)上的损失函数值

第三轮:在(X1Y1),(X2Y2),(X4Y4)进行训练,得到函数在(X3Y3)上的损失函数值

第四轮:在(X1Y1),(X2Y2),(X3Y3)进行训练,得到函数在(X4Y4)上的损失函数值

则最后得到的交叉调优的结果为。这是对我们的训练方法比较综合的估计。

为什么要构造调优集呢?这是因为在实际生产生活过程中,人们通常无法接触到测试数据,也不能等到测试的时候再修改训练算法和参数。于是人们从训练数据中挑选一部分当做模拟测试数据,并根据它们来决定训练算法和参数(调优)。这是实际中常用的技巧。

监督学习的几个步骤总结如下:

(1)确认目标问题;

(2)创建数据集,包含成千上万的数据点xiyi,其中xi为输入,yi为输出;

(3)针对问题选择一个好的机器学习模型f

(4)定义一个合适的损失函数L度量fX)和Y的距离;

(5)以损失函数为指标,使用优化算法寻找f的参数组合;

(6)确定f具有非常强的泛化能力。

下面用一个简单的例子来具体描述这个流程。假设我们希望学习判断图片中的物体是猫还是狗。首先需要找到一个训练数据集,它里面的图片不是猫就是狗,并且已经标注好,如图2.4所示。

接下来,指定一个具体的机器学习模型,用函数f表示。这个模型可以是线性模型、决策树模型或神经网络等(在后续的章节中会详细介绍)。根据输入图片x,模型f可以得到一个预测fx)∈{猫,狗}。然后,设计一个损失函数L来表示这个预测与真实答案的距离(后面会看到,对于这样的分类问题,交叉熵是一个比较好的损失函数)。

图2.4 狗和猫的图片

确定损失函数之后,选择优化算法(如常用的梯度下降法,会在第3章中详细介绍)对模型进行优化。为了确保模型的泛化性能,一般会在训练之前从训练数据集中随机选出一部分图片组成调优集,在训练完成之后测试模型f在调优集上面的表现,作为模型泛化能力的一个估计。

最后,再次强调,在训练过程中,算法不应以任何方式触碰测试数据集,无论是只看测试数据集的输入x,还是用部分的(xy)进行训练。这样会对测试数据造成污染,导致无法测试出算法的真实表现。这是初学者常犯的错误,请大家一定牢记在心。