Skip to content

Commit

Permalink
Improved tests and exposed implement_for in the documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ordinskiy committed Nov 2, 2022
1 parent 2e56128 commit 37c481d
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 40 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ API Reference
modules
objectives
trainers
utils
13 changes: 13 additions & 0 deletions docs/source/reference/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. currentmodule:: torchrl._utils

torchrl._utils package
====================

Set of utility methods that are used internally by the library.


.. autosummary::
:toctree: generated/
:template: rl_template.rst

implement_for
82 changes: 45 additions & 37 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,49 +67,57 @@ def test_get_binary_env_var_wrong_value():
del os.environ[key]


# To test from+ range and that this function is correctly selected as the implementation.
@implement_for("_utils_internal", "0.3")
def implement_for_test_func():
return "0.3+"


# To test that right bound is not included.
@implement_for("_utils_internal", "0.2", "0.3")
def implement_for_test_func(): # noqa: F811
return "0.2-0.3"


# To test that function with missing from-to range is ignored.
@implement_for("_utils_internal", "0.1", "0.2")
def implement_for_test_func(): # noqa: F811
return "0.1-0.2"


# To test that calling decorated function with missing module raises an exception.
@implement_for("missing_module")
def implement_for_test_missing_module():
return "missing"


@implement_for("_utils_internal", None, "0.3")
def implement_for_test_missing_version():
return "0-0.3"


@implement_for("_utils_internal", "0.4")
def implement_for_test_missing_version(): # noqa: F811
return "0.4+"
class implement_for_test_functions:
"""
Groups functions that are used in tests for `implement_for` decorator.
"""

@staticmethod
@implement_for("_utils_internal", "0.3")
def select_correct_version():
"""To test from+ range and that this function is correctly selected as the implementation."""
return "0.3+"

@staticmethod
@implement_for("_utils_internal", "0.2", "0.3")
def select_correct_version(): # noqa: F811
"""To test that right bound is not included."""
return "0.2-0.3"

@staticmethod
@implement_for("_utils_internal", "0.1", "0.2")
def select_correct_version(): # noqa: F811
"""To test that function with missing from-to range is ignored."""
return "0.1-0.2"

@staticmethod
@implement_for("missing_module")
def missing_module():
"""To test that calling decorated function with missing module raises an exception."""
return "missing"

@staticmethod
@implement_for("_utils_internal", None, "0.3")
def missing_version():
return "0-0.3"

@staticmethod
@implement_for("_utils_internal", "0.4")
def missing_version(): # noqa: F811
return "0.4+"


def test_implement_for():
assert implement_for_test_func() == "0.3+"
assert implement_for_test_functions.select_correct_version() == "0.3+"


def test_implement_for_missing_module():
with pytest.raises(ModuleNotFoundError):
implement_for_test_missing_module()
msg = "Supported version of 'missing_module' has not been found."
with pytest.raises(ModuleNotFoundError, match=msg):
implement_for_test_functions.missing_module()


def test_implement_for_missing_version():
with pytest.raises(ModuleNotFoundError):
implement_for_test_missing_version()
msg = "Supported version of '_utils_internal' has not been found."
with pytest.raises(ModuleNotFoundError, match=msg):
implement_for_test_functions.missing_version()
6 changes: 3 additions & 3 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ class implement_for:
to_version: version from which implementation is no longer compatible. Can be open (None).
Examples:
@implement_for(“gym”, “0.13”, “0.14”)
def fun(self, x):
...
>>> @implement_for(“gym”, “0.13”, “0.14”)
>>> def fun(self, x):
This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+.
"""

Expand Down

0 comments on commit 37c481d

Please sign in to comment.