3.4 决策树
3.4.1 原理概述
决策树可以看成一棵由“问询题”构成的树,对于识别和分类问题,它通过一系列问题的询问实现对物体的类型判断。这一思想在日常生活中经常用到。比如当你接到一个迷路的人的求助电话时,你会通过询问尽可能少的问题来确定他的位置,可能询问的问题包括“你在马路边还是居民区?”“附近有停车场吗?”“身后是邮局吗?”这一系列问题中前一个问题的回答决定了下一个问题的询问内容。
决策树分类过程如图3-8所示。
图3-8 决策树分类示意图
决策树的核心是如图3-8所示的由一系列简单数值判断构成的“树”,它作用于输入数据的特征向量上,其中特征向量的构成有各种方式,比如对音频数据序列,可用的特征包括方差、最大值和最小值之差、傅里叶变换的高频部分能量、特定卷积核的卷积输出序列平方和等。决策树的构建要求通过尽可能小的树形结构来实现尽可能精确的分类。
3.4.2 模型训练和推理
决策树构建的数学原理在这里不详细介绍,我们在代码清单3-3中给出构建决策树的Python代码示例。
代码清单3-3 决策树训练和测试示例
import numpy as np from sklearn.tree import DecisionTreeClassifier # 加载测试数据 import sklearn.datasets as datasets data=datasets.load_iris() TREE_DEP=5 x,y=data['data'],data['target'] # 训练/测试数据集分离 from sklearn.model_selection import train_test_split x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.5,shuffle=True) # 决策树训练 cls = DecisionTreeClassifier(random_state=0, max_depth=TREE_DEP) cls.fit(x_train, y_train) # 训练结果测试 y_pred=cls.predict(x_test) print('[INF] ACC: %.4f%%'%(100.0*(y_pred==y_test).astype(float).mean()))
代码清单3-3中使用了鸢尾花卉分类问题的数据进行分类。程序中决策树分类器的生成通过下面两句话实现:
cls = DecisionTreeClassifier(random_state=0, max_depth= TREE_DEP) cls.fit(x_train, y_train)
其中第一行代码构建DecisionTreeClassifier类的对象cls,参数TREE_DEP是决策树的深度,也就是每一次分类时最多允许问的问题数目。TREE_DEP和分类性能有关,它的值需要用户手动选择。决策树的“提问结构”的生成通过上面第2行代码实现。
训练得到的决策树的使用通过下面的语句实现:
y_pred=cls.predict(x_test)
其中x_test表示数据帧,用于存放多个测试数据,每一行存放一朵需要分类的花的4个测量数据。
3.4.3 决策树分类器的代码实现
决策树结构简单,运算量小,很适合嵌入式平台实现,上面基于Python下的Scikit-Learn软件包训练得到的分类器不能直接在嵌入式系统上运行,但我们可以从分类器中提取分类判断的树形结构并生成C程序,这样就能够在嵌入式环境下使用。代码清单3-4中的Python程序将之前讨论的决策树分类器cls内部的数据导出生成C语言程序,用于在嵌入式环境下运行数据分类任务。
代码清单3-4 从Scikit-Learn的决策树数据结构生成C程序的代码
# 自动代码生成(C语言) # 输入: # tree -- Scikit-Learn输出的决策树 def tree_to_c_code(tree): with open('tree.c', 'wt') as fout: def recurse(node=0, prefix=' '): if left[node]==-1 and right[node]==-1: # leaf node for i,v in enumerate(value[node][0]): if v==0: continue fout.write(prefix+'target['+str(i)+']='+\ str(int(v))+';\n') else: fout.write(prefix + 'if (feature['+\ str(feature[node])+']<=' +\ str(threshold[node])+')\n') fout.write(prefix + '{\n') if t.children_left[node] != -1: recurse(left[node], prefix+' ') fout.write(prefix + '}\n') fout.write(prefix + 'else\n') fout.write(prefix + '{\n') if t.children_right[node] != -1: recurse(right[node], prefix+' ') fout.write(prefix + '}\n') t=tree.tree_ left,right=t.children_left,t.children_right value,threshold,feature=t.value,t.threshold,t.feature fout.write('#include "tree.h"\n\n') fout.write('void tree(float *feature, int *target)\n') fout.write('{\n') fout.write(' for (int n=0; n<NUM_CLS; n++)\n') fout.write(' target[n]=0;\n') recurse() fout.write('}\n\n')
使用上述代码时通过调用tree_to_c_code(cls)生成C语言源代码,其中cls是训练得到的决策树分类器。生成的代码文件是tree.c和tree.h。其部分内容如下所示:
·tree.c
#include "tree.h" void tree(float *feature, int *target) { for (int n=0; n<NUM_CLS; n++) target[n]=0; if (feature[3]<=0.800000011920929) { target[0]=26; } else { if (feature[2]<=5.0) { if (feature[3]<=1.600000023841858) { target[1]=24; } else ……
·tree.h
#ifndef __TREE_H__ #define __TREE_H__ #define NUM_CLS 3 #define NUM_DIM 4 #endif
其中tree.c里面是一个复杂的if-else-if结构,它实现了决策树的判断逻辑。核心函数为void tree(float*feature,int*target),输入是数组feature,存放待分类的数据特征,比如在鸢尾花分类问题里就是存放鸢尾花的4个测量数值的数组指针。target数组存放分类得分,调用tree函数后在target数组内填入了对应各个类别的得分,一般选得分最高的那个元素的数组下标作为分类结果。