Skip to content

Commit

Permalink
update baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
WentseChen committed Nov 22, 2023
1 parent a12e89a commit 1de8bcb
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 87 deletions.
70 changes: 29 additions & 41 deletions openrl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,48 +275,36 @@ def prepare_loss(
# action_masks_batch,
# active_masks_batch,
# )

if self.last_logp is None:
self.last_logp = old_action_log_probs_batch

# if self.use_joint_action_loss:
# action_log_probs_copy = (
# action_log_probs.reshape(-1, self.agent_num, action_log_probs.shape[-1])
# .sum(dim=(1, -1), keepdim=True)
# .reshape(-1, 1)
# )
# old_action_log_probs_batch_copy = (
# old_action_log_probs_batch.reshape(
# -1, self.agent_num, old_action_log_probs_batch.shape[-1]
# )
# .sum(dim=(1, -1), keepdim=True)
# .reshape(-1, 1)
# )

# active_masks_batch = active_masks_batch.reshape(-1, self.agent_num, 1)
# active_masks_batch = active_masks_batch[:, 0, :]

# ratio = torch.exp(action_log_probs_copy - old_action_log_probs_batch_copy)
# else:
# ratio = torch.exp(action_log_probs - old_action_log_probs_batch)

# if self.dual_clip_ppo:
# ratio = torch.min(ratio, self.dual_clip_coeff)
if self.use_joint_action_loss:
action_log_probs_copy = (
action_log_probs.reshape(-1, self.agent_num, action_log_probs.shape[-1])
.sum(dim=(1, -1), keepdim=True)
.reshape(-1, 1)
)
old_action_log_probs_batch_copy = (
old_action_log_probs_batch.reshape(
-1, self.agent_num, old_action_log_probs_batch.shape[-1]
)
.sum(dim=(1, -1), keepdim=True)
.reshape(-1, 1)
)

# surr1 = ratio * adv_targ
# surr2 = (
# torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
# )
active_masks_batch = active_masks_batch.reshape(-1, self.agent_num, 1)
active_masks_batch = active_masks_batch[:, 0, :]

# surr_final = torch.min(surr1, surr2)
ratio = torch.exp(action_log_probs_copy - old_action_log_probs_batch_copy)
else:
ratio = torch.exp(action_log_probs - old_action_log_probs_batch)

# clip_param = 100.
ratio_ppo = torch.exp(self.last_logp - old_action_log_probs_batch)
ratio_aug = torch.exp(action_log_probs - self.last_logp)
ratio = ratio_aug * ratio_ppo
if self.dual_clip_ppo:
ratio = torch.min(ratio, self.dual_clip_coeff)

surr1 = ratio * adv_targ
surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
surr2 = (
torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
)

surr_final = torch.min(surr1, surr2)

if self._use_policy_active_masks:
Expand Down Expand Up @@ -348,9 +336,9 @@ def prepare_loss(
else:
policy_loss = policy_action_loss

ratio_aux = torch.exp(self.last_logp - old_action_log_probs_batch)
loss_aux = ratio_aux * (-action_log_probs)
policy_loss = policy_loss + loss_aux.mean() * 0.1
# ratio_aux = torch.exp(logp - old_action_log_probs_batch)
# loss_aux = ratio_aux * (-action_log_probs)
# policy_loss = policy_loss + loss_aux.mean() * 0.1

# critic update
if self._use_share_model:
Expand All @@ -371,7 +359,8 @@ def prepare_loss(
policy_loss, dist_entropy, value_loss, turn_on
)

self.last_logp = action_log_probs.detach()
ratio_aug = ratio
ratio_ppo = ratio

return loss_list, value_loss, policy_loss, dist_entropy, ratio, ratio_aug, ratio_ppo

Expand Down Expand Up @@ -436,7 +425,6 @@ def train_ppo(self, buffer, turn_on):
train_info["reduced_value_loss"] = 0
train_info["reduced_policy_loss"] = 0

self.last_logp = None
for _ in range(self.ppo_epoch):
data_generator = self.get_data_generator(buffer, advantages)

Expand Down
93 changes: 47 additions & 46 deletions openrl/buffers/replay_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,52 +1071,53 @@ def recurrent_generator(self, advantages, num_mini_batch, data_chunk_length):

epi_len, env_num, agent_num, _ = policy_obs.shape

# for smac
order = np.arange(agent_num)
np.random.shuffle(order)
policy_obs = policy_obs.reshape([epi_len*env_num*agent_num, -1])
enemy = policy_obs[:,4:49].copy()
enemy = enemy.reshape([epi_len*env_num*agent_num, self.num_agents, 9])
enemy = np.transpose(enemy, [1,0,2])
enemy = enemy[order]
enemy = np.transpose(enemy, [1,0,2])
enemy = enemy.reshape([epi_len*env_num*agent_num, 45])
policy_obs[:,4:49] = enemy
policy_obs = policy_obs.reshape([epi_len*env_num*agent_num, -1])
allies = policy_obs[:,49:85].copy()
allies = allies.reshape([epi_len*env_num, agent_num, self.num_agents-1, 9])
for agent_idx in range(self.num_agents):
agent_order = order.copy()
agent_order = np.delete(agent_order, np.argwhere(agent_order==agent_idx))
agent_order = np.where(agent_order>agent_idx,agent_order-1,agent_order)
ally = allies[:,agent_idx]
ally = np.transpose(ally, [1,0,2])
ally = ally[agent_order]
ally = np.transpose(ally, [1,0,2])
ally = ally.reshape([epi_len*env_num, self.num_agents-1, 9])
allies[:,agent_idx] = ally
allies = allies.reshape([epi_len*env_num*agent_num, -1])
policy_obs[:,49:85] = allies
policy_obs = policy_obs.reshape([epi_len, env_num, agent_num, -1])

order = np.arange(agent_num)
np.random.shuffle(order)
critic_obs = critic_obs.reshape([epi_len*env_num*agent_num, -1])
enemy = critic_obs[:,40:75].copy()
enemy = enemy.reshape([epi_len*env_num*agent_num, self.num_agents, 7])
enemy = np.transpose(enemy, [1,0,2])
enemy = enemy[order]
enemy = np.transpose(enemy, [1,0,2])
enemy = enemy.reshape([epi_len*env_num*agent_num, 35])
critic_obs[:,40:75] = enemy
allies = critic_obs[:,0:40].copy()
allies = allies.reshape([epi_len*env_num*agent_num, self.num_agents, 8])
allies = np.transpose(allies, [1,0,2])
allies = allies[order]
allies = np.transpose(allies, [1,0,2])
allies = allies.reshape([epi_len*env_num*agent_num, 40])
critic_obs[:,0:40] = allies
critic_obs = critic_obs.reshape([epi_len, env_num, agent_num, -1])
# # for smac
# order = np.arange(agent_num)
# np.random.shuffle(order)
# policy_obs = policy_obs.reshape([epi_len*env_num*agent_num, -1])
# enemy = policy_obs[:,4:49].copy()
# enemy = enemy.reshape([epi_len*env_num*agent_num, self.num_agents, 9])
# enemy = np.transpose(enemy, [1,0,2])
# enemy = enemy[order]
# enemy = np.transpose(enemy, [1,0,2])
# enemy = enemy.reshape([epi_len*env_num*agent_num, 45])
# policy_obs[:,4:49] = enemy
# policy_obs = policy_obs.reshape([epi_len*env_num*agent_num, -1])
# allies = policy_obs[:,49:85].copy()
# allies = allies.reshape([epi_len*env_num, agent_num, self.num_agents-1, 9])
# for agent_idx in range(self.num_agents):
# agent_order = order.copy()
# agent_order = np.delete(agent_order, np.argwhere(agent_order==agent_idx))
# agent_order = np.where(agent_order>agent_idx,agent_order-1,agent_order)
# ally = allies[:,agent_idx]
# ally = np.transpose(ally, [1,0,2])
# ally = ally[agent_order]
# ally = np.transpose(ally, [1,0,2])
# ally = ally.reshape([epi_len*env_num, self.num_agents-1, 9])
# allies[:,agent_idx] = ally
# allies = allies.reshape([epi_len*env_num*agent_num, -1])
# policy_obs[:,49:85] = allies
# policy_obs = policy_obs.reshape([epi_len, env_num, agent_num, -1])

# for agent_id in range(self.num_agents):
# order = np.arange(agent_num)
# np.random.shuffle(order)
# critic_o = critic_obs[:,:,agent_id].copy()
# critic_o = critic_o.reshape([epi_len*env_num, -1])
# enemy = critic_o[:,40:75].copy()
# enemy = enemy.reshape([epi_len*env_num, self.num_agents, 7])
# enemy = np.transpose(enemy, [1,0,2])
# enemy = enemy[order]
# enemy = np.transpose(enemy, [1,0,2])
# enemy = enemy.reshape([epi_len, env_num, 35])
# critic_obs[:,:,agent_id,40:75] = enemy
# allies = critic_o[:,0:40].copy()
# allies = allies.reshape([epi_len*env_num, self.num_agents, 8])
# allies = np.transpose(allies, [1,0,2])
# allies = allies[order]
# allies = np.transpose(allies, [1,0,2])
# allies = allies.reshape([epi_len, env_num, 40])
# critic_obs[:,:,agent_id,0:40] = allies

# # for mpe
# order = np.arange(agent_num)
Expand Down

0 comments on commit 1de8bcb

Please sign in to comment.