机器学习实战--决策树

决策树

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特诊数据
缺点:可能会产生过度匹配问题
使用数据类型:数值型和标称型
专家系统中,经常使用决策树

trees.py

1
2
from math import log	
import operator

createDataSet()

创建数据集

trees.py

1
2
3
4
5
6
7
8
9
10
def createDataSet():
# 数据集中两个特征'no surfacing','flippers', 数据的两个类标签'yes','no
#dataSet是个list
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers']
return dataSet, labels

calcShannonEnt(dataSet)

计算给定数据集的熵

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def calcShannonEnt(dataSet):
numEntries = len(dataSet) #计算数据集中实例的总数
labelCounts = {} #创建空字典
for featVec in dataSet: #提取数据集每一行的特征向量
currentLabel = featVec[-1] #获取特征向量最后一列的标签
# 检测字典的关键字key中是否存在该标签,如果不存在keys()关键字,将当前标签/0键值对存入字典中,并赋值为0
#print(labelCounts.keys())
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
#print(labelCounts)
labelCounts[currentLabel] += 1 #否则将当前标签对应的键值加1
#print("%s="%currentLabel,labelCounts[currentLabel])
shannonEnt = 0.0 #初始化熵为0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries #计算各值出现的频率
shannonEnt -= prob * log(prob,2) #以2为底求对数再乘以出现的频率,即信息期望值
#print("%s="%labelCounts[key],shannonEnt)
return shannonEnt

splitDataSet(dataSet, axis, value)

按照给定特征划分数据集
得到熵之后,还需划分数据集,以便判断当前是否正确地划分了数据集,三个输入参数分别为:带划分的数据集,划分数据集的特征,需要返回的特征得值,挑选出dataSet中axis位置值为value的剩余部分。

1
2
3
4
5
6
7
8
9
10
11
12
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value: #筛选出dataSet中axis位置值为value
#列表的索引中冒号的作用,a[1: ]表示该列表中的第1个元素到最后一个元素,而a[ : n]表示从第0歌元素到第n个元素(不包括n)
reducedFeatVec = featVec[:axis] #取出特定位置前面部分并赋值给reducedFeatVec
#print(featVec[axis+1:])
#print(reducedFeatVec)
reducedFeatVec.extend(featVec[axis+1:]) #取出特定位置后面部分并赋值给reducedFeatVec
retDataSet.append(reducedFeatVec)
#print(retDataSet)
return retDataSet

chooseBestFeatureToSplit(dataSet)

选择最好的数据集划分方式
选取特征,划分数据集,计算得出最好的划分数据集的特征

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #计算特征数量,即每一列表元素具有的列数,再减去最后一列为标签,故需减去1
baseEntropy = calcShannonEnt(dataSet) #计算信息熵,此处值为0.9709505944546686,此值将与划分之后的数据集计算的信息熵进行比较
bestInfoGain = 0.0;bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet] #创建标签列表
#print(featList)
uniqueVals = set(featList) #确定某一特征下所有可能的取值,set集合类型中的每个值互不相同
#print(uniqueVals)
newEntropy = 0.0
for value in uniqueVals: #计算每种划分方式的信息熵
subDataSet = splitDataSet(dataSet, i, value) #抽取该特征的每个取值下其他特征的值组成新的子数据集
prob = len(subDataSet)/float(len(dataSet)) #计算该特征下的每一个取值对应的概率(或者说所占的比重)
newEntropy += prob * calcShannonEnt(subDataSet) #计算该特征下每一个取值的子数据集的信息熵,并求和
infoGain = baseEntropy - newEntropy #计算每个特征的信息增益
#print("第%d个特征是的取值是%s,对应的信息增益值是%f"%((i+1),uniqueVals,infoGain))
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
#print("第%d个特征的信息增益最大,所以选择它作为划分的依据,其特征的取值为%s,对应的信息增益值是%f"%((i+1),uniqueVals,infoGain))
return bestFeature

majorityCnt(classList)

递归构建决策树,返回出现次数最多的分类名称

1
2
3
4
5
6
7
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]

createTree(dataSet,labels)

