Skip to content

Commit

Permalink
Merge pull request #486 from muupan/share-persistent-values
Browse files Browse the repository at this point in the history
Share persistent values among processes
  • Loading branch information
muupan authored Nov 7, 2019
2 parents ea98ae9 + b82b710 commit 4151087
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 24 deletions.
1 change: 1 addition & 0 deletions chainerrl/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from chainerrl.misc.draw_computational_graph import draw_computational_graph # NOQA
from chainerrl.misc.draw_computational_graph import is_graphviz_available # NOQA
from chainerrl.misc import env_modifiers # NOQA
from chainerrl.misc.namedpersistent import namedpersistent # NOQA
from chainerrl.misc.is_return_code_zero import is_return_code_zero # NOQA
from chainerrl.misc.random_seed import set_random_seed # NOQA
57 changes: 55 additions & 2 deletions chainerrl/misc/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import chainer
import numpy as np

import chainerrl
from chainerrl.misc import random_seed


Expand All @@ -32,19 +33,56 @@ def ensure_initialized_update_rule(param):
u.init_state(param)


def _set_persistent_values_recursively(link, persistent_name, shared_array):
if persistent_name.startswith('/'):
persistent_name = persistent_name[1:]
if hasattr(link, persistent_name):
attr_name = persistent_name
attr = getattr(link, attr_name)
if isinstance(attr, np.ndarray):
setattr(link, persistent_name, np.frombuffer(
shared_array, dtype=attr.dtype).reshape(attr.shape))
else:
assert np.isscalar(attr)
# We wrap scalars with np.ndarray because
# multiprocessing.RawValue cannot be used as a scalar, while
# np.ndarray can be.
typecode = np.asarray(attr).dtype.char
setattr(link, attr_name, np.frombuffer(
shared_array, dtype=typecode).reshape(()))
else:
assert isinstance(link, (chainer.Chain, chainer.ChainList))
assert '/' in persistent_name
child_name, remaining = persistent_name.split('/', 1)
if isinstance(link, chainer.Chain):
_set_persistent_values_recursively(
getattr(link, child_name), remaining, shared_array)
else:
_set_persistent_values_recursively(
link[int(child_name)], remaining, shared_array)


