diff --git a/examples/distributed_training_with_flexible_runner.py b/examples/distributed_training_with_flexible_runner.py index 20e7286695..99d2cf257d 100644 --- a/examples/distributed_training_with_flexible_runner.py +++ b/examples/distributed_training_with_flexible_runner.py @@ -94,6 +94,9 @@ def main(): initial_scale_power=15, ), inputs_to_half=[0], + # bf16=dict( + # enabled=True, + # ), zero_optimization=dict( stage=3, allgather_partitions=True, diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 2565d18c7d..a439a4a952 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -18,7 +18,7 @@ from mmengine.optim import BaseOptimWrapper, _ParamScheduler from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS, STRATEGIES) -from mmengine.utils import digit_version, get_git_hash +from mmengine.utils import apply_to, digit_version, get_git_hash from .base import BaseStrategy @@ -188,18 +188,22 @@ def _cast_inputs_half(self, inputs: Union[list, tuple, dict, None]): if self._inputs_to_half is None: return inputs + dtype = next(self.model.parameters()).dtype if isinstance(inputs, (list, tuple)): new_inputs = [] for i, v in enumerate(inputs): if i in self._inputs_to_half: - new_inputs.append(v.half()) + new_inputs.append( + apply_to(v, lambda x: hasattr(x, 'to'), + lambda x: x.to(dtype))) else: new_inputs.append(v) return inputs.__class__(new_inputs) elif isinstance(inputs, dict): for k, v in inputs.items(): if k in self._inputs_to_half: - inputs[k] = v.half() + inputs[k] = apply_to(v, lambda x: hasattr(x, 'to'), + lambda x: x.to(dtype)) return inputs else: raise TypeError('inputs should be list, tuple or dict, '