-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Trainer with a parallel model #9578
Conversation
@@ -426,7 +426,6 @@ def __post_init__(self): | |||
|
|||
if is_torch_available() and self.device.type != "cuda" and self.fp16: | |||
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.") | |||
self._n_gpu = torch.cuda.device_count() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing from here, this is going to be completely setup in _setup_devices
@@ -381,9 +381,11 @@ def test_data_is_not_parallelized_when_model_is_parallel(self): | |||
# Make the Trainer believe it's a parallelized model | |||
model.is_parallelizable = True | |||
model.model_parallel = True | |||
trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset()) | |||
args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure the test uses batch sizes of 16.
# Check the Trainer was fooled | ||
self.assertTrue(trainer.is_model_parallel) | ||
self.assertEqual(trainer.args.n_gpu, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was still set to 2 before, so this checks it is indeed 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @sgugger
* Fix Trainer with a parallel model * More clean up
What does this PR do?
The test introduced in #9566 wasn't actually working as the default batch size is 8, not 16...
So the problem was still there, the reason because
_setup_devices
inTrainingArguments
is acached_property
, so its result is computed once and for all at init. Had to change the behavior slightly, but it should be okay since it's a private method.Fixes #9577 (model is getting wrapped into DataParallel because the value of
self.args.n_gpu
is not updated.