forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_utils.py
123 lines (89 loc) · 3.41 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import pytest
from torchrl._utils import get_binary_env_var, implement_for
@pytest.mark.parametrize("value", ["True", "1", "true"])
def test_get_binary_env_var_positive(value):
try:
key = "SOME_ENVIRONMENT_VARIABLE_UNLIKELY_TO_BE_IN_ENVIRONMENT"
assert key not in os.environ
os.environ[key] = value
assert get_binary_env_var(key)
finally:
if key in os.environ:
del os.environ[key]
@pytest.mark.parametrize("value", ["False", "0", "false"])
def test_get_binary_env_var_negative(value):
try:
key = "SOME_ENVIRONMENT_VARIABLE_UNLIKELY_TO_BE_IN_ENVIRONMENT"
assert key not in os.environ
os.environ[key] = "True"
assert get_binary_env_var(key)
os.environ[key] = value
assert not get_binary_env_var(key)
finally:
if key in os.environ:
del os.environ[key]
def test_get_binary_env_var_missing():
try:
key = "SOME_ENVIRONMENT_VARIABLE_UNLIKELY_TO_BE_IN_ENVIRONMENT"
assert key not in os.environ
assert not get_binary_env_var(key)
finally:
if key in os.environ:
del os.environ[key]
def test_get_binary_env_var_wrong_value():
try:
key = "SOME_ENVIRONMENT_VARIABLE_UNLIKELY_TO_BE_IN_ENVIRONMENT"
assert key not in os.environ
os.environ[key] = "smthwrong"
with pytest.raises(ValueError):
get_binary_env_var(key)
finally:
if key in os.environ:
del os.environ[key]
class implement_for_test_functions:
"""
Groups functions that are used in tests for `implement_for` decorator.
"""
@staticmethod
@implement_for("_utils_internal", "0.3")
def select_correct_version():
"""To test from+ range and that this function is correctly selected as the implementation."""
return "0.3+"
@staticmethod
@implement_for("_utils_internal", "0.2", "0.3")
def select_correct_version(): # noqa: F811
"""To test that right bound is not included."""
return "0.2-0.3"
@staticmethod
@implement_for("_utils_internal", "0.1", "0.2")
def select_correct_version(): # noqa: F811
"""To test that function with missing from-to range is ignored."""
return "0.1-0.2"
@staticmethod
@implement_for("missing_module")
def missing_module():
"""To test that calling decorated function with missing module raises an exception."""
return "missing"
@staticmethod
@implement_for("_utils_internal", None, "0.3")
def missing_version():
return "0-0.3"
@staticmethod
@implement_for("_utils_internal", "0.4")
def missing_version(): # noqa: F811
return "0.4+"
def test_implement_for():
assert implement_for_test_functions.select_correct_version() == "0.3+"
def test_implement_for_missing_module():
msg = "Supported version of 'missing_module' has not been found."
with pytest.raises(ModuleNotFoundError, match=msg):
implement_for_test_functions.missing_module()
def test_implement_for_missing_version():
msg = "Supported version of '_utils_internal' has not been found."
with pytest.raises(ModuleNotFoundError, match=msg):
implement_for_test_functions.missing_version()