Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify ScaledAdam interface #782

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
change current abbreviation
glynpu committed Dec 21, 2022
commit b011448a287784fa14d44d9c6a5aeb8dc9705c27
34 changes: 17 additions & 17 deletions egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
Original file line number Diff line number Diff line change
@@ -132,7 +132,7 @@ class ScaledAdam(BatchedOptimizer):

Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
Unlike common optimizers, which accepts model.parameters() or groups of parameters(),
Unlike common optimizers, which accept model.parameters() or groups of parameters(),
this optimizer could accept model.named_parameters() or groups of named_parameters().
See comments of function _get_names_of_parameters for its 4 possible cases.
lr: The learning rate. We will typically use a learning rate schedule that starts
@@ -259,7 +259,7 @@ def _get_names_of_parameters(
# p is short for param.
# np is short for named_param.
# p_or_np is short for param_or_named_param.
# curt is short for current.
# cur is short for current.
# group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
# groups is a List[group]

@@ -277,8 +277,8 @@ def _get_names_of_parameters(
if not isinstance(iterable_or_groups[0], dict):
# case 1 or case 3,
# the input is an iterable of parameter or named parameter.
param_iterable_curt_group = []
param_names_curt_group = []
param_iterable_cur_group = []
param_names_cur_group = []
for p_or_np in iterable_or_groups:
if isinstance(p_or_np, tuple):
# case 3
@@ -290,17 +290,17 @@ def _get_names_of_parameters(
# Assign a dummy name as a placeholder
name = "foo"
self.show_dominant_parameters = False
param_iterable_curt_group.append(param)
param_names_curt_group.append(name)
param_groups.append({"params": param_iterable_curt_group})
param_groups_names.append(param_names_curt_group)
param_iterable_cur_group.append(param)
param_names_cur_group.append(name)
param_groups.append({"params": param_iterable_cur_group})
param_groups_names.append(param_names_cur_group)
else:
# case 2 or case 4
# the input is groups of parameter or named parameter.
for p_or_np_curt_group in iterable_or_groups:
param_iterable_curt_group = []
param_names_curt_group = []
p_or_np_iterable = p_or_np_curt_group["params"]
for p_or_np_cur_group in iterable_or_groups:
param_iterable_cur_group = []
param_names_cur_group = []
p_or_np_iterable = p_or_np_cur_group["params"]
for p_or_np in p_or_np_iterable:
if isinstance(p_or_np, tuple):
# case 4
@@ -312,18 +312,18 @@ def _get_names_of_parameters(
# Assign a dummy name as a placeholder
name = "foo"
self.show_dominant_parameters = False
param_iterable_curt_group.append(param)
param_names_curt_group.append(name)
param_iterable_cur_group.append(param)
param_names_cur_group.append(name)

# The original `params` filed contains named_parameters.
# After following assignment,
# it will be changed to an iterable of parameter,
# and other fileds, if exist, are still original values.
# So param_groups could be used to initialize
# an underlying torch.Optimizer later.
p_or_np_curt_group["params"] = param_iterable_curt_group
param_groups.append(p_or_np_curt_group)
param_groups_names.append(param_names_curt_group)
p_or_np_cur_group["params"] = param_iterable_cur_group
param_groups.append(p_or_np_cur_group)
param_groups_names.append(param_names_cur_group)
return param_groups, param_groups_names

def __setstate__(self, state):