Skip to content

Commit

Permalink
[GraphModule] Back out changes to module root version of __init__ (#5…
Browse files Browse the repository at this point in the history
…3791)

Summary: Pull Request resolved: #53791

Reviewed By: houseroad

Differential Revision: D26970869

fbshipit-source-id: 80684516f57fd2d1aca794f17fe488b2fe2b2f64
  • Loading branch information
jfix71 authored and facebook-github-bot committed Mar 11, 2021
1 parent 37ab711 commit 1053c96
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
7 changes: 3 additions & 4 deletions test/test_fx_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,19 +888,18 @@ def foo(x):

traced = symbolic_trace_with_rewrite(foo)

# TODO: Add support and coverage for pickling non-parameter/buffer Tensor
# attributes.
def test_to_folder(self):
class Test(torch.nn.Module):
def __init__(self):
super(Test, self).__init__()
self.W = torch.nn.Parameter(torch.randn(2))
self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
self.linear = torch.nn.Linear(2, 2)
self.register_buffer('attr', torch.randn(2))
self.attr = torch.randn(2)
self.register_buffer('attr2', torch.randn(2))

def forward(self, x):
return self.linear(self.seq(self.W + self.attr + x))
return self.linear(self.seq(self.W + self.attr + self.attr2 + x))

mod = symbolic_trace(Test())
module_name = 'Foo'
Expand Down
5 changes: 3 additions & 2 deletions torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target:
from_module, to_module = f, t

orig = getattr(from_module, field)
# Register it as a named buffer in to_module if it was a buffer in from_module.
if field in from_module._buffers.keys():
# If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
# So, we register it as a named buffer in the target module.
if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
to_module.register_buffer(field, orig)
else:
setattr(to_module, field, orig)
Expand Down

0 comments on commit 1053c96

Please sign in to comment.