TensorFlow知识图谱实战
上QQ阅读APP看书,第一时间看更新

2.3.2 使用ResNet作为特征提取层建立模型

下面使用ResNet作为特征提取层建立一个特定的目标分类器,这里简单地进行二分类的分类,代码及其讲解如下:

【程序2-13】

一般来说,预训练的特征提取器是放在自定义的模型第一层,主要用作对数据集的特征进行提取,之后的全局池化层是对数据维度进行压缩,将4维的数据特征重新定义成2维,从而将特征从[batch_size,7,7,2048]降维到[batch_size,2048],读者可以自行打印查看。

drop_out_layer是屏蔽掉某些层用作防止过拟合的层,而fc_layer是用作对特定目标的分类层,这里通过设置unit参数为2定义分类为2个类。

最后一步是对定义的各个层进行组合,

    binary_classes = tf.keras.Sequential([resnet_layer,flatten_layer,
drop_out_layr,fc_layer])

Sequential函数将各个层组合成一个完整的模型,打印的模型结构如图2.19所示。

图2.19 组合成一个完整的模型

可以看到经过预训练的ResNet50被作为一个自定义的特征层去使用,因此在打印结果上ResNet50是一个整体,而其他相关层依次排列在模型后方。这里仅仅将ResNet50当成一个自定义的层来使用,因此这里依次显示了各个层的名称和参数,最下方是模型参数的总数。

下面还有一个问题是关于参数的,这里可以看到,基本上所有的参数都是可训练的,也就是在模型的训练过程中所有的参数都参与了计算和更新。对于某些任务来说,预训练模型的参数是不需要进行更新的,因此可以对ResNet50模型进行设置,代码如下:

【程序2-14】

相对于上一个代码段,这里额外设置了resnet_layer.trainable = False,显式地标注了resnet为不可训练的层,因此resnet的参数在模型中不参与训练。

这里有一个小技巧:通过模型的大概描述可以比较参数训练的多少,显示结果如图2.20所示。

图2.20 模型展示

从图2.20可以看到,这里Non-trainable的参数占了大部分,也就是ResNet模型参数不参与训练。读者可以自行比较。

注意

在使用ResNet模型做特征提取器的时候,由于Keras中ResNet50模型是使用imagenet数据集做的预训练模型,输入的数据最低为[224,224,3],因此如果使用相同的方法进行预训练模型的自定义,那么输入的数据维度最小要为[224,224,3]。

其他模型的调用请有兴趣的读者自行完成。