创建树,参数为数据集和标签列表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet] #提取dataset中的最后一列——种类标签
#print(classList)
if classList.count(classList[0]) == len(classList): #计算classlist[0]出现的次数,如果相等,说明都是属于一类,不用继续往下划分
return classList[0] #递归结束的第一个条件是所有的类标签完全相同,则直接返回该类标签
#print(dataSet[0])
if len(dataSet[0]) == 1: #看还剩下多少个属性,如果只有一个属性,但是类别标签有多个,就直接用majoritycnt()进行整理,选取类别最多的作为返回值
return majorityCnt(classList) #递归结束的第二个条件是使用完了所有的特征,仍然不能将数据集划分成仅包含唯一类别的分组,则返回出现次数最多的类别
bestFeat = chooseBestFeatureToSplit(dataSet) #选取信息增益最大的特征作为下一次分类的依据
bestFeatLabel = labels[bestFeat] #选取特征对应的标签
#print(bestFeatLabel)
myTree = {bestFeatLabel:{}} #创建tree字典,下一个特征位于第二个大括号内,循环递归
del(labels[bestFeat]) #删除使用过的特征
featValues = [example[bestFeat] for example in dataSet] #特征值对应的该栏数据
#print(featValues)
uniqueVals = set(featValues) #找到featvalues所包含的所有元素,去重复
for value in uniqueVals:
subLabels = labels[:] #将使用过的标签删除更新后,赋值给新的列表,进行迭代
#print(subLabels)
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat,value),subLabels) #循环递归生成树
return myTree

classify(inputTree,featLabels,testVec):

测试算法,使用决策树执行分类

1
2
3
4
5
6
7
8
9
10
11
12
13
def classify(inputTree,featLabels,testVec):
firstStr = list(inputTree.keys())[0] #找到树的第一个分类特征,或者说根节点'no surfacing'
#print(firstStr)
secondDict = inputTree[firstStr] #从树中得到该分类特征的分支,有0和1
#print(secondDict)
featIndex = featLabels.index(firstStr) #根据分类特征的索引找到对应的标称型数据值,'no surfacing'对应的索引为0
#print(featIndex)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else: classLabel = valueOfFeat
return classLabel

storeTree(inputTree,filename)

决策树的存储,使用pickle序列化对象,可在磁盘中保存对象。

1
2
3
4
5
6
7
8
9
10
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'wb') #二进制写入'wb'
pickle.dump(inputTree,fw) #pickle的dump函数将决策树写入文件中
fw.close()

def grabTree(filename):
import pickle
fr = open(filename,'rb') #对应于二进制方式写入数据,'rb'采用二进制形式读出数据
return pickle.load(fr)

trees_main.py

1
2
3
import trees
from imp import reload
import treePlotter

创建数据集

1
2
3
4
myDat,labels=trees.createDataSet()
#print(myDat)
#print(labels)
#print(trees.calcShannonEnt(myDat))

熵增大的原因

熵越高,混合的数据就越多,如果我们在数据集中添加更多的分类,会导致熵结果增大

1
2
3
#myDat[1][-1]='maybe'#更改list中某一元素的值(除yes和no外的值),即为添加更多的分类,中括号中为对应元素行列的位置
#print(myDat)
#print(trees.calcShannonEnt(myDat)) #分类变多,熵增大

append()和extend()两类方法的区别

1
2
3
4
5
6
a=[1,2,3]
b=[4,5,6]
a.append(b)
#print(a)#[1, 2, 3, [4, 5, 6]]
a.extend(b)
#print(a)#[1, 2, 3, [4, 5, 6], 4, 5, 6]

按照给定特征划分数据集

1
2
3
#print(myDat)
#print(trees.splitDataSet(myDat,0,1))
#print(trees.splitDataSet(myDat,0,0))

选择最好的数据集划分方式

1
2
#print(myDat)
#print(trees.chooseBestFeatureToSplit(myDat))

创建树,参数为数据集和标签列表

1
2
3
4
5
6
7
8
myTree=trees.createTree(myDat,labels)
#print(myTree)

myDat,labels=trees.createDataSet()
myTree1=treePlotter.retrieveTree(0)
#print(myTree1)
#print(trees.classify(myTree1,labels,[1,0]))
#print(trees.classify(myTree,labels,[1,1]))

决策树的存储

