1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
| decisionNode = dict(boxstyle="square", ec='k', fc='yellow',) leafNode = dict(boxstyle="round", ec=(1., 0.5, 0.5), fc=(1., 0.8, 0.8),) arrow_args = dict(arrowstyle='<-')
def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
def createPlot(): fig=plt.figure(1,facecolor = 'white') fig.clf() createPlot.ax1=plt.subplot(111, frameon=False) plotNode('决策节点', (0.5,0.1), (0.1,0.5), decisionNode) plotNode('叶节点', (0.8,0.1), (0.3,0.8), leafNode) plt.show()
def getNumLeafs(myTree): numLeafs=0 firstStr=list(myTree.keys())[0] secondDict=myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': numLeafs+=getNumLeafs(secondDict[key]) else:numLeafs+=1 return numLeafs
def getTreeDepth(myTree): maxDepth=0 firstStr=list(myTree.keys())[0] secondDict=myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': thisDepth=1+getTreeDepth(secondDict[key]) else: thisDepth=1 if thisDepth > maxDepth: maxDepth=thisDepth return maxDepth
def plotMidText(cntrPt,parentPt,txtString): xMid=(parentPt[0]-cntrPt[0])/2+cntrPt[0] yMid=(parentPt[1]-cntrPt[1])/2+cntrPt[1] createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree, parentPt, nodeTxt): numLeafs=getNumLeafs(myTree) depth=getTreeDepth(myTree) firstStr=list(myTree.keys())[0] cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2/plotTree.totalW,plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict=myTree[firstStr] plotTree.yOff=plotTree.yOff -1/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode) plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key)) plotTree.yOff=plotTree.yOff+1/plotTree.totalD
def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5, 1.0), '') plt.show()
|