This repository has been archived by the owner on Jun 1, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 216
/
random.py
149 lines (109 loc) · 4.41 KB
/
random.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import hashlib
import six
from .base import PlanOutOpSimple
class PlanOutOpRandom(PlanOutOpSimple):
LONG_SCALE = float(0xFFFFFFFFFFFFFFF)
def getUnit(self, appended_unit=None):
unit = self.getArgMixed('unit')
if type(unit) is not list:
unit = [unit]
if appended_unit is not None:
unit += [appended_unit]
return unit
def getHash(self, appended_unit=None):
if 'full_salt' in self.args:
full_salt = self.getArgString('full_salt') + '.' # do typechecking
else:
full_salt = '%s.%s%s' % (
self.mapper.experiment_salt,
self.getArgString('salt'),
self.mapper.salt_sep)
unit_str = '.'.join(map(str, self.getUnit(appended_unit)))
hash_str = '%s%s' % (full_salt, unit_str)
if not isinstance(hash_str, six.binary_type):
hash_str = hash_str.encode("ascii")
return int(hashlib.sha1(hash_str).hexdigest()[:15], 16)
def getUniform(self, min_val=0.0, max_val=1.0, appended_unit=None):
zero_to_one = self.getHash(appended_unit) / PlanOutOpRandom.LONG_SCALE
return min_val + (max_val - min_val) * zero_to_one
class RandomFloat(PlanOutOpRandom):
def simpleExecute(self):
min_val = self.getArgFloat('min')
max_val = self.getArgFloat('max')
return self.getUniform(min_val, max_val)
class RandomInteger(PlanOutOpRandom):
def simpleExecute(self):
min_val = self.getArgInt('min')
max_val = self.getArgInt('max')
return min_val + self.getHash() % (max_val - min_val + 1)
class BernoulliTrial(PlanOutOpRandom):
def simpleExecute(self):
p = self.getArgNumeric('p')
assert p >= 0 and p <= 1.0, \
'%s: p must be a number between 0.0 and 1.0, not %s!' \
% (self.__class__, p)
rand_val = self.getUniform(0.0, 1.0)
return 1 if rand_val <= p else 0
class BernoulliFilter(PlanOutOpRandom):
def simpleExecute(self):
p = self.getArgNumeric('p')
values = self.getArgList('choices')
assert p >= 0 and p <= 1.0, \
'%s: p must be a number between 0.0 and 1.0, not %s!' \
% (self.__class__, p)
if len(values) == 0:
return []
return [i for i in values if self.getUniform(0.0, 1.0, i) <= p]
class UniformChoice(PlanOutOpRandom):
def simpleExecute(self):
choices = self.getArgList('choices')
if len(choices) == 0:
return []
rand_index = self.getHash() % len(choices)
return choices[rand_index]
class WeightedChoice(PlanOutOpRandom):
def simpleExecute(self):
choices = self.getArgList('choices')
weights = self.getArgList('weights')
if len(choices) == 0:
return []
cum_weights = dict(enumerate(weights))
cum_sum = 0.0
for index in cum_weights:
cum_sum += cum_weights[index]
cum_weights[index] = cum_sum
stop_value = self.getUniform(0.0, cum_sum)
for index in cum_weights:
if stop_value <= cum_weights[index]:
return choices[index]
class BaseSample(PlanOutOpRandom):
def copyChoices(self):
return [x for x in self.getArgList('choices')]
def getNumDraws(self, choices):
if 'draws' in self.args:
num_draws = self.getArgInt('draws')
assert num_draws <= len(choices), \
"%s: cannot make %s draws when only %s choices are available" \
% (self.__class__, num_draws, len(choices))
return num_draws
else:
return len(choices)
class FastSample(BaseSample):
def simpleExecute(self):
choices = self.copyChoices()
num_draws = self.getNumDraws(choices)
stopping_point = len(choices) - num_draws
for i in six.moves.range(len(choices) - 1, 0, -1):
j = self.getHash(i) % (i + 1)
choices[i], choices[j] = choices[j], choices[i]
if stopping_point == i:
return choices[i:]
return choices[:num_draws]
class Sample(BaseSample):
def simpleExecute(self):
choices = self.copyChoices()
num_draws = self.getNumDraws(choices)
for i in six.moves.range(len(choices) - 1, 0, -1):
j = self.getHash(i) % (i + 1)
choices[i], choices[j] = choices[j], choices[i]
return choices[:num_draws]