1
2
trees.storeTree(myTree,'classifierStorage.txt')
#print(trees.grabTree('classifierStorage.txt'))

使用决策树预测隐形眼镜类型

1
2
3
4
5
6
fr=open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()] #将文本数据的每一个数据行按照tab键分割,并依次存入lenses
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] # 创建并存入特征标签列表
lensesTree = trees.createTree(lenses, lensesLabels) # 根据继续文件得到的数据集和特征标签列表创建决策树
print(lensesTree)
treePlotter.createPlot(lensesTree)

treePlotter.py

python中使用Matplotlib注解绘制树形图

1
import matplotlib.pyplot as plt

定义文本框和箭头格式

1
2
3
decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细,pad指的是外边框锯齿形(圆形等)的大小
leafNode = dict(boxstyle="round4", fc="0.8") #定义决策树的叶子结点的描述属性,round4表示圆形
arrow_args = dict(arrowstyle="<-") #定义箭头属性

plotNode(nodeTxt, centerPt, parentPt, nodeType)

绘制带箭头的注解
annotate是关于一个数据点的文本
nodeTxt为要显示的文本,centerPt为文本的中心点,箭头所在的点,parentPt为指向文本的点
annotate的作用是添加注释,nodetxt是注释的内容
nodetype指的是输入的节点(边框)的形状

1
2
3
4
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

def createPlot():

第一版构造树函数,后面会改进,所以这里要注释上

1
2
3
4
5
6
#fig = plt.figure(1, facecolor='white')
#fig.clf()
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
#plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#plt.show()

getNumLeafs(myTree)

计算叶子节点的个数
构造注解树,需要知道叶节点的个数,以便可以正确确定x轴的长度;要知道树的层数,可以确定y轴的高度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def getNumLeafs(myTree):    
numLeafs = 0
firstStr = list(myTree.keys())[0] #获得myTree的第一个键值,即第一个特征,分割的标签
#print(firstStr)
secondDict = myTree[firstStr] #根据键值得到对应的值,即根据第一个特征分类的结果
#print(secondDict)
for key in secondDict.keys(): #获取第二个小字典中的key
if type(secondDict[key]).__name__=='dict':
#判断是否小字典中是否还包含新的字典(即新的分支)
numLeafs += getNumLeafs(secondDict[key]) #包含的话进行递归从而继续循环获得新的分支所包含的叶节点的数量
else: numLeafs +=1 #不包含的话就停止迭代并把现在的小字典加一表示这边有一个分支
return numLeafs

def getTreeDepth(myTree): #计算判断节点的个数
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth

retrieveTree(i)

预先存储树信息

1
2
3
4
5
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]

plotMidText(cntrPt, parentPt, txtString)

作用是计算tree的中间位置,cntrPt起始位置,parentPt终止位置,txtString文本标签信息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] #cntrPt起点坐标,子节点坐标,parentPt结束坐标,父节点坐标
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] #找到x和y的中间位置
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #计算子节点的坐标
plotMidText(cntrPt, parentPt, nodeTxt) #绘制线上的文字
plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制节点
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #每绘制一次图,将y的坐标减少1.0/plottree.totald,间接保证y坐标上深度的
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

createPlot(inTree)

1
2
3
4
5
6
7
8
9
10
11
def createPlot(inTree):
fig = plt.figure(1, facecolor='white') #类似于Matlab的figure,定义一个画布,背景为白色
fig.clf() # 把画布清空
axprops = dict(xticks=[], yticks=[]) #subplot定义了一个绘图
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1为全局变量,绘制图像的句柄,111表示figure中的图有1行1列,即1个,最后的1代表第一个图,frameon表示是否绘制坐标轴矩形
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()

treePlotter_main.py

1
2
3
4
5
6
7
8
import  treePlotter
#treePlotter.createPlot()
#print(treePlotter.retrieveTree(1))
myTree=treePlotter.retrieveTree(0)
#print(treePlotter.getNumLeafs(myTree))
#print(treePlotter.getTreeDepth(myTree))
myTree['no surfacing'][3]='maybe'
treePlotter.createPlot(myTree)
-------------本文结束感谢您的阅读-------------
AmberWu wechat
欢迎大家扫码交流!