Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Support LARC for SGD optimizer only in classy vision #408

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion classy_vision/optim/classy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ def step(self, closure: Optional[Callable] = None):
Args:
closure: A closure that re-evaluates the model and returns the loss
"""
self.optimizer.step(closure)
if closure is None:
self.optimizer.step()
else:
self.optimizer.step(closure)

def zero_grad(self):
"""
Expand Down
19 changes: 19 additions & 0 deletions classy_vision/optim/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,21 @@
class SGD(ClassyOptimizer):
def __init__(
self,
larc_config: Dict[str, Any] = None,
lr: float = 0.1,
momentum: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
use_larc: bool = False,
):
super().__init__()

self.parameters.lr = lr
self.parameters.momentum = momentum
self.parameters.weight_decay = weight_decay
self.parameters.nesterov = nesterov
self.parameters.use_larc = use_larc
self.larc_config = larc_config

def init_pytorch_optimizer(self, model, **kwargs):
super().init_pytorch_optimizer(model, **kwargs)
Expand All @@ -36,6 +40,12 @@ def init_pytorch_optimizer(self, model, **kwargs):
momentum=self.parameters.momentum,
weight_decay=self.parameters.weight_decay,
)
if self.parameters.use_larc:
try:
from apex.parallel.LARC import LARC
except ImportError:
raise RuntimeError("Apex needed for LARC")
self.optimizer = LARC(optimizer=self.optimizer, **self.larc_config)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SGD":
Expand All @@ -53,6 +63,10 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD":
config.setdefault("momentum", 0.0)
config.setdefault("weight_decay", 0.0)
config.setdefault("nesterov", False)
config.setdefault("use_larc", False)
config.setdefault(
"larc_config", {"clip": True, "eps": 1e-08, "trust_coefficient": 0.02}
)

assert (
config["momentum"] >= 0.0
Expand All @@ -62,10 +76,15 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD":
assert isinstance(
config["nesterov"], bool
), "Config must contain a boolean 'nesterov' param for SGD optimizer"
assert isinstance(
config["use_larc"], bool
), "Config must contain a boolean 'use_larc' param for SGD optimizer"

return cls(
larc_config=config["larc_config"],
lr=config["lr"],
momentum=config["momentum"],
weight_decay=config["weight_decay"],
nesterov=config["nesterov"],
use_larc=config["use_larc"],
)