From 4add1b3298e3fa753f3fd1c9c4e72180570f4cbb Mon Sep 17 00:00:00 2001 From: Priya Goyal Date: Thu, 27 Feb 2020 08:58:07 -0800 Subject: [PATCH] Support LARC for SGD optimizer only in classy vision (#408) Summary: Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/408 Pull Request resolved: https://github.com/fairinternal/ClassyVision/pull/64 In an attempt to implement SimpleCLR for contrastive losses, I needed LARC to enable large batch training. mannatsingh had already done work on this during classy vision open source release. https://our.intern.facebook.com/intern/diff/D18542126/ I initially tried using that diff to have a separate standalone LARC to work for any optimizer but it turned out to be tricky to setup correctly as we need to wrap a given optimizer in LARC (the `getattr` and `setattr` functions were not working). I talked to vreis about it and we decided that for now, we can support it for SGD only, file a task to support other optimizers too later after discussions with mannatsingh once he's back. Differential Revision: D20139718 fbshipit-source-id: c8cf4d545e6ce94cca8e646f68d519197856f675 --- classy_vision/optim/classy_optimizer.py | 5 ++++- classy_vision/optim/sgd.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/classy_vision/optim/classy_optimizer.py b/classy_vision/optim/classy_optimizer.py index 1fa80925d2..0410a1dc30 100644 --- a/classy_vision/optim/classy_optimizer.py +++ b/classy_vision/optim/classy_optimizer.py @@ -278,7 +278,10 @@ def step(self, closure: Optional[Callable] = None): Args: closure: A closure that re-evaluates the model and returns the loss """ - self.optimizer.step(closure) + if closure is None: + self.optimizer.step() + else: + self.optimizer.step(closure) def zero_grad(self): """ diff --git a/classy_vision/optim/sgd.py b/classy_vision/optim/sgd.py index 5f9ddd96b9..c78e8eb90c 100644 --- a/classy_vision/optim/sgd.py +++ b/classy_vision/optim/sgd.py @@ -15,10 +15,12 @@ class SGD(ClassyOptimizer): def __init__( self, + larc_config: Dict[str, Any] = None, lr: float = 0.1, momentum: float = 0.0, weight_decay: float = 0.0, nesterov: bool = False, + use_larc: bool = False, ): super().__init__() @@ -26,6 +28,8 @@ def __init__( self.parameters.momentum = momentum self.parameters.weight_decay = weight_decay self.parameters.nesterov = nesterov + self.parameters.use_larc = use_larc + self.larc_config = larc_config def init_pytorch_optimizer(self, model, **kwargs): super().init_pytorch_optimizer(model, **kwargs) @@ -36,6 +40,12 @@ def init_pytorch_optimizer(self, model, **kwargs): momentum=self.parameters.momentum, weight_decay=self.parameters.weight_decay, ) + if self.parameters.use_larc: + try: + from apex.parallel.LARC import LARC + except ImportError: + raise RuntimeError("Apex needed for LARC") + self.optimizer = LARC(optimizer=self.optimizer, **self.larc_config) @classmethod def from_config(cls, config: Dict[str, Any]) -> "SGD": @@ -53,6 +63,10 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD": config.setdefault("momentum", 0.0) config.setdefault("weight_decay", 0.0) config.setdefault("nesterov", False) + config.setdefault("use_larc", False) + config.setdefault( + "larc_config", {"clip": True, "eps": 1e-08, "trust_coefficient": 0.02} + ) assert ( config["momentum"] >= 0.0 @@ -62,10 +76,15 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD": assert isinstance( config["nesterov"], bool ), "Config must contain a boolean 'nesterov' param for SGD optimizer" + assert isinstance( + config["use_larc"], bool + ), "Config must contain a boolean 'use_larc' param for SGD optimizer" return cls( + larc_config=config["larc_config"], lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"], nesterov=config["nesterov"], + use_larc=config["use_larc"], )