Skip to content

Commit

Permalink
Update to PyTorch v1.2.0 (#580)
Browse files Browse the repository at this point in the history
* Update .travis.yml

* Update .travis.yml

* Fixed tests and improved travis
  • Loading branch information
vfdev-5 authored Aug 15, 2019
1 parent f2c2441 commit f2ab1b5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
8 changes: 4 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ python:
- "3.6"

env:
- PYTORCH_PACKAGE=pytorch-cpu
- PYTORCH_PACKAGE=pytorch-nightly-cpu
- PYTORCH_CHANNEL=pytorch
- PYTORCH_CHANNEL=pytorch-nightly

stages:
- Lint check
Expand All @@ -25,7 +25,7 @@ before_install: &before_install
- conda update -q conda
# Useful for debugging any issues with conda
- conda info -a
- conda create -q -n test-environment -c pytorch python=$TRAVIS_PYTHON_VERSION $PYTORCH_PACKAGE
- conda create -q -n test-environment pytorch cpuonly python=$TRAVIS_PYTHON_VERSION -c $PYTORCH_CHANNEL
- source activate test-environment
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install enum34; fi
# Test contrib dependencies
Expand All @@ -39,7 +39,7 @@ install:
- pip install numpy mock pytest codecov pytest-cov
# Examples dependencies
- pip install matplotlib pandas
- conda install torchvision-cpu -c pytorch
- conda install torchvision -c $PYTORCH_CHANNEL
- pip install gym==0.10.11

script:
Expand Down
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ def _replicate_lr_scheduler(lr_scheduler, new_optimizer_param_groups=None):
if new_optimizer_param_groups is not None:
dummy_optimizer.param_groups = new_optimizer_param_groups
kwargs = lr_scheduler.state_dict()
del kwargs['base_lrs']
for k in ['base_lrs', '_step_count']:
del kwargs[k]
copy_lr_scheduler = lr_scheduler_cls(optimizer=dummy_optimizer, **kwargs)
copy_lr_scheduler.load_state_dict(lr_scheduler.state_dict())
return copy_lr_scheduler
Expand Down
9 changes: 7 additions & 2 deletions tests/ignite/contrib/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,11 @@ def _test(torch_lr_scheduler_cls, **kwargs):
lrs = []
lrs_true = []

trainer = Engine(lambda engine, batch: None)
def dummy_update(engine, batch):
optimizer1.step()
optimizer2.step()

trainer = Engine(dummy_update)

@trainer.on(Events.ITERATION_COMPLETED)
def torch_lr_scheduler_step(engine):
Expand Down Expand Up @@ -396,6 +400,7 @@ def save_true_lr(engine):
init_lr_scheduler_state = dict(lr_scheduler.state_dict())
copy_lr_scheduler = LRScheduler._replicate_lr_scheduler(lr_scheduler)
for _ in range(10):
optimizer.step()
lr_scheduler.step()

assert copy_lr_scheduler.state_dict() == init_lr_scheduler_state
Expand Down Expand Up @@ -444,7 +449,7 @@ def save_lr(engine):
lrs.append(optimizer.param_groups[0]['lr'])

trainer = Engine(lambda engine, batch: None)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
trainer.run([0] * 25, max_epochs=2)

Expand Down

0 comments on commit f2ab1b5

Please sign in to comment.