决策树可以使用不熟悉的数据集合,并从中提取出一系列的规则,这是机器根据数据集创建规则的过程,就是机器学习的过程。用一个小案例分析:
通过No surfacing 和 flippers判断该生物是否是鱼,No surfacing 是离开水面是否可以生存,flippers判断是否有脚蹼
引入信息增益和信息熵的概念:
信息熵:计算熵,我们需要计算所有类别所有可能值包含的信息期望值。
p(x)是类别出现的概率
条件熵(表示在已知随机变量X的条件下随机变量Y的不确定性。):

信息增益(划分数据集前后的信息发生的变化,通俗的说,就是信息熵减去条件熵):

代码实现:
def createDataSet(): dataSet = [ [1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no'] ] labels = ['no surfacing','flippers'] return dataSet,labels
讯享网
计算原始熵:
讯享网def calcShannonEnt(dataSet): numEntries = len( dataSet) labelCounts = { } for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts [currentLabel]+=1 shannonEnt = 0.0 for key in labelCounts : prob = float (labelCounts[key])/numEntries shannonEnt -=prob * log(prob,2) return shannonEnt 划分数据集
def splitDataSet(dataSet,axis,value): # 待划分的数据集 ,划分数据集的特征,需要返回的特征的值 retDataSet=[] for featVec in dataSet: if featVec[axis] == value : reduceFeatVec=featVec[:axis] #取不到axis这一行 reduceFeatVec.extend(featVec[axis+1:]) retDataSet.append(reduceFeatVec) return retDataSet
测试数据及结果:


计算出条件熵,然后求出信息增益,并找到最大的信息增益,最大的信息增益就是找到最好的划分数据集的特征
讯享网def chooseBestFeatureToSplit(dataSet): numFeatures=len(dataSet[0])-1 #计算出原始的香农熵 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0; bestFeature =-1 for i in range (numFeatures): #创建唯一的分类标签列表 featList = [example[i] for example in dataSet] uniqueVals = set (featList) #去重复 #条件熵的初始化 newEntropy = 0.0 for value in uniqueVals : #划分 获得数据集 subDataSet = splitDataSet(dataSet,i ,value) prob=len(subDataSet)/float(len(dataSet)) # 概率 #条件熵的计算 newEntropy += prob * calcShannonEnt (subDataSet) # 信息增益 infoGain = baseEntropy -newEntropy if (infoGain >bestInfoGain): bestInfoGain = infoGain #找到最大的信息增益 bestFeature =i #找出最好的划分数据集的特征 return bestFeature
测试数据:
dataSet,labels = createDataSet() print(dataSet) print(chooseBestFeatureToSplit(dataSet))
输入结果:

投票机制:
讯享网def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys() : classCount[vote]=0 sortedClassCount = sorted (classCount.iteritems(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0]
创建树:
def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0] )== len (classList) : return classList[0] if len(dataSet[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree={bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set( featValues) for value in uniqueVals: subLabels =labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels) return myTree
结果:

该方法是用信息增益的方法来构建树,在查阅其他的博客得知:
ID3算法主要是通过信息增益的大小来判定,最大信息增益的特征就是当前节点,这个算法存在许多的不足,第一,它解决不了过拟合问题,和缺失值的处理,第二,信息增益偏向取值较多的特征,第三,不能处理连续特征问题。
因此,引入C4.5算法,是利用信息增益率来代替信息增益。为了减少过度匹配问题,我们通过剪枝来处理冗余的数据,生成决策树时决定是否要剪枝叫预剪枝,生成树之后进行交叉验证的叫后剪枝。
还有一个是引入基尼指数来进行计算叫CART树,以后再做介绍。
绘制树形图:
讯享网decisionNode = dict(boxstyle = "sawtooth", fc="0.8") leafNode = dict(boxstyle = "round4" ,fc="0.8") arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt,centerPt,parentPt,nodeType): createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction' ,\ xytext=centerPt ,textcoords='axes fraction',va="center" ,\ ha ="center" ,bbox=nodeType,arrowprops = arrow_args)
讯享网def getNumLeafs(myTree): numLeafs= 0 firstStr =list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': numLeafs +=getNumLeafs(secondDict[key]) else : numLeafs+=1 return numLeafs
def getTreeDepth(myTree) : maxDepth=0 firstStr =list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': thisDepth = 1+ getTreeDepth(secondDict[key]) else : thisDepth = 1 if thisDepth > maxDepth : maxDepth=thisDepth return maxDepth
讯享网def plotMidText(cntrPt , parentPt ,txtString) : xMid = (parentPt[0]-cntrPt[0])/2.0 +cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 +cntrPt[1] createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) plotMidText(cntrPt,parentPt,nodeTxt) plotNode(firstStr,cntrPt,parentPt,decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': plotTree(secondDict[key],cntrPt,str(key)) else: plotTree.xOff = plotTree.xOff +1.0 /plotTree.totalW plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode) plotMidText((plotTree.xOff,plotTree.yOff) ,cntrPt,str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
讯享网def createPlot(inTree) : fig = plt.figure(1,facecolor = 'white') fig.clf() axprops =dict(xticks=[],yticks=[]) createPlot.ax1 = plt.subplot(111,frameon = False ,axprops) plotTree.totalW =float(getNumLeafs(inTree)) plotTree.totalD=float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW;plotTree.yOff = 1.0 plotTree(inTree,(0.5,1.0),'') plt.show()
createPlot(myTree)
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/38520.html