Skip to content

Commit

Permalink
Update maximum batch size for UT; Include recompute modes (microsoft#…
Browse files Browse the repository at this point in the history
…5444)

* Update MaxBatchSize and include recompute mode
* Minor fix for frontend test

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
  • Loading branch information
SherlockNoMad and Sherlock Huang authored Oct 12, 2020
1 parent dbc626d commit 60dbd8a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
16 changes: 10 additions & 6 deletions orttraining/orttraining/core/optimizer/localized_recompute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ bool GeluRecompute::SatisfyCondition(const Node& node) const {

Status GeluRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder();

for (NodeIndex i : order) {
Node& node = *graph.GetNode(i);
// Traverse backward from the bottom of the graph, so that the recompute nodes
// for lower layers are executed earlier
for (int i = static_cast<int>(node_ids.size() - 1); i >= 0; --i) {
Node& node = *graph.GetNode(node_ids[i]);

if (!SatisfyCondition(node)) {
continue;
Expand Down Expand Up @@ -70,10 +72,12 @@ bool AttentionDropoutRecompute::SatisfyCondition(const Node& node) const {

Status AttentionDropoutRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder();

for (NodeIndex i : order) {
Node& node = *graph.GetNode(i);
// Traverse backward from the bottom of the graph, so that the recompute nodes
// for lower layers are executed earlier
for (int i = static_cast<int>(node_ids.size() - 1); i >= 0; --i) {
Node& node = *graph.GetNode(node_ids[i]);

if (!SatisfyCondition(node)) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,6 @@ def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers
set_seed(seed)

# Setup ORTTrainer
loss_scaler = amp.DynamicLossScaler()
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
'graph_transformer' : {
'attn_dropout_recompute': attn_dropout,
Expand Down
25 changes: 20 additions & 5 deletions orttraining/tools/ci_test/run_batch_size_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,26 @@ def parse_args():
def main():
args = parse_args()

Config = collections.namedtuple("Config", ["enable_mixed_precision", "sequence_length", "max_batch_size"])
Config = collections.namedtuple("Config", ["enable_mixed_precision",
"sequence_length",
"max_batch_size",
"max_predictions_per_seq",
"additional_options"])
configs = [
Config(True, 128, 66),
Config(True, 512, 10),
Config(False, 128, 33),
Config(False, 512, 5),
Config(True, 128, 76, 20, ""),
Config(True, 512, 11, 80, ""),
Config(False, 128, 39, 20, ""),
Config(False, 512, 6, 80, ""),

# BertLarge Phase 1 recompute
Config(True, 128, 91, 20, "--gelu_recompute"),
Config(True, 128, 83, 20, "--attn_dropout_recompute"),
Config(True, 128, 344, 20, "--transformer_layer_recompute"),

# BertLarge Phase 2 recompute
Config(True, 512, 12, 80, "--gelu_recompute"),
Config(True, 512, 14, 80, "--attn_dropout_recompute"),
Config(True, 512, 50, 80, "--transformer_layer_recompute"),
]

# run BERT training
Expand Down Expand Up @@ -52,6 +66,7 @@ def main():
"--use_nccl",
"--seed", "42",
"--enable_grad_norm_clip=false",
config.additional_options
]

if config.enable_mixed_precision:
Expand Down

0 comments on commit 60dbd8a

Please sign in to comment.