Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up Model tests by 20% #5574

Merged
merged 9 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified test/expect/ModelTester.test_alexnet_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_convnext_base_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_convnext_large_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_b6_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_b7_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_v2_l_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_v2_m_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_x_16gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_x_32gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_y_128gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_y_16gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_y_32gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_resnext101_32x8d_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_wide_resnet101_2_expect.pkl
Binary file not shown.
107 changes: 61 additions & 46 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import pkgutil
import sys
import traceback
import warnings
from collections import OrderedDict
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -119,27 +118,16 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None):
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)


def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False, eager_out=None):
"""Check that a nn.Module's results in TorchScript match eager and that it can be exported"""

def assert_export_import_module(m, args):
"""Check that the results of a model are the same after saving and loading"""

def get_export_import_copy(m):
"""Save and load a TorchScript model"""
with TemporaryDirectory() as dir:
path = os.path.join(dir, "script.pt")
m.save(path)
imported = torch.jit.load(path)
return imported

m_import = get_export_import_copy(m)
with torch.no_grad(), freeze_rng_state():
results = m(*args)
with torch.no_grad(), freeze_rng_state():
results_from_imported = m_import(*args)
tol = 3e-4
torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol)
def get_export_import_copy(m):
"""Save and load a TorchScript model"""
with TemporaryDirectory() as dir:
path = os.path.join(dir, "script.pt")
m.save(path)
imported = torch.jit.load(path)
return imported

TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1"
if not TEST_WITH_SLOW or skip:
Expand All @@ -157,23 +145,33 @@ def get_export_import_copy(m):

sm = torch.jit.script(nn_module)

with torch.no_grad(), freeze_rng_state():
eager_out = nn_module(*args)
if eager_out is None:
with torch.no_grad(), freeze_rng_state():
if unwrapper:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this line needed? It looks like it eager_out wasn't unwrapped before

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not mandatory to have it for most of the existing models but this is only due to implementation details. Since this is a general purpose tool for checking JIT-scriptability, I opted for consistently unwrapping the output in all place.

FYI the reason many models don't have to get unwrapped is due to idioms like this:

@torch.jit.unused
def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
if self.training and self.aux_logits:
return _GoogLeNetOutputs(x, aux2, aux1)
else:
return x # type: ignore[return-value]
def forward(self, x: Tensor) -> GoogLeNetOutputs:
x = self._transform_input(x)
x, aux1, aux2 = self._forward(x)
aux_defined = self.training and self.aux_logits
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
return GoogLeNetOutputs(x, aux2, aux1)
else:
return self.eager_outputs(x, aux2, aux1)

New non-detection models don't use this idiom any more (returning different output depending on jit/training flag), so I think it's safer to handle it explicitly.

eager_out = nn_module(*args)

with torch.no_grad(), freeze_rng_state():
script_out = sm(*args)
if unwrapper:
script_out = unwrapper(script_out)

torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)
assert_export_import_module(sm, args)

m_import = get_export_import_copy(sm)
with torch.no_grad(), freeze_rng_state():
imported_script_out = m_import(*args)
if unwrapper:
imported_script_out = unwrapper(imported_script_out)

torch.testing.assert_close(script_out, imported_script_out, atol=3e-4, rtol=3e-4)


def _check_fx_compatible(model, inputs):
def _check_fx_compatible(model, inputs, eager_out=None):
model_fx = torch.fx.symbolic_trace(model)
out = model(inputs)
out_fx = model_fx(inputs)
torch.testing.assert_close(out, out_fx)
if eager_out is None:
eager_out = model(inputs)
fx_out = model_fx(inputs)
torch.testing.assert_close(eager_out, fx_out)


def _check_input_backprop(model, inputs):
Expand Down Expand Up @@ -298,6 +296,24 @@ def _check_input_backprop(model, inputs):
"rpn_post_nms_top_n_test": 1000,
},
}
# speeding up slow models:
slow_models = [
"convnext_base",
"convnext_large",
"resnext101_32x8d",
"wide_resnet101_2",
"efficientnet_b6",
"efficientnet_b7",
"efficientnet_v2_m",
"efficientnet_v2_l",
"regnet_y_16gf",
"regnet_y_32gf",
"regnet_y_128gf",
"regnet_x_16gf",
"regnet_x_32gf",
]
for m in slow_models:
_model_params[m] = {"input_shape": (1, 3, 64, 64)}


# The following contains configuration and expected values to be used tests that are model specific
Expand Down Expand Up @@ -564,8 +580,8 @@ def test_classification_model(model_fn, dev):
out = model(x)
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == num_classes
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out)

if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
Expand Down Expand Up @@ -595,7 +611,7 @@ def test_segmentation_model(model_fn, dev):
model.eval().to(device=dev)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)["out"]
out = model(x)

def check_out(out):
prec = 0.01
Expand All @@ -615,17 +631,17 @@ def check_out(out):

return True # Full validation performed

full_validation = check_out(out)
full_validation = check_out(out["out"])

_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out)

if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
out = model(x)["out"]
out = model(x)
# See autocast_flaky_numerics comment at top of file.
if model_name not in autocast_flaky_numerics:
full_validation &= check_out(out)
full_validation &= check_out(out["out"])

if not full_validation:
msg = (
Expand Down Expand Up @@ -716,7 +732,7 @@ def compute_mean_std(tensor):
return True # Full validation performed

full_validation = check_out(out)
_check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None))
_check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)

if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
Expand Down Expand Up @@ -780,8 +796,8 @@ def test_video_model(model_fn, dev):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out)
assert out.shape[-1] == 50

if dev == torch.device("cuda"):
Expand Down Expand Up @@ -821,8 +837,13 @@ def test_quantized_classification_model(model_fn):
if model_name not in quantized_flaky_models:
_assert_expected(out, model_name + "_quantized", prec=0.1)
assert out.shape[-1] == 5
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out)
else:
try:
torch.jit.script(model)
except Exception as e:
raise AssertionError("model cannot be scripted.") from e

kwargs["quantize"] = False
for eval_mode in [True, False]:
Expand All @@ -843,12 +864,6 @@ def test_quantized_classification_model(model_fn):

torch.ao.quantization.convert(model, inplace=True)

try:
torch.jit.script(model)
except Exception as e:
tb = traceback.format_exc()
raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e


@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
Expand Down