forked from escape-velocity-labs/beginner_master_rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Section_10_reinforce_CartPole.py
150 lines (100 loc) · 4.15 KB
/
Section_10_reinforce_CartPole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
## Import the necessary software libraries:
import os
import torch
import gym
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch import nn as nn
from torch.optim import AdamW
from utils import test_policy_network, seed_everything, plot_stats, plot_action_probs
from parallel_env import ParallelEnv, ParallelWrapper
# select device
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
if __name__ == '__main__':
## Create and preprocess the environment
### Create the environment
_env = gym.make('CartPole-v1', render_mode='rgb_array')
dims = _env.observation_space.shape[0]
actions = _env.action_space.n
print(f"State dimensions: {dims}. Actions: {actions}")
print(f"Sample state: {_env.reset()}")
# plt.imshow(_env.render())
### Parallelize the environment
num_envs = os.cpu_count()
def create_env(env_name, seed):
env = gym.make(env_name, render_mode='rgb_array')
seed_everything(env, seed=seed)
return env
env_fns = [lambda: create_env('CartPole-v1', rank) for rank in range(num_envs)]
_penv = ParallelEnv(env_fns)
_penv.reset()
### Prepare the environment to work with PyTorch
class PreprocessEnv(ParallelWrapper):
def __init__(self, _penv):
ParallelWrapper.__init__(self, _penv)
# Wraps penv.reset
def reset(self):
state = self.venv.reset()
return torch.from_numpy(state).float().to(device)
# Wraps penv.step_async
def step_async(self, actions):
actions = actions.squeeze().cpu().numpy()
self.venv.step_async(actions)
def step_wait(self):
next_state, reward, done, info1 = self.venv.step_wait()
next_state = torch.from_numpy(next_state).float().to(device)
reward = torch.Tensor([reward]).view(-1, 1).float().to(device)
done = torch.Tensor([done]).view(-1, 1).float().to(device)
return next_state, reward, done, info1
penv = PreprocessEnv(_penv)
state = penv.reset()
_, reward, done, _ = penv.step(torch.zeros(num_envs, 1, dtype=torch.int32))
print(f"State: {state}, Reward: {reward}, Done: {done}")
### Create the policy $\pi(s)$
policy = nn.Sequential(
nn.Linear(dims, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, actions),
nn.Softmax(dim=-1)
)
policy = policy.to(device)
## Implement the algorithm
def reinforce(policy, episodes, alpha=1e-4, gamma=0.99):
optim = AdamW(policy.parameters(), lr=alpha)
stats = {'Loss': [], 'Returns': []}
for episode in tqdm(range(1, episodes+1)):
state = penv.reset()
done_b = torch.zeros((num_envs, 1)).to(device)
transitions = []
ep_return = torch.zeros((num_envs, 1)).to(device)
while not (done_b == 1).all():
action = policy(state).multinomial(1).detach()
next_state, reward, done, _ = penv.step(action)
transitions.append([state, action, (1-done_b)*reward])
ep_return += reward
done_b = torch.maximum(done_b, done)
state = next_state
G = torch.zeros((num_envs, 1)).to(device)
for t, (state_t, action_t, reward_t) in reversed(list(enumerate(transitions))):
G = reward_t + gamma * G
probs_t = policy(state_t)
log_probs_t = torch.log(probs_t + 1e-6)
action_log_prob_t = log_probs_t.gather(1, action_t)
entropy_t = -torch.sum(probs_t * log_probs_t, dim=-1, keepdim=True)
gamma_t = gamma ** t
pg_loss_t = - gamma_t * action_log_prob_t * G # negative because we want gradient ascent
total_loss_t = (pg_loss_t - 0.01 * entropy_t).mean()
policy.zero_grad()
total_loss_t.backward()
optim.step()
stats['Loss'].append(total_loss_t.item())
stats['Returns'].append(ep_return.mean().item())
return stats
penv.reset()
stats = reinforce(policy, 200)
a=1