Skip to content

Commit

Permalink
Merge pull request apachecn#374 from jingwangfei/patch-2
Browse files Browse the repository at this point in the history
add get_tree_height method
  • Loading branch information
jiangzhonglian authored May 7, 2018
2 parents a764418 + 9be2e8d commit 145ba63
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions src/py2.x/3.DecisionTree/DecisionTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def fishTest():
print myTree
# [1, 1]表示要取的分支上的节点位置,对应的结果值
print classify(myTree, labels, [1, 1])

# 获得树的高度
print get_tree_height(myTree)

# 画图可视化展现
dtPlot.createPlot(myTree)
Expand All @@ -353,6 +356,32 @@ def ContactLensesTest():
print lensesTree
# 画图可视化展现
dtPlot.createPlot(lensesTree)


def get_tree_height(tree):
"""
Desc:
递归获得决策树的高度
Args:
tree
Returns:
树高
"""

if not isinstance(tree, dict):
return 1

child_trees = tree.values()[0].values()

# 遍历子树, 获得子树的最大高度
max_height = 0
for child_tree in child_trees:
child_tree_height = get_tree_height(child_tree)

if child_tree_height > max_height:
max_height = child_tree_height

return max_height + 1


if __name__ == "__main__":
Expand Down

0 comments on commit 145ba63

Please sign in to comment.