TensorFlow知识图谱实战
上QQ阅读APP看书,第一时间看更新

2.1.2 使用Keras API实现鸢尾花分类的例子(顺序模型)

iris数据集是常用的分类实验数据集,由Fisher于1936年收集整理。iris也称鸢尾花卉数据集,是一类用于多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度、花萼宽度、花瓣长度、花瓣宽度4个属性预测鸢尾花卉(见图2.2)属于Setosa、Versicolour、Virginica这3个种类中的哪一类。

图2.2 鸢尾花

第一步:数据的准备

不需要读者下载这个数据集,一般常用的机器学习工具自带iris数据集,引入数据集的代码如下:

     from sklearn.datasets import load_iris
     data = load_iris()

这里调用的是sklearn数据库中的iris数据集,直接载入即可。

而其中的数据又是以key-value值对应存放,key值如下:

由于本例中需要iris的特征与分类目标,因此这里只需要获取data和target,代码如下:

数据打印结果如图2.3所示。

图2.3 数据打印结果

这里是分别打印了前5条数据。可以看到iris数据集中分成了4个不同特征进行数据记录,而每条特征又对应于一个分类表示。

第二步:数据的处理

下面就是数据处理部分,对特征的表示不需要变动。而对于分类表示的结果,全部打印结果如图2.4所示。

图2.4 数据处理

这里按数字分成了3类,0、1和2分别代表3种类型。如果按直接计算的思路,可以将数据结果向固定的数字进行拟合,这是一个回归问题,即通过回归曲线去拟合出最终结果。但是本例实际上是一个分类任务,因此需要对其进行分类处理。

分类处理的一个非常简单的方法就是进行one-hot处理,即将一个序列化数据分到不同的数据领域空间进行表示,如图2.5所示。

图2.5 one-hot处理

具体在程序处理上,读者可以手动实现one-hot的编码表示,也可以使用Keras自带的分散工具对数据进行处理,代码如下:

    iris_target =
np.float32(tf.keras.utils.to_categorical(iris_target,num_classes=3))

这里的num_classes表示分成了3类,用一行三列对每个类别进行表示。

交叉熵函数与分散化表示的方法超出了本书的讲解范围,这里就不再过多介绍,读者只需要知道交叉熵函数需要和softmax配合,从分布上向离散空间靠拢即可。

     iris_data = tf.data.Dataset.from_tensor_slices(iris_data).batch(50)
     iris_target = tf.data.Dataset.from_tensor_slices(iris_target).batch(50)

当生成的数据读取到内存中并准备以批量的形式打印,使用的是tf.data.Dataset.from_tensor_slices函数,并且可以根据具体情况对batch进行设置。关于tf.data.Dataset函数更多的细节和用法在后面章节中会专门介绍。

第三步:梯度更新函数的写法

梯度更新函数是根据误差的幅度对数据进行更新的方法,代码如下:

     grads = tape.gradient(loss_value, model.trainable_variables)
     opt.apply_gradients(zip(grads, model.trainable_variables))

与前面线性回归例子的差别是,使用的模型直接获取参数的方式对数据自动进行更新而非人为指定,这一点请读者注意。至于人为的指定和排除某些参数的方法属于高级程序设计,在后面的章节会提到。

【程序2-1】

最终打印结果如图2.6所示。可以看到损失值在符合要求的条件下不停降低,达到了预期目标。

图2.6 打印结果