-
Notifications
You must be signed in to change notification settings - Fork 725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reinitialize GAIL dataloader after saving #847
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternative approach would be to implement a custom pickler in ExpertDataset
and drop the prepare_pickle
method entirely.
I think something like this should do the trick in ExpertDataset
:
def __getstate__(self):
state = dict(__dict__)
del state['dataloader']
del state['train_loader']
del state['val_loader']
return state
That said, since I believe the intention is to drop TRPO MPI and GAIL in Stable Baselines3, I'm coding this as approve -- up to you if you want to spend the time to refactor this.
very good point and much cleaner, I will do that ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the revision, few minor suggested changes.
Excludes processes that are not pickleable | ||
""" | ||
# Remove processes in order to pickle the dataset. | ||
excluded = {'dataloader', 'train_loader', 'val_loader'} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider making this a module or class constant since it's used in __getstate__
and __setstate__
?
""" | ||
# Remove processes in order to pickle the dataset. | ||
excluded = {'dataloader', 'train_loader', 'val_loader'} | ||
state = {key:val for key, val in self.__dict__.items() if key not in excluded} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think our codebase normally has spaces after :
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true, but I'm not sure for list comprehension... I will check
tests/test_gail.py
Outdated
dataset = ExpertDataset(expert_path=EXPERT_PATH_PENDULUM, traj_limitation=10, | ||
sequential_preprocessing=True, verbose=0) | ||
model = GAIL("MlpPolicy", "Pendulum-v0", dataset) | ||
checkpoint_callback = CheckpointCallback(save_freq=500, save_path='./logs/gail/', name_prefix='gail') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe use a TemporaryDirectory instance rather than a hardcoded path? This would remove the need to rmtree
afterwards as the contextmanager automatically cleans it up.
We use hardcoded paths all over the tests, but one day I dream of changing this -- it's caused me headaches in the past from race conditions between tests when trying to parallelize.
@AdamGleave I took advantage of your comments to reformat some part of the code (using flake8 to spot the issues ;)) I may do a PR soon to sync what I'm doing with SB3. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes LGTM apart from the test failure
Tests fail because of python 3.5 :/ (it runs locally with python 3.6). Will fix that... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Motivation and Context
closes #830
Types of changes
Checklist:
pytest
andpytype
both pass (by runningmake pytest
andmake type
).