def set_shared_params(a, b):
"""Set shared params to a link.
"""Set shared params (and persistent values) to a link.
Args:
a (chainer.Link): link whose params are to be replaced
b (dict): dict that consists of (param_name, multiprocessing.Array)
"""
assert isinstance(a, chainer.Link)
remaining_keys = set(b.keys())
for param_name, param in a.namedparams():
if param_name in b:
shared_param = b[param_name]
param.array = np.frombuffer(
shared_param, dtype=param.dtype).reshape(param.shape)
remaining_keys.remove(param_name)
for persistent_name, _ in chainerrl.misc.namedpersistent(a):
if persistent_name in b:
_set_persistent_values_recursively(
a, persistent_name, b[persistent_name])
remaining_keys.remove(persistent_name)
assert not remaining_keys


def make_params_not_shared(a):
Expand Down Expand Up @@ -85,7 +123,22 @@ def extract_params_as_shared_arrays(link):
assert isinstance(link, chainer.Link)
shared_arrays = {}
for param_name, param in link.namedparams():
shared_arrays[param_name] = mp.RawArray('f', param.array.ravel())
typecode = param.array.dtype.char
shared_arrays[param_name] = mp.RawArray(typecode, param.array.ravel())

for persistent_name, persistent in chainerrl.misc.namedpersistent(link):
if isinstance(persistent, np.ndarray):
typecode = persistent.dtype.char
shared_arrays[persistent_name] = mp.RawArray(
typecode, persistent.ravel())
else:
assert np.isscalar(persistent)
# Wrap by a 1-dim array because multiprocessing.RawArray does not
# accept a 0-dim array.
persistent_as_array = np.asarray([persistent])
typecode = persistent_as_array.dtype.char
shared_arrays[persistent_name] = mp.RawArray(
typecode, persistent_as_array)
return shared_arrays


Expand Down
40 changes: 40 additions & 0 deletions chainerrl/misc/namedpersistent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA

import chainer


def _namedchildren(link):
if isinstance(link, chainer.Chain):
for name in sorted(link._children):
yield name, link.__dict__[name]
elif isinstance(link, chainer.ChainList):
for idx, child in enumerate(link._children):
yield str(idx), child


def namedpersistent(link):
"""Return a generator of all (path, persistent) pairs for a given link.
This function is adopted from https://github.com/chainer/chainer/pull/6788.
Once it is merged into Chainer, we should use the property instead.
Args:
link (chainer.Link): Link.
Returns:
A generator object that generates all (path, persistent) pairs.
The paths are relative from this link.
"""
d = link.__dict__
for name in sorted(link._persistent):
yield '/' + name, d[name]
for name, child in _namedchildren(link):
prefix = '/' + name
for path, persistent in namedpersistent(child):
yield prefix + path, persistent
151 changes: 129 additions & 22 deletions tests/misc_tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,51 +19,160 @@
import copy
import numpy as np

import chainerrl
from chainerrl.misc import async_


def _assert_same_pointers_to_persistent_values(a, b):
assert isinstance(a, chainer.Link)
assert isinstance(b, chainer.Link)
a_persistents = dict(chainerrl.misc.namedpersistent(a))
b_persistents = dict(chainerrl.misc.namedpersistent(b))
assert set(a_persistents.keys()) == set(b_persistents.keys())
for key in a_persistents:
a_persistent = a_persistents[key]
b_persistent = b_persistents[key]
assert isinstance(a_persistent, np.ndarray)
assert isinstance(b_persistent, np.ndarray)
assert a_persistent.ctypes.data == b_persistent.ctypes.data


def _assert_same_pointers_to_param_data(a, b):
assert isinstance(a, chainer.Link)
assert isinstance(b, chainer.Link)
a_params = dict(a.namedparams())
b_params = dict(b.namedparams())
assert set(a_params.keys()) == set(b_params.keys())
for key in a_params.keys():
assert isinstance(a_params[key], chainer.Variable)
assert isinstance(b_params[key], chainer.Variable)
assert (a_params[key].array.ctypes.data
== b_params[key].array.ctypes.data)


def _assert_different_pointers_to_param_grad(a, b):
assert isinstance(a, chainer.Link)
assert isinstance(b, chainer.Link)
a_params = dict(a.namedparams())
b_params = dict(b.namedparams())
assert set(a_params.keys()) == set(b_params.keys())
for key in a_params.keys():
assert isinstance(a_params[key], chainer.Variable)
assert isinstance(b_params[key], chainer.Variable)
assert (a_params[key].grad.ctypes.data
!= b_params[key].grad.ctypes.data)


class TestAsync(unittest.TestCase):

def setUp(self):
pass

def test_share_params(self):
def test_share_params_linear(self):

# A's params are shared with B and C so that all the three share the
# same parameter arrays

model_a = L.Linear(2, 2)

arrays = async_.share_params_as_shared_arrays(model_a)
assert isinstance(arrays, dict)
assert set(arrays.keys()) == {'/W', '/b'}

model_b = L.Linear(2, 2)
model_c = L.Linear(2, 2)

async_.set_shared_params(model_b, arrays)
async_.set_shared_params(model_c, arrays)

a_params = dict(model_a.namedparams())
b_params = dict(model_b.namedparams())
c_params = dict(model_c.namedparams())
# Pointers to parameters must be the same
_assert_same_pointers_to_param_data(model_a, model_b)
_assert_same_pointers_to_param_data(model_a, model_c)
# Pointers to gradients must be different
_assert_different_pointers_to_param_grad(model_a, model_b)
_assert_different_pointers_to_param_grad(model_a, model_c)
_assert_different_pointers_to_param_grad(model_b, model_c)
# Pointers to persistent values must be the same
_assert_same_pointers_to_persistent_values(model_a, model_b)
_assert_same_pointers_to_persistent_values(model_a, model_c)

def test_share_params_batch_normalization(self):

# A's params and persistent values are all shared with B and C

model_a = L.BatchNormalization(3)

arrays = async_.share_params_as_shared_arrays(model_a)
assert isinstance(arrays, dict)
assert set(arrays.keys()) == {
'/gamma', '/beta', '/avg_mean', '/avg_var', '/N'}

def assert_same_pointers_to_data(a, b):
self.assertEqual(a['/W'].array.ctypes.data,
b['/W'].array.ctypes.data)
self.assertEqual(a['/b'].array.ctypes.data,
b['/b'].array.ctypes.data)
model_b = L.BatchNormalization(3)
model_c = L.BatchNormalization(3)

def assert_different_pointers_to_grad(a, b):
self.assertNotEqual(a['/W'].grad.ctypes.data,
b['/W'].grad.ctypes.data)
self.assertNotEqual(a['/b'].grad.ctypes.data,
b['/b'].grad.ctypes.data)
async_.set_shared_params(model_b, arrays)
async_.set_shared_params(model_c, arrays)

# Pointers to parameters must be the same
_assert_same_pointers_to_param_data(model_a, model_b)
_assert_same_pointers_to_param_data(model_a, model_c)
# Pointers to gradients must be different
_assert_different_pointers_to_param_grad(model_a, model_b)
_assert_different_pointers_to_param_grad(model_a, model_c)
_assert_different_pointers_to_param_grad(model_b, model_c)
# Pointers to persistent values must be the same
_assert_same_pointers_to_persistent_values(model_a, model_b)
_assert_same_pointers_to_persistent_values(model_a, model_c)

# Check if N is shared correctly among links
assert model_a.N == 0
assert model_b.N == 0
assert model_c.N == 0
test_input = np.random.normal(size=(2, 3)).astype(np.float32)
model_a(test_input, finetune=True)
assert model_a.N == 1
assert model_b.N == 1
assert model_c.N == 1
model_c(test_input, finetune=True)
assert model_a.N == 2
assert model_b.N == 2
assert model_c.N == 2

def test_share_params_chain_list(self):

model_a = chainer.ChainList(
L.BatchNormalization(3),
chainer.ChainList(L.Linear(3, 5)),
)

arrays = async_.share_params_as_shared_arrays(model_a)
assert isinstance(arrays, dict)
assert set(arrays.keys()) == {
'/0/gamma', '/0/beta', '/0/avg_mean', '/0/avg_var', '/0/N',
'/1/0/W', '/1/0/b'}

model_b = chainer.ChainList(
L.BatchNormalization(3),
chainer.ChainList(L.Linear(3, 5)),
)
model_c = chainer.ChainList(
L.BatchNormalization(3),
chainer.ChainList(L.Linear(3, 5)),
)

async_.set_shared_params(model_b, arrays)
async_.set_shared_params(model_c, arrays)

# Pointers to parameters must be the same
assert_same_pointers_to_data(a_params, b_params)
assert_same_pointers_to_data(a_params, c_params)
_assert_same_pointers_to_param_data(model_a, model_b)
_assert_same_pointers_to_param_data(model_a, model_c)
# Pointers to gradients must be different
assert_different_pointers_to_grad(a_params, b_params)
assert_different_pointers_to_grad(a_params, c_params)
_assert_different_pointers_to_param_grad(model_a, model_b)
_assert_different_pointers_to_param_grad(model_a, model_c)
_assert_different_pointers_to_param_grad(model_b, model_c)
# Pointers to persistent values must be the same
_assert_same_pointers_to_persistent_values(model_a, model_b)
_assert_same_pointers_to_persistent_values(model_a, model_c)

def test_share_states(self):

Expand Down Expand Up @@ -114,10 +223,8 @@ def test_shared_link(self):
model_a = chainer.ChainList(head.copy(), L.Linear(2, 3))
model_b = chainer.ChainList(head.copy(), L.Linear(2, 4))

a_arrays = async_.extract_params_as_shared_arrays(
chainer.ChainList(model_a))
b_arrays = async_.extract_params_as_shared_arrays(
chainer.ChainList(model_b))
a_arrays = async_.extract_params_as_shared_arrays(model_a)
b_arrays = async_.extract_params_as_shared_arrays(model_b)

print(('model_a shared_arrays', a_arrays))
print(('model_b shared_arrays', b_arrays))
Expand Down
52 changes: 52 additions & 0 deletions tests/misc_tests/test_namedpersistent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA

import chainer
import numpy

import chainerrl


def test_namedpersistent():
# This test case is adopted from
# https://github.com/chainer/chainer/pull/6788

l1 = chainer.Link()
with l1.init_scope():
l1.x = chainer.Parameter(shape=(2, 3))

l2 = chainer.Link()
with l2.init_scope():
l2.x = chainer.Parameter(shape=2)
l2.add_persistent(
'l2_a', numpy.array([1, 2, 3], dtype=numpy.float32))

l3 = chainer.Link()
with l3.init_scope():
l3.x = chainer.Parameter()
l3.add_persistent(
'l3_a', numpy.array([1, 2, 3], dtype=numpy.float32))

c1 = chainer.Chain()
with c1.init_scope():
c1.l1 = l1
c1.add_link('l2', l2)
c1.add_persistent(
'c1_a', numpy.array([1, 2, 3], dtype=numpy.float32))

c2 = chainer.Chain()
with c2.init_scope():
c2.c1 = c1
c2.l3 = l3
c2.add_persistent(
'c2_a', numpy.array([1, 2, 3], dtype=numpy.float32))
namedpersistent = list(chainerrl.misc.namedpersistent(c2))
assert (
[(name, id(p)) for name, p in namedpersistent] ==
[('/c2_a', id(c2.c2_a)), ('/c1/c1_a', id(c2.c1.c1_a)),
('/c1/l2/l2_a', id(c2.c1.l2.l2_a)), ('/l3/l3_a', id(c2.l3.l3_a))])

0 comments on commit 4151087

Please sign in to comment.