forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
attention_net_supervised.py
76 lines (60 loc) · 2.31 KB
/
attention_net_supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from gymnasium.spaces import Box, Discrete
import numpy as np
from rllib.models.tf.attention_net import TrXLNet
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
def bit_shift_generator(seq_length, shift, batch_size):
while True:
values = np.array([0.0, 1.0], dtype=np.float32)
seq = np.random.choice(values, (batch_size, seq_length, 1))
targets = np.squeeze(np.roll(seq, shift, axis=1).astype(np.int32))
targets[:, :shift] = 0
yield seq, targets
def train_loss(targets, outputs):
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=targets, logits=outputs
)
return tf.reduce_mean(loss)
def train_bit_shift(seq_length, num_iterations, print_every_n):
optimizer = tf.keras.optimizers.Adam(1e-3)
model = TrXLNet(
observation_space=Box(low=0, high=1, shape=(1,), dtype=np.int32),
action_space=Discrete(2),
num_outputs=2,
model_config={"max_seq_len": seq_length},
name="trxl",
num_transformer_units=1,
attention_dim=10,
num_heads=5,
head_dim=20,
position_wise_mlp_dim=20,
)
shift = 10
train_batch = 10
test_batch = 100
data_gen = bit_shift_generator(seq_length, shift=shift, batch_size=train_batch)
test_gen = bit_shift_generator(seq_length, shift=shift, batch_size=test_batch)
@tf.function
def update_step(inputs, targets):
model_out = model(
{"obs": inputs},
state=[tf.reshape(inputs, [-1, seq_length, 1])],
seq_lens=np.full(shape=(train_batch,), fill_value=seq_length),
)
optimizer.minimize(
lambda: train_loss(targets, model_out), lambda: model.trainable_variables
)
for i, (inputs, targets) in zip(range(num_iterations), data_gen):
inputs_in = np.reshape(inputs, [-1, 1])
targets_in = np.reshape(targets, [-1])
update_step(tf.convert_to_tensor(inputs_in), tf.convert_to_tensor(targets_in))
if i % print_every_n == 0:
test_inputs, test_targets = next(test_gen)
print(i, train_loss(test_targets, model(test_inputs)))
if __name__ == "__main__":
tf.enable_eager_execution()
train_bit_shift(
seq_length=20,
num_iterations=2000,
print_every_n=200,
)