Skip to content

Commit

Permalink
[Feature] Train group map
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jan 11, 2024
1 parent 6419c79 commit bffd97e
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import copy
import importlib

import os
Expand Down Expand Up @@ -391,6 +392,7 @@ def _setup_task(self):
self.action_mask_spec = self.task.action_mask_spec(test_env)
self.action_spec = self.task.action_spec(test_env)
self.group_map = self.task.group_map(test_env)
self.train_group_map = copy.deepcopy(self.group_map)
self.max_steps = self.task.max_steps(test_env)

transforms = [self.task.get_reward_sum_transform(test_env)]
Expand Down Expand Up @@ -548,7 +550,7 @@ def _collection_loop(self):

# Loop over groups
training_start = time.time()
for group in self.group_map.keys():
for group in self.train_group_map.keys():
group_batch = batch.exclude(*self._get_excluded_keys(group))
group_batch = self.algorithm.process_batch(group, group_batch)
group_batch = group_batch.reshape(-1)
Expand Down

0 comments on commit bffd97e

Please sign in to comment.