From 89d78851e68867140169a920c94d084445cf528d Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Thu, 27 May 2021 17:01:13 -0700 Subject: [PATCH] [quant][refactor tests] Move qtensor serialization tests from test_deprecated_jit (#59089) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59089 Move these tests into test_quantized_tensor Test Plan: python test/test_quantization.py Imported from OSS Reviewed By: jerryzh168 Differential Revision: D28750065 fbshipit-source-id: 5c4350d49dd07710b86ba330de80369403c6013c --- .../core/test_quantized_tensor.py | 63 ++++++++++++++++++- .../jit/test_deprecated_jit_quant.py | 59 ----------------- test/test_quantization.py | 1 + 3 files changed, 63 insertions(+), 60 deletions(-) diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py index 4a89332c46100b..70f309f1d2a9c4 100644 --- a/test/quantization/core/test_quantized_tensor.py +++ b/test/quantization/core/test_quantized_tensor.py @@ -6,7 +6,7 @@ from copy import deepcopy from hypothesis import given from hypothesis import strategies as st - +from torch.testing._internal.common_utils import TemporaryFileName from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM import torch.testing._internal.hypothesis_utils as hu @@ -903,3 +903,64 @@ def test_choose_qparams_optimized(self): ref = param_search_greedy(x.numpy(), bit_rate=bit_width) self.assertEqual(y[0].numpy(), ref[0]) self.assertEqual(y[1].numpy(), ref[1]) + + def _test_pickle_checkpoint_qtensor(self, device): + with TemporaryFileName() as fname: + class M(torch.jit.ScriptModule): + __constants__ = ['fname'] + + def __init__(self): + super(M, self).__init__() + self.fname = fname + + @torch.jit.script_method + def forward(self, x, y): + torch.save((x, y), self.fname) + return y + + q = torch.quantize_per_tensor( + torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device) + qc = torch.quantize_per_channel( + torch.rand(2, 3, dtype=torch.float), + scales=torch.tensor([0.1, 0.5, 0.01]), + zero_points=torch.tensor([10, 0, 20]), + axis=1, dtype=torch.quint8).to(device) + m = M() + m(q, qc) + with open(fname, "rb") as handle: + loaded_q, loaded_qc = torch.load(fname) + self.assertEqual(loaded_q, q) + self.assertEqual(loaded_qc, qc) + + def test_pickle_checkpoint_qtensor(self): + self._test_pickle_checkpoint_qtensor('cpu') + + def test_jit_serialization(self): + class SimpleQTensor(torch.jit.ScriptModule): + def __init__(self, per_channel): + super(SimpleQTensor, self).__init__() + x = torch.rand(5, 5).float() + if not per_channel: + x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8) + else: + s = torch.rand(5, dtype=torch.float64) + 0.1 + zp = torch.randint(5, 15, (5,)) + x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8) + self.register_buffer('x', x_q) + + @torch.jit.script_method + def forward(self): + return self.x + + for per_channel in [False, True]: + model = SimpleQTensor(per_channel) + buffer = io.BytesIO() + torch.jit.save(model, buffer) + buffer.seek(0) + model_loaded = torch.jit.load(buffer) + self.assertEqual(model_loaded(), model()) + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_quantization.py TESTNAME\n\n" + "instead.") diff --git a/test/quantization/jit/test_deprecated_jit_quant.py b/test/quantization/jit/test_deprecated_jit_quant.py index d98778f2d1dedf..662ead35bcf012 100644 --- a/test/quantization/jit/test_deprecated_jit_quant.py +++ b/test/quantization/jit/test_deprecated_jit_quant.py @@ -2,15 +2,12 @@ from torch.testing._internal.common_quantization import ( skipIfNoFBGEMM ) -from torch.testing._internal.common_utils import TemporaryFileName from torch.testing._internal.common_utils import suppress_warnings from torch.testing._internal.jit_utils import JitTestCase from typing import Tuple import copy -import io -# TODO: Move some tensor tests here like test_serialize_qtensor to test_quantize_tensor.py class TestDeprecatedJitQuantized(JitTestCase): @skipIfNoFBGEMM def test_rnn_cell_quantized(self): @@ -258,62 +255,6 @@ def forward(self, x): torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3) torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3) - def _test_pickle_checkpoint_qtensor(self, device): - with TemporaryFileName() as fname: - class M(torch.jit.ScriptModule): - __constants__ = ['fname'] - - def __init__(self): - super(M, self).__init__() - self.fname = fname - - @torch.jit.script_method - def forward(self, x, y): - torch.save((x, y), self.fname) - return y - - q = torch.quantize_per_tensor( - torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device) - qc = torch.quantize_per_channel( - torch.rand(2, 3, dtype=torch.float), - scales=torch.tensor([0.1, 0.5, 0.01]), - zero_points=torch.tensor([10, 0, 20]), - axis=1, dtype=torch.quint8).to(device) - m = M() - m(q, qc) - with open(fname, "rb") as handle: - loaded_q, loaded_qc = torch.load(fname) - self.assertEqual(loaded_q, q) - self.assertEqual(loaded_qc, qc) - - def test_pickle_checkpoint_qtensor(self): - self._test_pickle_checkpoint_qtensor('cpu') - - def test_serialize_qtensor(self): - class SimpleQTensor(torch.jit.ScriptModule): - def __init__(self, per_channel): - super(SimpleQTensor, self).__init__() - x = torch.rand(5, 5).float() - if not per_channel: - x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8) - else: - s = torch.rand(5, dtype=torch.float64) + 0.1 - zp = torch.randint(5, 15, (5,)) - x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8) - self.register_buffer('x', x_q) - - @torch.jit.script_method - def forward(self): - return self.x - - for per_channel in [False, True]: - model = SimpleQTensor(per_channel) - buffer = io.BytesIO() - torch.jit.save(model, buffer) - buffer.seek(0) - model_loaded = torch.jit.load(buffer) - self.assertEqual(model_loaded(), model()) - @skipIfNoFBGEMM def test_erase_class_tensor_shapes(self): class Linear(torch.nn.Module): diff --git a/test/test_quantization.py b/test/test_quantization.py index 6627a0435adaa0..0821fa1f5a8a29 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -10,6 +10,7 @@ # - quantized tensor # 1. Quantized Kernels +# TODO: merge the different quantized op tests into one test class from quantization.core.test_quantized_op import TestQuantizedOps # noqa: F401 from quantization.core.test_quantized_op import TestQNNPackOps # noqa: F401 from quantization.core.test_quantized_op import TestQuantizedLinear # noqa: F401