pytest-pytorch
is a lightweight pytest
-plugin that enhances the developer experience when working with the PyTorch test suite if you come from a pytest
background.
Some testcases in the PyTorch test suite are automatically generated when a module is loaded in order to parametrize them. Trying to collect them with their names as written, e.g. pytest test_foo.py::TestFoo
or pytest test_foo.py::TestFoo::test_bar
, is unfortunately not possible. If you are used to this syntax or your IDE relies on it (PyCharm, VSCode), you can install pytest-pytorch
to make it work.
You can install pytest-pytorch
with pip
$ pip install pytest-pytorch
or with conda
:
$ conda install -c conda-forge pytest-pytorch
With pytest-pytorch
installed you can select test cases and tests as if the instantiation for different devices was performed by @pytest.mark.parametrize
:
Use case | Command |
---|---|
Run a test case against all devices | pytest test_foo.py::TestBar |
Run a test case against one device | pytest test_foo.py::TestBar -k "$DEVICE" |
Run a test against all devices | pytest test_foo.py::TestBar::test_baz |
Run a test against one device | pytest test_foo.py::TestBar::test_baz -k "$DEVICE" |
PyTorch uses its own method for generating tests that is for the most part compatible with unittest
and pytest. Its custom test generation allows test templates to be written and instantiated for different device types, data types, and operators. Consider the following module test_foo.py
:
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_device_type import instantiate_device_type_tests
class TestFoo(TestCase):
def test_bar(self, device):
pass
def test_baz(self, device):
pass
instantiate_device_type_tests(TestFoo, globals())
Assuming we "cpu"
and "cuda"
are available as devices, we can collect four tests:
test_foo.py::TestFooCPU::test_bar_cpu
,test_foo.py::TestFooCPU::test_baz_cpu
,test_foo.py::TestFooCUDA::test_bar_cuda
, andtest_foo.py::TestFooCUDA::test_baz_cuda
.
From a pytest
perspective this is similar to decorating TestFoo
with @pytest.mark.parametrize("device", ("cpu", "cuda")))
which would result in
test_foo.py::TestFoo:test_bar[cpu]
,test_foo.py::TestFoo:test_bar[cuda]
,test_foo.py::TestFoo:test_baz[cpu]
, andtest_foo.py::TestFoo:test_baz[cuda]
.
Since the PyTorch test framework renames testcases and tests, naively running pytest test_foo.py::TestFoo
or pytest test_foo.py::TestFoo::test_bar
fails, because it can't find anything matching these names. Of course you can get around it by using the regular expression matching (-k
command line flag) that pytest
offers.
pytest-pytorch
performs this matching so you can keep your familiar workflow and your IDE is happy out of the box.
First and foremost: Thank you for your interest in development of pytest-pytorch
's! We appreciate all contributions be it code or something else. Check out our contribution guide lines for details.