Skip to content

Commit

Permalink
print() function in randomForest.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cclauss authored Jun 6, 2018
1 parent 91bf612 commit b5225b0
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/py3.x/ML/7.RandomForest/randomForest.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_split(dataset, n_features):
# 左右两边的数量越一样,说明数据区分度不高,gini系数越大
if gini < b_score:
b_index, b_value, b_score, b_groups = index, row[index], gini, groups # 最后得到最优的分类特征 b_index,分类特征值 b_value,分类结果 b_groups。b_value 为分错的代价成本
# print b_score
# print(b_score)
return {'index': b_index, 'value': b_value, 'groups': b_groups}


Expand Down Expand Up @@ -303,7 +303,7 @@ def evaluate_algorithm(dataset, algorithm, n_folds, *args):

# 加载数据
dataset = loadDataSet('input/7.RandomForest/sonar-all-data.txt')
# print dataset
# print(dataset)

n_folds = 5 # 分成5份数据,进行交叉验证
max_depth = 20 # 调参(自己修改) #决策树深度不能太深,不然容易导致过拟合
Expand All @@ -315,7 +315,7 @@ def evaluate_algorithm(dataset, algorithm, n_folds, *args):
scores = evaluate_algorithm(dataset, random_forest, n_folds, max_depth, min_size, sample_size, n_trees, n_features)
# 每一次执行本文件时都能产生同一个随机数
seed(1)
print 'random=', random()
print 'Trees: %d' % n_trees
print 'Scores: %s' % scores
print 'Mean Accuracy: %.3f%%' % (sum(scores)/float(len(scores)))
print('random=', random())
print('Trees: %d' % n_trees)
print('Scores: %s' % scores)
print('Mean Accuracy: %.3f%%' % (sum(scores)/float(len(scores))))

0 comments on commit b5225b0

Please sign in to comment.