决策树
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特诊数据
缺点:可能会产生过度匹配问题
使用数据类型:数值型和标称型
专家系统中,经常使用决策树
trees.py
1 | from math import log |
createDataSet()
创建数据集
trees.py
1 | def createDataSet(): |
calcShannonEnt(dataSet)
计算给定数据集的熵
1 | def calcShannonEnt(dataSet): |
splitDataSet(dataSet, axis, value)
按照给定特征划分数据集
得到熵之后,还需划分数据集,以便判断当前是否正确地划分了数据集,三个输入参数分别为:带划分的数据集,划分数据集的特征,需要返回的特征得值,挑选出dataSet中axis位置值为value的剩余部分。
1 | def splitDataSet(dataSet, axis, value): |
chooseBestFeatureToSplit(dataSet)
选择最好的数据集划分方式
选取特征,划分数据集,计算得出最好的划分数据集的特征
1 | def chooseBestFeatureToSplit(dataSet): |
majorityCnt(classList)
递归构建决策树,返回出现次数最多的分类名称
1 | def majorityCnt(classList): |
createTree(dataSet,labels)
创建树,参数为数据集和标签列表
1 | def createTree(dataSet,labels): |
classify(inputTree,featLabels,testVec):
测试算法,使用决策树执行分类
1 | def classify(inputTree,featLabels,testVec): |
storeTree(inputTree,filename)
决策树的存储,使用pickle序列化对象,可在磁盘中保存对象。
1 | def storeTree(inputTree,filename): |
trees_main.py
1 | import trees |
创建数据集
1 | myDat,labels=trees.createDataSet() |
熵增大的原因
熵越高,混合的数据就越多,如果我们在数据集中添加更多的分类,会导致熵结果增大
1 | #myDat[1][-1]='maybe'#更改list中某一元素的值(除yes和no外的值),即为添加更多的分类,中括号中为对应元素行列的位置 |
append()和extend()两类方法的区别
1 | a=[1,2,3] |
按照给定特征划分数据集
1 | #print(myDat) |
选择最好的数据集划分方式
1 | #print(myDat) |
创建树,参数为数据集和标签列表
1 | myTree=trees.createTree(myDat,labels) |
决策树的存储
1 | trees.storeTree(myTree,'classifierStorage.txt') |
使用决策树预测隐形眼镜类型
1 | fr=open('lenses.txt') |
treePlotter.py
python中使用Matplotlib注解绘制树形图
1 | import matplotlib.pyplot as plt |
定义文本框和箭头格式
1 | decisionNode = dict(boxstyle="sawtooth", fc="0.8") # boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细,pad指的是外边框锯齿形(圆形等)的大小 |
plotNode(nodeTxt, centerPt, parentPt, nodeType)
绘制带箭头的注解
annotate是关于一个数据点的文本
nodeTxt为要显示的文本,centerPt为文本的中心点,箭头所在的点,parentPt为指向文本的点
annotate的作用是添加注释,nodetxt是注释的内容
nodetype指的是输入的节点(边框)的形状
1 | def plotNode(nodeTxt, centerPt, parentPt, nodeType): |
def createPlot():
第一版构造树函数,后面会改进,所以这里要注释上
1 | #fig = plt.figure(1, facecolor='white') |
getNumLeafs(myTree)
计算叶子节点的个数
构造注解树,需要知道叶节点的个数,以便可以正确确定x轴的长度;要知道树的层数,可以确定y轴的高度。
1 | def getNumLeafs(myTree): |
retrieveTree(i)
预先存储树信息
1 | def retrieveTree(i): |
plotMidText(cntrPt, parentPt, txtString)
作用是计算tree的中间位置,cntrPt起始位置,parentPt终止位置,txtString文本标签信息
1 | def plotMidText(cntrPt, parentPt, txtString): |
createPlot(inTree)
1 | def createPlot(inTree): |
treePlotter_main.py
1 | import treePlotter |