-
Notifications
You must be signed in to change notification settings - Fork 538
/
Copy pathsearch.py
105 lines (85 loc) · 3.18 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Code for sequence generation
"""
import numpy
import copy
def gen_sample(tparams, f_init, f_next, ctx, options, trng=None, k=1, maxlen=30,
stochastic=True, argmax=False, use_unk=False):
"""
Generate a sample, using either beam search or stochastic sampling
"""
if k > 1:
assert not stochastic, 'Beam search does not support stochastic sampling'
sample = []
sample_score = []
if stochastic:
sample_score = 0
live_k = 1
dead_k = 0
hyp_samples = [[]] * live_k
hyp_scores = numpy.zeros(live_k).astype('float32')
hyp_states = []
next_state = f_init(ctx)
next_w = -1 * numpy.ones((1,)).astype('int64')
for ii in xrange(maxlen):
inps = [next_w, next_state]
ret = f_next(*inps)
next_p, next_w, next_state = ret[0], ret[1], ret[2]
if stochastic:
if argmax:
nw = next_p[0].argmax()
else:
nw = next_w[0]
sample.append(nw)
sample_score += next_p[0,nw]
if nw == 0:
break
else:
cand_scores = hyp_scores[:,None] - numpy.log(next_p)
cand_flat = cand_scores.flatten()
if not use_unk:
voc_size = next_p.shape[1]
for xx in range(len(cand_flat) / voc_size):
cand_flat[voc_size * xx + 1] = 1e20
ranks_flat = cand_flat.argsort()[:(k-dead_k)]
voc_size = next_p.shape[1]
trans_indices = ranks_flat / voc_size
word_indices = ranks_flat % voc_size
costs = cand_flat[ranks_flat]
new_hyp_samples = []
new_hyp_scores = numpy.zeros(k-dead_k).astype('float32')
new_hyp_states = []
for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)):
new_hyp_samples.append(hyp_samples[ti]+[wi])
new_hyp_scores[idx] = copy.copy(costs[idx])
new_hyp_states.append(copy.copy(next_state[ti]))
# check the finished samples
new_live_k = 0
hyp_samples = []
hyp_scores = []
hyp_states = []
for idx in xrange(len(new_hyp_samples)):
if new_hyp_samples[idx][-1] == 0:
sample.append(new_hyp_samples[idx])
sample_score.append(new_hyp_scores[idx])
dead_k += 1
else:
new_live_k += 1
hyp_samples.append(new_hyp_samples[idx])
hyp_scores.append(new_hyp_scores[idx])
hyp_states.append(new_hyp_states[idx])
hyp_scores = numpy.array(hyp_scores)
live_k = new_live_k
if new_live_k < 1:
break
if dead_k >= k:
break
next_w = numpy.array([w[-1] for w in hyp_samples])
next_state = numpy.array(hyp_states)
if not stochastic:
# dump every remaining one
if live_k > 0:
for idx in xrange(live_k):
sample.append(hyp_samples[idx])
sample_score.append(hyp_scores[idx])
return sample, sample_score