现代决策树模型及其编程实践:从传统决策树到深度决策树
上QQ阅读APP看书,第一时间看更新

2.2.7 CART回归决策树的编程实践

在CART分类决策树中使用基尼系数作为寻找最优划分点的依据,在回归树中则采用均方误差最小化准则作为特征和分割点的选择方法,下面我们基于这种算法来实现回归树的模型。

2.2.7.1 整体流程

首先介绍一下整体流程,如代码段2.14所示。与CART分类树类似,主要由四部分组成:数据集加载、模型训练、模型预测和决策树可视化。

代码段2.14 CART回归树测试主程序(源码位于Chapter02/test_CartRegressor.py)

1. 数据集加载(第15~29行)

数据集使用的是表2.11中的“流行歌手喜好度”数据集。与2.2.3节的例子类似,首先在第18~19行利用“with open”语法和csv库读取数据集文件,并将其转换成list类型。与之不同的是,接下来第20~24行针对数据集的第2列执行数据预处理,将非数值型字符串(“性别”一列的属性值)转化成数字。最后在第25~29行进行数据集的划分和数据类型的转换。

2. 决策树模型的训练和生成(第31~38行)

在CART回归树的创建和训练过程中,与CART分类树相比,回归树使用CartRegressor代替CartClassifier,并且指定了浮点型切分点保留的小数点后有效数字位数。训练过程对外提供的接口与CART分类树相同,但是其内部的训练细节会有所差异,在下文中会对此做详细描述。我们先来看一下使用上述数据集训练得到的决策树模型。

该model变量实际上是由一组规则表示的。以上输出结果为决策树的字典(树形结构)数据结构形式,在这棵model树中,从根节点到每个叶子节点的每条路径都代表一条规则。为了更清晰地表示规则,我们可以将以上数据结构转换成“if-then”的格式,如下所示:

由此可以看出,CART回归决策树与一组“if-then”规则是等价的。

3. 决策树模型的使用(第40~45行)

在第42~44行的模型预测阶段,调用CartRegressor类的成员函数predict,传入测试集数据X_test,返回numpy.array类型的预测结果y_pred,并且打印输出测试集的真实值y_test和预测值y_pred。

在第45行的模型评估阶段,调用sklearn的r2_score函数计算R2指标。R2指标是用于评估回归问题预测性能的一种指标,R2指标越大,代表预测性能越好。在这里调用r2_score函数前,需要使用“from sklearn.metrics import r2_score”语句将其引入当前环境。实际执行结果如下:

4. 决策树可视化(第47~49行)

决策树可视化阶段导入了tree_plotter包的tree_plot函数,关于tree_plotter包的详细介绍可以回看2.2.3节。在tree_plot函数中传入训练好的模型,底层借助Matplotlib进行可视化,效果如图2.19所示。

图2.19 “流行歌手喜好度”数据集生成的CART回归树

2.2.7.2 训练和创建过程

下面展开介绍CART回归树的训练和创建过程。CartRegressor的创建和训练过程与CartClassifier类似。不同点在于CartRegressor去掉了建立字符串与数值映射的功能,并且将数据预处理部分移到类外部定制。另外,最重要的区别在于模型训练时切分点的选取。接下来,我们逐一分析这些不同点在CART回归树的训练和创建过程中的表现。

首先介绍一下CartRegressor类的构造。从代码段2.15中可以看到,CartRegressor类的实现依然依赖于torch和numpy。在构造函数__init__中,依然需要提供use_gpu和min_samples_split两个参数。与CartClassifier类不同的是,新增加了bit参数,bit用来表示连续属性离散化时精确的小数点位数,默认保留2位。另外,CartRegressor类删掉了CartClassifier类中用于建立字符串与数值映射的成员变量,读者可以参照CartClassifier类对比学习。

接下来介绍CART回归树与分类树在训练过程中的不同,如代码段2.15和2.16所示。在函数__create_tree中,主要有3处与分类树不同。第一处在第87~89行,当满足递归终止条件“节点样本数小于self.min_samples_split”时,返回的预测值是该集合中所有目标变量的平均值。第二处在第91~93行,差异集中在__choose_best_point_to_split函数中,在回归树中采用“平方误差最小”的原则来选择最优切分点,该部分内容稍后做重点讲解。第三处在第95~105行,使用最优属性和最优切分点划分数据集时相较分类树(处理匹配字符串“<=”和“>”的代码逻辑)做略微调整。

代码段2.15 CART分类树创建过程(源码位于Chapter02/CartRegressor.py)

代码段2.16 创建CART回归树的核心代码(源码位于Chapter02/CartRegressor.py)

__choose_best_point_to_split函数如代码段2.17所示。在第132~155行遍历所有属性值时,回归树中不再计算基尼不纯度和基尼增益,而是针对回归问题计算损失函数。其中,第144行和第147行分别计算了使用当前切分点划分的左右子树的残差平方和,第149行计算左右子树的总残差平方和。最后选出取得最小损失函数的切分点和属性索引,作为最优切分点和最优分裂属性。对比2.2.3节计算基尼增益的方法,此处计算最小平方误差的方法具有异曲同工之妙。

代码段2.17 选择最优切分点(源码位于Chapter02/CartRegressor.py)

最后是CART回归树的预测过程和可视化过程。由于CART回归树与分类树的预测过程和可视化过程几乎完全相同,在此不做赘述,请读者参考2.2.3节CART分类树的预测和可视化代码。

以上即为CART回归树针对2.2.6节的“流行歌手喜好度”数据集进行编程实践的全部过程。