Skip to content

Commit

Permalink
[quant][refactor tests] Move qtensor serialization tests from test_de…
Browse files Browse the repository at this point in the history
…precated_jit (pytorch#59089)

Summary:
Pull Request resolved: pytorch#59089

Move these tests into test_quantized_tensor

Test Plan:
python test/test_quantization.py

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D28750065

fbshipit-source-id: 5c4350d49dd07710b86ba330de80369403c6013c
  • Loading branch information
supriyar authored and facebook-github-bot committed May 28, 2021
1 parent 886a2dd commit 89d7885
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 60 deletions.
63 changes: 62 additions & 1 deletion test/quantization/core/test_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import deepcopy
from hypothesis import given
from hypothesis import strategies as st

from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM
import torch.testing._internal.hypothesis_utils as hu
Expand Down Expand Up @@ -903,3 +903,64 @@ def test_choose_qparams_optimized(self):
ref = param_search_greedy(x.numpy(), bit_rate=bit_width)
self.assertEqual(y[0].numpy(), ref[0])
self.assertEqual(y[1].numpy(), ref[1])

def _test_pickle_checkpoint_qtensor(self, device):
with TemporaryFileName() as fname:
class M(torch.jit.ScriptModule):
__constants__ = ['fname']

def __init__(self):
super(M, self).__init__()
self.fname = fname

@torch.jit.script_method
def forward(self, x, y):
torch.save((x, y), self.fname)
return y

q = torch.quantize_per_tensor(
torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device)
qc = torch.quantize_per_channel(
torch.rand(2, 3, dtype=torch.float),
scales=torch.tensor([0.1, 0.5, 0.01]),
zero_points=torch.tensor([10, 0, 20]),
axis=1, dtype=torch.quint8).to(device)
m = M()
m(q, qc)
with open(fname, "rb") as handle:
loaded_q, loaded_qc = torch.load(fname)
self.assertEqual(loaded_q, q)
self.assertEqual(loaded_qc, qc)

def test_pickle_checkpoint_qtensor(self):
self._test_pickle_checkpoint_qtensor('cpu')

def test_jit_serialization(self):
class SimpleQTensor(torch.jit.ScriptModule):
def __init__(self, per_channel):
super(SimpleQTensor, self).__init__()
x = torch.rand(5, 5).float()
if not per_channel:
x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8)
else:
s = torch.rand(5, dtype=torch.float64) + 0.1
zp = torch.randint(5, 15, (5,))
x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8)
self.register_buffer('x', x_q)

@torch.jit.script_method
def forward(self):
return self.x

for per_channel in [False, True]:
model = SimpleQTensor(per_channel)
buffer = io.BytesIO()
torch.jit.save(model, buffer)
buffer.seek(0)
model_loaded = torch.jit.load(buffer)
self.assertEqual(model_loaded(), model())

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"
"instead.")
59 changes: 0 additions & 59 deletions test/quantization/jit/test_deprecated_jit_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
from torch.testing._internal.common_quantization import (
skipIfNoFBGEMM
)
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.common_utils import suppress_warnings
from torch.testing._internal.jit_utils import JitTestCase

from typing import Tuple
import copy
import io

# TODO: Move some tensor tests here like test_serialize_qtensor to test_quantize_tensor.py
class TestDeprecatedJitQuantized(JitTestCase):
@skipIfNoFBGEMM
def test_rnn_cell_quantized(self):
Expand Down Expand Up @@ -258,62 +255,6 @@ def forward(self, x):
torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3)
torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3)

def _test_pickle_checkpoint_qtensor(self, device):
with TemporaryFileName() as fname:
class M(torch.jit.ScriptModule):
__constants__ = ['fname']

def __init__(self):
super(M, self).__init__()
self.fname = fname

@torch.jit.script_method
def forward(self, x, y):
torch.save((x, y), self.fname)
return y

q = torch.quantize_per_tensor(
torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device)
qc = torch.quantize_per_channel(
torch.rand(2, 3, dtype=torch.float),
scales=torch.tensor([0.1, 0.5, 0.01]),
zero_points=torch.tensor([10, 0, 20]),
axis=1, dtype=torch.quint8).to(device)
m = M()
m(q, qc)
with open(fname, "rb") as handle:
loaded_q, loaded_qc = torch.load(fname)
self.assertEqual(loaded_q, q)
self.assertEqual(loaded_qc, qc)

def test_pickle_checkpoint_qtensor(self):
self._test_pickle_checkpoint_qtensor('cpu')

def test_serialize_qtensor(self):
class SimpleQTensor(torch.jit.ScriptModule):
def __init__(self, per_channel):
super(SimpleQTensor, self).__init__()
x = torch.rand(5, 5).float()
if not per_channel:
x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8)
else:
s = torch.rand(5, dtype=torch.float64) + 0.1
zp = torch.randint(5, 15, (5,))
x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8)
self.register_buffer('x', x_q)

@torch.jit.script_method
def forward(self):
return self.x

for per_channel in [False, True]:
model = SimpleQTensor(per_channel)
buffer = io.BytesIO()
torch.jit.save(model, buffer)
buffer.seek(0)
model_loaded = torch.jit.load(buffer)
self.assertEqual(model_loaded(), model())

@skipIfNoFBGEMM
def test_erase_class_tensor_shapes(self):
class Linear(torch.nn.Module):
Expand Down
1 change: 1 addition & 0 deletions test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# - quantized tensor

# 1. Quantized Kernels
# TODO: merge the different quantized op tests into one test class
from quantization.core.test_quantized_op import TestQuantizedOps # noqa: F401
from quantization.core.test_quantized_op import TestQNNPackOps # noqa: F401
from quantization.core.test_quantized_op import TestQuantizedLinear # noqa: F401
Expand Down

0 comments on commit 89d7885

Please sign in to comment.