-
Notifications
You must be signed in to change notification settings - Fork 225
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #486 from muupan/share-persistent-values
Share persistent values among processes
- Loading branch information
Showing
5 changed files
with
277 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from __future__ import unicode_literals | ||
from __future__ import print_function | ||
from __future__ import division | ||
from __future__ import absolute_import | ||
from builtins import * # NOQA | ||
from future import standard_library | ||
standard_library.install_aliases() # NOQA | ||
|
||
import chainer | ||
|
||
|
||
def _namedchildren(link): | ||
if isinstance(link, chainer.Chain): | ||
for name in sorted(link._children): | ||
yield name, link.__dict__[name] | ||
elif isinstance(link, chainer.ChainList): | ||
for idx, child in enumerate(link._children): | ||
yield str(idx), child | ||
|
||
|
||
def namedpersistent(link): | ||
"""Return a generator of all (path, persistent) pairs for a given link. | ||
This function is adopted from https://github.com/chainer/chainer/pull/6788. | ||
Once it is merged into Chainer, we should use the property instead. | ||
Args: | ||
link (chainer.Link): Link. | ||
Returns: | ||
A generator object that generates all (path, persistent) pairs. | ||
The paths are relative from this link. | ||
""" | ||
d = link.__dict__ | ||
for name in sorted(link._persistent): | ||
yield '/' + name, d[name] | ||
for name, child in _namedchildren(link): | ||
prefix = '/' + name | ||
for path, persistent in namedpersistent(child): | ||
yield prefix + path, persistent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from __future__ import unicode_literals | ||
from __future__ import print_function | ||
from __future__ import division | ||
from __future__ import absolute_import | ||
from builtins import * # NOQA | ||
from future import standard_library | ||
standard_library.install_aliases() # NOQA | ||
|
||
import chainer | ||
import numpy | ||
|
||
import chainerrl | ||
|
||
|
||
def test_namedpersistent(): | ||
# This test case is adopted from | ||
# https://github.com/chainer/chainer/pull/6788 | ||
|
||
l1 = chainer.Link() | ||
with l1.init_scope(): | ||
l1.x = chainer.Parameter(shape=(2, 3)) | ||
|
||
l2 = chainer.Link() | ||
with l2.init_scope(): | ||
l2.x = chainer.Parameter(shape=2) | ||
l2.add_persistent( | ||
'l2_a', numpy.array([1, 2, 3], dtype=numpy.float32)) | ||
|
||
l3 = chainer.Link() | ||
with l3.init_scope(): | ||
l3.x = chainer.Parameter() | ||
l3.add_persistent( | ||
'l3_a', numpy.array([1, 2, 3], dtype=numpy.float32)) | ||
|
||
c1 = chainer.Chain() | ||
with c1.init_scope(): | ||
c1.l1 = l1 | ||
c1.add_link('l2', l2) | ||
c1.add_persistent( | ||
'c1_a', numpy.array([1, 2, 3], dtype=numpy.float32)) | ||
|
||
c2 = chainer.Chain() | ||
with c2.init_scope(): | ||
c2.c1 = c1 | ||
c2.l3 = l3 | ||
c2.add_persistent( | ||
'c2_a', numpy.array([1, 2, 3], dtype=numpy.float32)) | ||
namedpersistent = list(chainerrl.misc.namedpersistent(c2)) | ||
assert ( | ||
[(name, id(p)) for name, p in namedpersistent] == | ||
[('/c2_a', id(c2.c2_a)), ('/c1/c1_a', id(c2.c1.c1_a)), | ||
('/c1/l2/l2_a', id(c2.c1.l2.l2_a)), ('/l3/l3_a', id(c2.l3.l3_a))]) |