diff --git a/classy_vision/optim/classy_optimizer.py b/classy_vision/optim/classy_optimizer.py index 1fa80925d2..0410a1dc30 100644 --- a/classy_vision/optim/classy_optimizer.py +++ b/classy_vision/optim/classy_optimizer.py @@ -278,7 +278,10 @@ def step(self, closure: Optional[Callable] = None): Args: closure: A closure that re-evaluates the model and returns the loss """ - self.optimizer.step(closure) + if closure is None: + self.optimizer.step() + else: + self.optimizer.step(closure) def zero_grad(self): """ diff --git a/classy_vision/optim/sgd.py b/classy_vision/optim/sgd.py index 5f9ddd96b9..c78e8eb90c 100644 --- a/classy_vision/optim/sgd.py +++ b/classy_vision/optim/sgd.py @@ -15,10 +15,12 @@ class SGD(ClassyOptimizer): def __init__( self, + larc_config: Dict[str, Any] = None, lr: float = 0.1, momentum: float = 0.0, weight_decay: float = 0.0, nesterov: bool = False, + use_larc: bool = False, ): super().__init__() @@ -26,6 +28,8 @@ def __init__( self.parameters.momentum = momentum self.parameters.weight_decay = weight_decay self.parameters.nesterov = nesterov + self.parameters.use_larc = use_larc + self.larc_config = larc_config def init_pytorch_optimizer(self, model, **kwargs): super().init_pytorch_optimizer(model, **kwargs) @@ -36,6 +40,12 @@ def init_pytorch_optimizer(self, model, **kwargs): momentum=self.parameters.momentum, weight_decay=self.parameters.weight_decay, ) + if self.parameters.use_larc: + try: + from apex.parallel.LARC import LARC + except ImportError: + raise RuntimeError("Apex needed for LARC") + self.optimizer = LARC(optimizer=self.optimizer, **self.larc_config) @classmethod def from_config(cls, config: Dict[str, Any]) -> "SGD": @@ -53,6 +63,10 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD": config.setdefault("momentum", 0.0) config.setdefault("weight_decay", 0.0) config.setdefault("nesterov", False) + config.setdefault("use_larc", False) + config.setdefault( + "larc_config", {"clip": True, "eps": 1e-08, "trust_coefficient": 0.02} + ) assert ( config["momentum"] >= 0.0 @@ -62,10 +76,15 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD": assert isinstance( config["nesterov"], bool ), "Config must contain a boolean 'nesterov' param for SGD optimizer" + assert isinstance( + config["use_larc"], bool + ), "Config must contain a boolean 'use_larc' param for SGD optimizer" return cls( + larc_config=config["larc_config"], lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"], nesterov=config["nesterov"], + use_larc=config["use_larc"], )