Skip to content

Commit

Permalink
smac&mpe data aug done
Browse files Browse the repository at this point in the history
  • Loading branch information
WentseChen committed Nov 20, 2023
1 parent 8f30dd3 commit a12e89a
Show file tree
Hide file tree
Showing 6 changed files with 480 additions and 235 deletions.
4 changes: 3 additions & 1 deletion examples/mpe/mpe_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ episode_length: 25
run_dir: ./run_results/
experiment_name: train_mpe
log_interval: 10
data_chunk_length: 25
ppo_epoch: 10
use_recurrent_policy: true
use_joint_action_loss: false
use_valuenorm: true
use_adv_normalize: true
wandb_entity: openrl-lab
wandb_entity: cwz19
6 changes: 3 additions & 3 deletions examples/smac/smac_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ seed: 0
run_dir: ./run_results/
experiment_name: smac_mappo

lr: 5e-4
lr: 1e-3
critic_lr: 1e-3

data_chunk_length: 8
data_chunk_length: 400

episode_length: 400
ppo_epoch: 5
ppo_epoch: 20
log_interval: 10
actor_train_interval_step: 1

Expand Down
109 changes: 77 additions & 32 deletions openrl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def ppo_update(self, sample, turn_on=True):
(
critic_obs_batch,
obs_batch,
original_obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
Expand All @@ -75,9 +76,12 @@ def ppo_update(self, sample, turn_on=True):
policy_loss,
dist_entropy,
ratio,
ratio_aug,
ratio_ppo,
) = self.prepare_loss(
critic_obs_batch,
obs_batch,
original_obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
Expand All @@ -93,9 +97,10 @@ def ppo_update(self, sample, turn_on=True):
for loss in loss_list:
self.algo_module.scaler.scale(loss).backward()
else:
loss_list, value_loss, policy_loss, dist_entropy, ratio = self.prepare_loss(
loss_list, value_loss, policy_loss, dist_entropy, ratio, ratio_aug, ratio_ppo, = self.prepare_loss(
critic_obs_batch,
obs_batch,
original_obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
Expand Down Expand Up @@ -154,6 +159,8 @@ def ppo_update(self, sample, turn_on=True):
dist_entropy,
actor_grad_norm,
ratio,
ratio_aug,
ratio_ppo,
)

def cal_value_loss(
Expand Down Expand Up @@ -217,6 +224,7 @@ def prepare_loss(
self,
critic_obs_batch,
obs_batch,
original_obs_batch,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
Expand Down Expand Up @@ -257,36 +265,58 @@ def prepare_loss(
active_masks_batch,
critic_masks_batch=critic_masks_batch,
)

# logp = self.algo_module.models["policy"](
# "get_log_prob",
# original_obs_batch,
# rnn_states_batch,
# actions_batch,
# masks_batch,
# action_masks_batch,
# active_masks_batch,
# )

if self.last_logp is None:
self.last_logp = old_action_log_probs_batch

# if self.use_joint_action_loss:
# action_log_probs_copy = (
# action_log_probs.reshape(-1, self.agent_num, action_log_probs.shape[-1])
# .sum(dim=(1, -1), keepdim=True)
# .reshape(-1, 1)
# )
# old_action_log_probs_batch_copy = (
# old_action_log_probs_batch.reshape(
# -1, self.agent_num, old_action_log_probs_batch.shape[-1]
# )
# .sum(dim=(1, -1), keepdim=True)
# .reshape(-1, 1)
# )

# active_masks_batch = active_masks_batch.reshape(-1, self.agent_num, 1)
# active_masks_batch = active_masks_batch[:, 0, :]

# ratio = torch.exp(action_log_probs_copy - old_action_log_probs_batch_copy)
# else:
# ratio = torch.exp(action_log_probs - old_action_log_probs_batch)

if self.use_joint_action_loss:
action_log_probs_copy = (
action_log_probs.reshape(-1, self.agent_num, action_log_probs.shape[-1])
.sum(dim=(1, -1), keepdim=True)
.reshape(-1, 1)
)
old_action_log_probs_batch_copy = (
old_action_log_probs_batch.reshape(
-1, self.agent_num, old_action_log_probs_batch.shape[-1]
)
.sum(dim=(1, -1), keepdim=True)
.reshape(-1, 1)
)
# if self.dual_clip_ppo:
# ratio = torch.min(ratio, self.dual_clip_coeff)

active_masks_batch = active_masks_batch.reshape(-1, self.agent_num, 1)
active_masks_batch = active_masks_batch[:, 0, :]
# surr1 = ratio * adv_targ
# surr2 = (
# torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
# )

ratio = torch.exp(action_log_probs_copy - old_action_log_probs_batch_copy)
else:
ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
# surr_final = torch.min(surr1, surr2)

if self.dual_clip_ppo:
ratio = torch.min(ratio, self.dual_clip_coeff)
# clip_param = 100.
ratio_ppo = torch.exp(self.last_logp - old_action_log_probs_batch)
ratio_aug = torch.exp(action_log_probs - self.last_logp)
ratio = ratio_aug * ratio_ppo

surr1 = ratio * adv_targ
surr2 = (
torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
)

surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
surr_final = torch.min(surr1, surr2)

if self._use_policy_active_masks:
Expand Down Expand Up @@ -318,6 +348,10 @@ def prepare_loss(
else:
policy_loss = policy_action_loss

ratio_aux = torch.exp(self.last_logp - old_action_log_probs_batch)
loss_aux = ratio_aux * (-action_log_probs)
policy_loss = policy_loss + loss_aux.mean() * 0.1

# critic update
if self._use_share_model:
value_normalizer = self.algo_module.models["model"].value_normalizer
Expand All @@ -336,7 +370,10 @@ def prepare_loss(
loss_list = self.construct_loss_list(
policy_loss, dist_entropy, value_loss, turn_on
)
return loss_list, value_loss, policy_loss, dist_entropy, ratio

self.last_logp = action_log_probs.detach()

return loss_list, value_loss, policy_loss, dist_entropy, ratio, ratio_aug, ratio_ppo

def get_data_generator(self, buffer, advantages):
if self._use_recurrent_policy:
Expand Down Expand Up @@ -368,15 +405,16 @@ def train_ppo(self, buffer, turn_on):
].module.value_normalizer
else:
value_normalizer = self.algo_module.get_critic_value_normalizer()
advantages = buffer.returns[:-1] - value_normalizer.denormalize(
buffer.value_preds[:-1]
)

returns = buffer.returns[:-1]
values = value_normalizer.denormalize(buffer.value_preds[:-1])
advantages = returns - values

advantages = advantages.mean(axis=2, keepdims=True)
advantages = np.repeat(advantages, self.agent_num, axis=2)
else:
advantages = buffer.returns[:-1] - buffer.value_preds[:-1]

if self._use_adv_normalize:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)

advantages_copy = advantages.copy()
advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
mean_advantages = np.nanmean(advantages_copy)
Expand All @@ -392,10 +430,13 @@ def train_ppo(self, buffer, turn_on):
train_info["actor_grad_norm"] = 0
train_info["critic_grad_norm"] = 0
train_info["ratio"] = 0
train_info["ratio_ppo"] = 0
train_info["ratio_aug"] = 0
if self.world_size > 1:
train_info["reduced_value_loss"] = 0
train_info["reduced_policy_loss"] = 0

self.last_logp = None
for _ in range(self.ppo_epoch):
data_generator = self.get_data_generator(buffer, advantages)

Expand All @@ -407,6 +448,8 @@ def train_ppo(self, buffer, turn_on):
dist_entropy,
actor_grad_norm,
ratio,
ratio_aug,
ratio_ppo,
) = self.ppo_update(sample, turn_on)

if self.world_size > 1:
Expand All @@ -424,6 +467,8 @@ def train_ppo(self, buffer, turn_on):
train_info["actor_grad_norm"] += actor_grad_norm
train_info["critic_grad_norm"] += critic_grad_norm
train_info["ratio"] += ratio.mean().item()
train_info["ratio_ppo"] += ratio_ppo.mean().item()
train_info["ratio_aug"] += ratio_aug.mean().item()

num_updates = self.ppo_epoch * self.num_mini_batch

Expand Down
Loading

0 comments on commit a12e89a

Please sign in to comment.