Skip to content

Commit

Permalink
[BugFix] Fix locked params modif (pytorch#1307)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 22, 2023
1 parent e21e4cf commit 36d307b
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,8 @@ def test_functional(self, safe, spec_type):
assert hasattr(tdmodule, "__setitem__")
assert len(tdmodule) == 3
tdmodule[1] = tdmodule2
params["module", "1"] = params["module", "2"]
with params.unlock_():
params["module", "1"] = params["module", "2"]
assert len(tdmodule) == 3

assert hasattr(tdmodule, "__delitem__")
Expand Down Expand Up @@ -920,8 +921,9 @@ def test_functional_probabilistic(self, safe, spec_type):
assert len(tdmodule) == 4
tdmodule[1] = tdmodule2
tdmodule[2] = prob_module
params["module", "1"] = params["module", "2"]
params["module", "2"] = params["module", "3"]
with params.unlock_():
params["module", "1"] = params["module", "2"]
params["module", "2"] = params["module", "3"]
assert len(tdmodule) == 4

assert hasattr(tdmodule, "__delitem__")
Expand Down Expand Up @@ -995,7 +997,8 @@ def test_functional_with_buffer(self, safe, spec_type):
assert hasattr(tdmodule, "__setitem__")
assert len(tdmodule) == 3
tdmodule[1] = tdmodule2
params["module", "1"] = params["module", "2"]
with params.unlock_():
params["module", "1"] = params["module", "2"]
assert len(tdmodule) == 3

assert hasattr(tdmodule, "__delitem__")
Expand Down Expand Up @@ -1082,8 +1085,9 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type):
assert len(tdmodule) == 4
tdmodule[1] = tdmodule2
tdmodule[2] = prob_module
params["module", "1"] = params["module", "2"]
params["module", "2"] = params["module", "3"]
with params.unlock_():
params["module", "1"] = params["module", "2"]
params["module", "2"] = params["module", "3"]
assert len(tdmodule) == 4

assert hasattr(tdmodule, "__delitem__")
Expand Down Expand Up @@ -1163,7 +1167,8 @@ def test_vmap(self, safe, spec_type):
assert hasattr(tdmodule, "__setitem__")
assert len(tdmodule) == 3
tdmodule[1] = tdmodule2
params["module", "1"] = params["module", "2"]
with params.unlock_():
params["module", "1"] = params["module", "2"]
assert len(tdmodule) == 3

assert hasattr(tdmodule, "__delitem__")
Expand Down

0 comments on commit 36d307b

Please sign in to comment.