2.1.2 使用Keras API实现鸢尾花分类的例子(顺序模型)
iris数据集是常用的分类实验数据集,由Fisher于1936年收集整理。iris也称鸢尾花卉数据集,是一类用于多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度、花萼宽度、花瓣长度、花瓣宽度4个属性预测鸢尾花卉(见图2.2)属于Setosa、Versicolour、Virginica这3个种类中的哪一类。
图2.2 鸢尾花
第一步:数据的准备
不需要读者下载这个数据集,一般常用的机器学习工具自带iris数据集,引入数据集的代码如下:
from sklearn.datasets import load_iris data = load_iris()
这里调用的是sklearn数据库中的iris数据集,直接载入即可。
而其中的数据又是以key-value值对应存放,key值如下:
由于本例中需要iris的特征与分类目标,因此这里只需要获取data和target,代码如下:
数据打印结果如图2.3所示。
图2.3 数据打印结果
这里是分别打印了前5条数据。可以看到iris数据集中分成了4个不同特征进行数据记录,而每条特征又对应于一个分类表示。
第二步:数据的处理
下面就是数据处理部分,对特征的表示不需要变动。而对于分类表示的结果,全部打印结果如图2.4所示。
图2.4 数据处理
这里按数字分成了3类,0、1和2分别代表3种类型。如果按直接计算的思路,可以将数据结果向固定的数字进行拟合,这是一个回归问题,即通过回归曲线去拟合出最终结果。但是本例实际上是一个分类任务,因此需要对其进行分类处理。
分类处理的一个非常简单的方法就是进行one-hot处理,即将一个序列化数据分到不同的数据领域空间进行表示,如图2.5所示。
图2.5 one-hot处理
具体在程序处理上,读者可以手动实现one-hot的编码表示,也可以使用Keras自带的分散工具对数据进行处理,代码如下:
iris_target = np.float32(tf.keras.utils.to_categorical(iris_target,num_classes=3))
这里的num_classes表示分成了3类,用一行三列对每个类别进行表示。
交叉熵函数与分散化表示的方法超出了本书的讲解范围,这里就不再过多介绍,读者只需要知道交叉熵函数需要和softmax配合,从分布上向离散空间靠拢即可。
iris_data = tf.data.Dataset.from_tensor_slices(iris_data).batch(50) iris_target = tf.data.Dataset.from_tensor_slices(iris_target).batch(50)
当生成的数据读取到内存中并准备以批量的形式打印,使用的是tf.data.Dataset.from_tensor_slices函数,并且可以根据具体情况对batch进行设置。关于tf.data.Dataset函数更多的细节和用法在后面章节中会专门介绍。
第三步:梯度更新函数的写法
梯度更新函数是根据误差的幅度对数据进行更新的方法,代码如下:
grads = tape.gradient(loss_value, model.trainable_variables) opt.apply_gradients(zip(grads, model.trainable_variables))
与前面线性回归例子的差别是,使用的模型直接获取参数的方式对数据自动进行更新而非人为指定,这一点请读者注意。至于人为的指定和排除某些参数的方法属于高级程序设计,在后面的章节会提到。
【程序2-1】
最终打印结果如图2.6所示。可以看到损失值在符合要求的条件下不停降低,达到了预期目标。
图2.6 打印结果