1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
| import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth",fc="0.8") leafNode = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt,xy=parentPt,\ xycoords='axes fraction',\ xytext=centerPt, textcoords='axes fraction',\ va="center", ha="center",bbox=nodeType,arrowprops=arrow_args)
def createPlot(): fig = plt.figure(1,facecolor='white') fig.clf() createPlot.ax1 = plt.subplot(111,frameon=False) plotNode(u'决策节点',(0.5,0.1),(0.1,0.5),decisionNode) plotNode(u'叶节点',(0.8,0.1),(0.3,0.8),leafNode) plt.show()
def getNumLeafs(myTree): numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs
def getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': thisDepth = 1+ getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth
def retrieveTree(i): listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':\ {0:'no',1:'yes'}}}}, {'no surfacing':{0:'no',1:{'flippers':\ {0:{'head':{0:'no',1:'yes'}},1:'no'}}}}] return listOfTrees[i] def plotMidText(cntrPt,parentPt,txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid,yMid,txtString) def plotTree(myTree,parentPt,nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = myTree.keys()[0] cntrPt = (plotTree.xOff + (1.0+float(numLeafs))/2.0/plotTree.totalW,\ plotTree.yOff) plotMidText(cntrPt,parentPt,nodeTxt) plotNode(firstStr,cntrPt,parentPt,decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': plotTree(secondDict[key],cntrPt,str(key)) else: plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),\ cntrPt,leafNode) plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree): fig = plt.figure(1,facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111,frameon=False,**axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree,(0.5,1.0),'') plt.show()
|