AI源码解读:数字图像处理案例(Python版)
上QQ阅读APP看书,第一时间看更新
 AI源码解读.数字图像处理案例:Python版

5.3.2 创建模型并编译

数据加载进模型之后,需要定义模型结构,并优化损失函数。

1.定义模型结构

使用ImageNet数据集上预训练的ResNet-18作为源模型。当目标数据集小于源数据集时,微调有助于提升模型的泛化能力。指定pretrained=True自动下载并加载预训练的模型参数,在第一次使用时需要联网下载模型参数。

     pretrained_net = model_zoo.vision.resnet18_v2(pretrained=True)

预训练的源模型实例含有两个成员变量,即features和output。前者包含模型除输出层以外的所有层,后者为模型的输出层。这样划分主要是为了方便微调除输出层以外所有层的模型参数。下面打印源模型的成员变量输出。作为一个全连接层,它将ResNet最终的全局平均池化层输出变换成ImageNet数据集上1000类的输出。

新建一个神经网络作为目标模型。它的定义与预训练的源模型一样,但最后的输出个数等于目标数据集的类别数。在下面的代码中,目标模型实例finetune_net成员变量features中的模型参数被初始化为源模型相应层的模型参数。由于features中的模型参数是在ImageNet数据集上预训练得到的,比较成熟,因此,只需使用较小的学习率微调参数。而成员变量output中的模型参数采用了随机初始化,需要更大的学习率从头训练。假设Trainer实例中的学习率为η,设成员变量output中的模型参数在迭代中使用的学习率为10η。相关代码如下:

2.优化损失函数

确定模型架构并进行编译,这是多类别的分类问题,因此,需要使用交叉熵作为损失函数。由于所有的标签都带有相似的权重,经常使用精确度作为性能指标。Adam是常用的梯度下降方法,使用它来优化模型参数。