-
Notifications
You must be signed in to change notification settings - Fork 3
/
cart.py
202 lines (168 loc) · 5.54 KB
/
cart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""
Cart for both regression and classifying
min[ min sum (y_i - c1)^2 + min sum (y_j - c2)^2]
二叉树
"""
import numpy as np
import scipy
import time
# import pysnooper
class Node:
"""
node
left right Node类
"""
def __init__(self, is_root=False):
self.is_root = is_root
self.x_data = None
self.y_value = None
# l,r node类
self.left = None
self.right = None
self.split_var = None # 分割变量 str
self.split_value = None
@property
def isend(self):
"""判断是否是一棵树的end"""
return self.left is None and self.right is None
class Tree:
def __init__(self, max_depth=10,debug=False):
self.root = Node(True)
self.max_depth = max_depth
self.debug = debug
# @pysnooper.snoop()
def reg_lookup(self, x):
"""
data 从root进入,然后输出end node的value
x 是numpy.array
"""
assert not self.root.isend, "root未分裂"
next_node = self.root
while not next_node.isend:
node = next_node
if x[node.split_var] <= node.split_value:
next_node = node.left
else:
next_node = node.right
next_end = node.isend
return node.y_value.mean()
def predict(self, data):
"""在lookup基础上"""
raiseNotImplementedError("tree.predict: not implemented!")
def _random_split(self, node):
var = np.random.choice(node.x_data.shape[1])
print(var)
return var, np.random.uniform(
min(node.x_data[:,var]), max(node.x_data[:,var]))
def _greedy_split(self, node):
"""贪心的方式分裂, 返回split_var 和 split_value"""
data_num, feature_num = node.x_data.shape
split_list = []
error_list = []
split_var_list = []
error = None
#
min_error = np.inf
for split_var in range(feature_num):
# 使用scipy 寻找最优的分节点
# get optimal value
var_space = list(node.x_data[:,split_var])
if (not var_space) or (min(var_space) == max(var_space)):
continue
split, error, ierr, numf = scipy.optimize.fminbound(
self._reg_error_func, min(var_space), max(var_space),
args = (split_var, node.x_data, node.y_value), full_output = 1)
split_list.append(split)
error_list.append(error)
split_var_list.append(split_var)
if min_error > error:
min_error = error
best_split_var = split_var
best_split_value = split
if self.debug:
time.sleep(0.1)
print("Node",node)
print("split_var: ", split_var_list, "\n")
print("error_list: ", error_list, "\n")
if error:
print("best_split_value:", best_split_value, "\n")
print("best_split_var:", best_split_var, "\n")
print("-"*10,"\n")
if error is not None:
return best_split_var, best_split_value
else:
return
def _reg_error_func(self, split_value, split_var, x_data, y):
ind_left = x_data[:,split_var] >= split_value
ind_right = x_data[:,split_var] < split_value
error = np.square(y[ind_left] - np.mean(y[ind_left])).sum() + np.square(
y[ind_right] - np.mean(y[ind_right])).sum()
return error
def split(self,node,x_data,y,depth):
"""
递归的分裂
"""
# 分裂节点
assert node.isend, "分裂节点必须是树的终结点"
if self.debug:
print("depth", depth)
print("y",y)
if depth >= self.max_depth:
# print("done")
return
if x_data.shape[0] <= 1:
return
# 赋值
# rand_ind =
# x_data = x_data
node.x_data = x_data
node.y_value = y
try:
split_var, split_value = self._greedy_split(node)
except:
return
# split_var, split_value = self._random_split(node)
node.split_var = split_var
node.split_value = split_value
left_ind = x_data[:,split_var] <= split_value
right_ind = x_data[:,split_var] > split_value
x_data_left = x_data[left_ind]
x_data_right = x_data[right_ind]
y_left = y[left_ind]
y_right = y[right_ind]
node.left = Node()
node.right = Node()
self.split(node.left,x_data_left,y_left,depth+1)
self.split(node.right,x_data_right,y_right,depth+1)
#----------------------------------------------------------------------
class CART:
"""
cart封装
"""
def __init__(self):
pass
def train_reg(self, X, y, max_depth=5, regularization=False, debug=False):
"""
Args:
X: train data
y: label
"""
self.y = y
self.X = X
self.tree = Tree(max_depth=max_depth, debug=debug)
self.tree.split(self.tree.root, x_data=X, y=y, depth=0)
def train_clf(self):
pass
def predict(self, X):
"""
Args:
X: data
"""
self.y_pred = np.array(list(map(lambda x: self.tree.reg_lookup(x), X)))
return self.y_pred
def reg_loss(self, y_pred=None, y=None):
if not y_pred:
y_pred = self.y_pred
y = np.array(self.y).reshape(-1)
#
return np.mean(np.abs(y_pred - y))