From 9a26e2af6ef10ab9fcf1be433f514ded28dd4796 Mon Sep 17 00:00:00 2001 From: muupan Date: Mon, 17 Jun 2019 22:10:40 +0900 Subject: [PATCH 01/13] Share persistent values among processes so that BatchNormalization etc can be safely shared among processes --- chainerrl/misc/async_.py | 33 +++++++++- tests/misc_tests/test_async.py | 108 +++++++++++++++++++++++++++------ 2 files changed, 121 insertions(+), 20 deletions(-) diff --git a/chainerrl/misc/async_.py b/chainerrl/misc/async_.py index 317d20877..e2a84290b 100644 --- a/chainerrl/misc/async_.py +++ b/chainerrl/misc/async_.py @@ -33,7 +33,7 @@ def ensure_initialized_update_rule(param): def set_shared_params(a, b): - """Set shared params to a link. + """Set shared params (and persistent values) to a link. Args: a (chainer.Link): link whose params are to be replaced @@ -45,6 +45,22 @@ def set_shared_params(a, b): shared_param = b[param_name] param.array = np.frombuffer( shared_param, dtype=param.dtype).reshape(param.shape) + for persistent_name in a._persistent: + persistent = a.__dict__[persistent_name] + if persistent_name in b: + shared_param = b[persistent_name] + if isinstance(persistent, np.ndarray): + a.__dict__[persistent_name] = np.frombuffer( + shared_param, dtype=persistent.dtype).reshape( + persistent.shape) + else: + assert np.isscalar(persistent) + # We wrap scalars with np.ndarray because + # multiprocessing.RawValue cannot be used as a scalar, while + # np.ndarray can be. + typecode = np.asarray(persistent).dtype.char + a.__dict__[persistent_name] = np.frombuffer( + shared_param, dtype=typecode).reshape(()) def make_params_not_shared(a): @@ -85,7 +101,20 @@ def extract_params_as_shared_arrays(link): assert isinstance(link, chainer.Link) shared_arrays = {} for param_name, param in link.namedparams(): - shared_arrays[param_name] = mp.RawArray('f', param.array.ravel()) + typecode = param.array.dtype.char + shared_arrays[param_name] = mp.RawArray(typecode, param.array.ravel()) + for persistent_name in link._persistent: + persistent = link.__dict__[persistent_name] + if isinstance(persistent, np.ndarray): + typecode = persistent.dtype.char + shared_arrays[persistent_name] = mp.RawArray( + typecode, persistent.ravel()) + else: + assert np.isscalar(persistent) + persistent_as_array = np.asarray([persistent]) + typecode = persistent_as_array.dtype.char + shared_arrays[persistent_name] = mp.RawArray( + typecode, persistent_as_array) return shared_arrays diff --git a/tests/misc_tests/test_async.py b/tests/misc_tests/test_async.py index 6e52de85d..af533cb45 100644 --- a/tests/misc_tests/test_async.py +++ b/tests/misc_tests/test_async.py @@ -22,12 +22,52 @@ from chainerrl.misc import async_ +def _assert_same_pointers_to_persistent_values(a, b): + assert isinstance(a, chainer.Link) + assert isinstance(b, chainer.Link) + a_persistent_names = set(a._persistent) + b_persistent_names = set(b._persistent) + assert a_persistent_names == b_persistent_names + for key in a_persistent_names: + a_persistent = a.__dict__[key] + b_persistent = b.__dict__[key] + assert isinstance(a_persistent, np.ndarray) + assert isinstance(b_persistent, np.ndarray) + assert a_persistent.ctypes.data == b_persistent.ctypes.data + + +def _assert_same_pointers_to_param_data(a, b): + assert isinstance(a, chainer.Link) + assert isinstance(b, chainer.Link) + a_params = dict(a.namedparams()) + b_params = dict(b.namedparams()) + assert set(a_params.keys()) == set(b_params.keys()) + for key in a_params.keys(): + assert isinstance(a_params[key], chainer.Variable) + assert isinstance(b_params[key], chainer.Variable) + assert (a_params[key].array.ctypes.data + == b_params[key].array.ctypes.data) + + +def _assert_different_pointers_to_param_grad(a, b): + assert isinstance(a, chainer.Link) + assert isinstance(b, chainer.Link) + a_params = dict(a.namedparams()) + b_params = dict(b.namedparams()) + assert set(a_params.keys()) == set(b_params.keys()) + for key in a_params.keys(): + assert isinstance(a_params[key], chainer.Variable) + assert isinstance(b_params[key], chainer.Variable) + assert (a_params[key].grad.ctypes.data + != b_params[key].grad.ctypes.data) + + class TestAsync(unittest.TestCase): def setUp(self): pass - def test_share_params(self): + def test_share_params_linear(self): # A's params are shared with B and C so that all the three share the # same parameter arrays @@ -35,6 +75,8 @@ def test_share_params(self): model_a = L.Linear(2, 2) arrays = async_.share_params_as_shared_arrays(model_a) + assert isinstance(arrays, dict) + assert set(arrays.keys()) == {'/W', '/b'} model_b = L.Linear(2, 2) model_c = L.Linear(2, 2) @@ -42,28 +84,58 @@ def test_share_params(self): async_.set_shared_params(model_b, arrays) async_.set_shared_params(model_c, arrays) - a_params = dict(model_a.namedparams()) - b_params = dict(model_b.namedparams()) - c_params = dict(model_c.namedparams()) + # Pointers to parameters must be the same + _assert_same_pointers_to_param_data(model_a, model_b) + _assert_same_pointers_to_param_data(model_a, model_c) + # Pointers to gradients must be different + _assert_different_pointers_to_param_grad(model_a, model_b) + _assert_different_pointers_to_param_grad(model_a, model_c) + _assert_different_pointers_to_param_grad(model_b, model_c) + # Pointers to persistent values must be the same + _assert_same_pointers_to_persistent_values(model_a, model_b) + _assert_same_pointers_to_persistent_values(model_a, model_c) + + def test_share_params_batch_normalization(self): - def assert_same_pointers_to_data(a, b): - self.assertEqual(a['/W'].array.ctypes.data, - b['/W'].array.ctypes.data) - self.assertEqual(a['/b'].array.ctypes.data, - b['/b'].array.ctypes.data) + # A's params and persistent values are all shared with B and C - def assert_different_pointers_to_grad(a, b): - self.assertNotEqual(a['/W'].grad.ctypes.data, - b['/W'].grad.ctypes.data) - self.assertNotEqual(a['/b'].grad.ctypes.data, - b['/b'].grad.ctypes.data) + model_a = L.BatchNormalization(3) + + arrays = async_.share_params_as_shared_arrays(model_a) + assert isinstance(arrays, dict) + assert set(arrays.keys()) == { + '/gamma', '/beta', 'avg_mean', 'avg_var', 'N'} + + model_b = L.BatchNormalization(3) + model_c = L.BatchNormalization(3) + + async_.set_shared_params(model_b, arrays) + async_.set_shared_params(model_c, arrays) # Pointers to parameters must be the same - assert_same_pointers_to_data(a_params, b_params) - assert_same_pointers_to_data(a_params, c_params) + _assert_same_pointers_to_param_data(model_a, model_b) + _assert_same_pointers_to_param_data(model_a, model_c) # Pointers to gradients must be different - assert_different_pointers_to_grad(a_params, b_params) - assert_different_pointers_to_grad(a_params, c_params) + _assert_different_pointers_to_param_grad(model_a, model_b) + _assert_different_pointers_to_param_grad(model_a, model_c) + _assert_different_pointers_to_param_grad(model_b, model_c) + # Pointers to persistent values must be the same + _assert_same_pointers_to_persistent_values(model_a, model_b) + _assert_same_pointers_to_persistent_values(model_a, model_c) + + # Check if N is shared correctly among links + assert model_a.N == 0 + assert model_b.N == 0 + assert model_c.N == 0 + test_input = np.random.normal(size=(2, 3)).astype(np.float32) + model_a(test_input, finetune=True) + assert model_a.N == 1 + assert model_b.N == 1 + assert model_c.N == 1 + model_c(test_input, finetune=True) + assert model_a.N == 2 + assert model_b.N == 2 + assert model_c.N == 2 def test_share_states(self): From 5c120ba4a25baadde43f8fd513f49f7f6a3444ca Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 19 Jun 2019 15:43:29 +0900 Subject: [PATCH 02/13] Add namedpersistent function --- chainerrl/misc/namedpersistent.py | 37 +++++++++++++++++ tests/misc_tests/test_namedpersistent.py | 52 ++++++++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 chainerrl/misc/namedpersistent.py create mode 100644 tests/misc_tests/test_namedpersistent.py diff --git a/chainerrl/misc/namedpersistent.py b/chainerrl/misc/namedpersistent.py new file mode 100644 index 000000000..237a3ffff --- /dev/null +++ b/chainerrl/misc/namedpersistent.py @@ -0,0 +1,37 @@ +from __future__ import unicode_literals +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import +from builtins import * # NOQA +from future import standard_library +standard_library.install_aliases() # NOQA + +import chainer + + +def namedpersistent(link): + """Return a generator of all (path, persistent) pairs for a given link. + + This function is adopted from https://github.com/chainer/chainer/pull/6788. + Once it is merged into Chainer, we should use the property instead. + + Args: + link (chainer.Link): Link. + + Returns: + A generator object that generates all (path, persistent) pairs. + The paths are relative from this link. + """ + d = link.__dict__ + for name in sorted(link._persistent): + yield '/' + name, d[name] + if isinstance(link, chainer.Chain): + for name in sorted(link._children): + prefix = '/' + name + for path, persistent in namedpersistent(d[name]): + yield prefix + path, persistent + elif isinstance(link, chainer.ChainList): + for idx, link in enumerate(link._children): + prefix = '/{}'.format(idx) + for path, persistent in namedpersistent(link): + yield prefix + path, persistent diff --git a/tests/misc_tests/test_namedpersistent.py b/tests/misc_tests/test_namedpersistent.py new file mode 100644 index 000000000..b669a398e --- /dev/null +++ b/tests/misc_tests/test_namedpersistent.py @@ -0,0 +1,52 @@ +from __future__ import unicode_literals +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import +from builtins import * # NOQA +from future import standard_library +standard_library.install_aliases() # NOQA + +import chainer +import numpy + +import chainerrl + + +def test_namedpersistent(): + # This test case is adopted from + # https://github.com/chainer/chainer/pull/6788 + + l1 = chainer.Link() + with l1.init_scope(): + l1.x = chainer.Parameter(shape=(2, 3)) + + l2 = chainer.Link() + with l2.init_scope(): + l2.x = chainer.Parameter(shape=2) + l2.add_persistent( + 'l2_a', numpy.array([1, 2, 3], dtype=numpy.float32)) + + l3 = chainer.Link() + with l3.init_scope(): + l3.x = chainer.Parameter() + l3.add_persistent( + 'l3_a', numpy.array([1, 2, 3], dtype=numpy.float32)) + + c1 = chainer.Chain() + with c1.init_scope(): + c1.l1 = l1 + c1.add_link('l2', l2) + c1.add_persistent( + 'c1_a', numpy.array([1, 2, 3], dtype=numpy.float32)) + + c2 = chainer.Chain() + with c2.init_scope(): + c2.c1 = c1 + c2.l3 = l3 + c2.add_persistent( + 'c2_a', numpy.array([1, 2, 3], dtype=numpy.float32)) + namedpersistent = list(chainerrl.misc.namedpersistent(c2)) + assert ( + [(name, id(p)) for name, p in namedpersistent] == + [('/c2_a', id(c2.c2_a)), ('/c1/c1_a', id(c2.c1.c1_a)), + ('/c1/l2/l2_a', id(c2.c1.l2.l2_a)), ('/l3/l3_a', id(c2.l3.l3_a))]) From fbafd1a7f1bda70707ba1bc29b4c9aeb599bdd95 Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 19 Jun 2019 17:08:39 +0900 Subject: [PATCH 03/13] Support persistent values of child links --- chainerrl/misc/async_.py | 54 ++++++++++++++++++++++------------ tests/misc_tests/test_async.py | 51 +++++++++++++++++++++++++++----- 2 files changed, 80 insertions(+), 25 deletions(-) diff --git a/chainerrl/misc/async_.py b/chainerrl/misc/async_.py index e2a84290b..b57df0996 100644 --- a/chainerrl/misc/async_.py +++ b/chainerrl/misc/async_.py @@ -12,7 +12,7 @@ import chainer import numpy as np -from chainerrl.misc import random_seed +import chainerrl class AbnormalExitWarning(Warning): @@ -32,6 +32,36 @@ def ensure_initialized_update_rule(param): u.init_state(param) +def _set_persistent_values_recursively(link, persistent_name, shared_array): + if persistent_name.startswith('/'): + persistent_name = persistent_name[1:] + if hasattr(link, persistent_name): + attr_name = persistent_name + attr = getattr(link, attr_name) + if isinstance(attr, np.ndarray): + setattr(link, persistent_name, np.frombuffer( + shared_array, dtype=attr.dtype).reshape(attr.shape)) + else: + assert np.isscalar(attr) + # We wrap scalars with np.ndarray because + # multiprocessing.RawValue cannot be used as a scalar, while + # np.ndarray can be. + typecode = np.asarray(attr).dtype.char + setattr(link, attr_name, np.frombuffer( + shared_array, dtype=typecode).reshape(())) + else: + assert isinstance(link, (chainer.Chain, chainer.ChainList)) + assert '/' in persistent_name + if isinstance(link, chainer.Chain): + child_name, remaining = persistent_name.split('/') + _set_persistent_values_recursively( + getattr(link, child_name), remaining, shared_array) + else: + child_idx, remaining = persistent_name.split('/') + _set_persistent_values_recursively( + link._children[int(child_idx)], remaining, shared_array) + + def set_shared_params(a, b): """Set shared params (and persistent values) to a link. @@ -45,22 +75,10 @@ def set_shared_params(a, b): shared_param = b[param_name] param.array = np.frombuffer( shared_param, dtype=param.dtype).reshape(param.shape) - for persistent_name in a._persistent: - persistent = a.__dict__[persistent_name] + for persistent_name, _ in chainerrl.misc.namedpersistent(a): if persistent_name in b: - shared_param = b[persistent_name] - if isinstance(persistent, np.ndarray): - a.__dict__[persistent_name] = np.frombuffer( - shared_param, dtype=persistent.dtype).reshape( - persistent.shape) - else: - assert np.isscalar(persistent) - # We wrap scalars with np.ndarray because - # multiprocessing.RawValue cannot be used as a scalar, while - # np.ndarray can be. - typecode = np.asarray(persistent).dtype.char - a.__dict__[persistent_name] = np.frombuffer( - shared_param, dtype=typecode).reshape(()) + _set_persistent_values_recursively( + a, persistent_name, b[persistent_name]) def make_params_not_shared(a): @@ -103,8 +121,8 @@ def extract_params_as_shared_arrays(link): for param_name, param in link.namedparams(): typecode = param.array.dtype.char shared_arrays[param_name] = mp.RawArray(typecode, param.array.ravel()) - for persistent_name in link._persistent: - persistent = link.__dict__[persistent_name] + + for persistent_name, persistent in chainerrl.misc.namedpersistent(link): if isinstance(persistent, np.ndarray): typecode = persistent.dtype.char shared_arrays[persistent_name] = mp.RawArray( diff --git a/tests/misc_tests/test_async.py b/tests/misc_tests/test_async.py index af533cb45..48b863f26 100644 --- a/tests/misc_tests/test_async.py +++ b/tests/misc_tests/test_async.py @@ -19,18 +19,19 @@ import copy import numpy as np +import chainerrl from chainerrl.misc import async_ def _assert_same_pointers_to_persistent_values(a, b): assert isinstance(a, chainer.Link) assert isinstance(b, chainer.Link) - a_persistent_names = set(a._persistent) - b_persistent_names = set(b._persistent) - assert a_persistent_names == b_persistent_names - for key in a_persistent_names: - a_persistent = a.__dict__[key] - b_persistent = b.__dict__[key] + a_persistents = dict(chainerrl.misc.namedpersistent(a)) + b_persistents = dict(chainerrl.misc.namedpersistent(b)) + assert set(a_persistents.keys()) == set(b_persistents.keys()) + for key in a_persistents: + a_persistent = a_persistents[key] + b_persistent = b_persistents[key] assert isinstance(a_persistent, np.ndarray) assert isinstance(b_persistent, np.ndarray) assert a_persistent.ctypes.data == b_persistent.ctypes.data @@ -104,7 +105,7 @@ def test_share_params_batch_normalization(self): arrays = async_.share_params_as_shared_arrays(model_a) assert isinstance(arrays, dict) assert set(arrays.keys()) == { - '/gamma', '/beta', 'avg_mean', 'avg_var', 'N'} + '/gamma', '/beta', '/avg_mean', '/avg_var', '/N'} model_b = L.BatchNormalization(3) model_c = L.BatchNormalization(3) @@ -137,6 +138,42 @@ def test_share_params_batch_normalization(self): assert model_b.N == 2 assert model_c.N == 2 + def test_share_params_chain_list(self): + + model_a = chainer.ChainList( + L.BatchNormalization(3), + L.Linear(3, 5), + ) + + arrays = async_.share_params_as_shared_arrays(model_a) + assert isinstance(arrays, dict) + assert set(arrays.keys()) == { + '/0/gamma', '/0/beta', '/0/avg_mean', '/0/avg_var', '/0/N', + '/1/W', '/1/b'} + + model_b = chainer.ChainList( + L.BatchNormalization(3), + L.Linear(3, 5), + ) + model_c = chainer.ChainList( + L.BatchNormalization(3), + L.Linear(3, 5), + ) + + async_.set_shared_params(model_b, arrays) + async_.set_shared_params(model_c, arrays) + + # Pointers to parameters must be the same + _assert_same_pointers_to_param_data(model_a, model_b) + _assert_same_pointers_to_param_data(model_a, model_c) + # Pointers to gradients must be different + _assert_different_pointers_to_param_grad(model_a, model_b) + _assert_different_pointers_to_param_grad(model_a, model_c) + _assert_different_pointers_to_param_grad(model_b, model_c) + # Pointers to persistent values must be the same + _assert_same_pointers_to_persistent_values(model_a, model_b) + _assert_same_pointers_to_persistent_values(model_a, model_c) + def test_share_states(self): model = L.Linear(2, 2) From ae1efd1b3478926c17be4f5d913d15919f0f6661 Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 19 Jun 2019 17:13:35 +0900 Subject: [PATCH 04/13] Add namedpersistent under chainerrl.misc --- chainerrl/misc/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chainerrl/misc/__init__.py b/chainerrl/misc/__init__.py index 1219bee5e..c3c1266bb 100644 --- a/chainerrl/misc/__init__.py +++ b/chainerrl/misc/__init__.py @@ -4,5 +4,6 @@ from chainerrl.misc.draw_computational_graph import draw_computational_graph # NOQA from chainerrl.misc.draw_computational_graph import is_graphviz_available # NOQA from chainerrl.misc import env_modifiers # NOQA +from chainerrl.misc.namedpersistent import namedpersistent # NOQA from chainerrl.misc.is_return_code_zero import is_return_code_zero # NOQA from chainerrl.misc.random_seed import set_random_seed # NOQA From 70a2df127b5c9b23c272945ce7fc520a5eb51401 Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 19 Jun 2019 17:13:57 +0900 Subject: [PATCH 05/13] Restore import --- chainerrl/misc/async_.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chainerrl/misc/async_.py b/chainerrl/misc/async_.py index b57df0996..f492b3996 100644 --- a/chainerrl/misc/async_.py +++ b/chainerrl/misc/async_.py @@ -13,6 +13,7 @@ import numpy as np import chainerrl +from chainerrl.misc import random_seed class AbnormalExitWarning(Warning): From 03f46ccea346e193d8515abb4da38202e05a6efb Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 19 Jun 2019 17:34:13 +0900 Subject: [PATCH 06/13] Check if all the keys are consumed --- chainerrl/misc/async_.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/chainerrl/misc/async_.py b/chainerrl/misc/async_.py index f492b3996..cb43fcbf7 100644 --- a/chainerrl/misc/async_.py +++ b/chainerrl/misc/async_.py @@ -71,15 +71,19 @@ def set_shared_params(a, b): b (dict): dict that consists of (param_name, multiprocessing.Array) """ assert isinstance(a, chainer.Link) + remaining_keys = set(b.keys()) for param_name, param in a.namedparams(): if param_name in b: shared_param = b[param_name] param.array = np.frombuffer( shared_param, dtype=param.dtype).reshape(param.shape) + remaining_keys.remove(param_name) for persistent_name, _ in chainerrl.misc.namedpersistent(a): if persistent_name in b: _set_persistent_values_recursively( a, persistent_name, b[persistent_name]) + remaining_keys.remove(persistent_name) + assert not remaining_keys def make_params_not_shared(a): From 2b9066d005948a9418b94bb7dd3204c45f2642c1 Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 19 Jun 2019 17:35:38 +0900 Subject: [PATCH 07/13] Fix error in test --- tests/misc_tests/test_async.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/misc_tests/test_async.py b/tests/misc_tests/test_async.py index 48b863f26..6b260e910 100644 --- a/tests/misc_tests/test_async.py +++ b/tests/misc_tests/test_async.py @@ -223,10 +223,8 @@ def test_shared_link(self): model_a = chainer.ChainList(head.copy(), L.Linear(2, 3)) model_b = chainer.ChainList(head.copy(), L.Linear(2, 4)) - a_arrays = async_.extract_params_as_shared_arrays( - chainer.ChainList(model_a)) - b_arrays = async_.extract_params_as_shared_arrays( - chainer.ChainList(model_b)) + a_arrays = async_.extract_params_as_shared_arrays(model_a) + b_arrays = async_.extract_params_as_shared_arrays(model_b) print(('model_a shared_arrays', a_arrays)) print(('model_b shared_arrays', b_arrays)) From ae87f4170312c662a6d794cda06a2665024bf622 Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 19 Jun 2019 18:20:42 +0900 Subject: [PATCH 08/13] Fix error of splitting persistent_name --- chainerrl/misc/async_.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/chainerrl/misc/async_.py b/chainerrl/misc/async_.py index cb43fcbf7..b244534c3 100644 --- a/chainerrl/misc/async_.py +++ b/chainerrl/misc/async_.py @@ -53,14 +53,13 @@ def _set_persistent_values_recursively(link, persistent_name, shared_array): else: assert isinstance(link, (chainer.Chain, chainer.ChainList)) assert '/' in persistent_name + child_name, remaining = persistent_name.split('/', maxsplit=1) if isinstance(link, chainer.Chain): - child_name, remaining = persistent_name.split('/') _set_persistent_values_recursively( getattr(link, child_name), remaining, shared_array) else: - child_idx, remaining = persistent_name.split('/') _set_persistent_values_recursively( - link._children[int(child_idx)], remaining, shared_array) + link._children[int(child_name)], remaining, shared_array) def set_shared_params(a, b): From ae4b6be2b5f11df4a9181d441aff48bfbc2506cf Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 19 Jun 2019 18:21:48 +0900 Subject: [PATCH 09/13] Test with a deeper chain structure --- tests/misc_tests/test_async.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/misc_tests/test_async.py b/tests/misc_tests/test_async.py index 6b260e910..48fe8d644 100644 --- a/tests/misc_tests/test_async.py +++ b/tests/misc_tests/test_async.py @@ -142,22 +142,22 @@ def test_share_params_chain_list(self): model_a = chainer.ChainList( L.BatchNormalization(3), - L.Linear(3, 5), + chainer.ChainList(L.Linear(3, 5)), ) arrays = async_.share_params_as_shared_arrays(model_a) assert isinstance(arrays, dict) assert set(arrays.keys()) == { '/0/gamma', '/0/beta', '/0/avg_mean', '/0/avg_var', '/0/N', - '/1/W', '/1/b'} + '/1/0/W', '/1/0/b'} model_b = chainer.ChainList( L.BatchNormalization(3), - L.Linear(3, 5), + chainer.ChainList(L.Linear(3, 5)), ) model_c = chainer.ChainList( L.BatchNormalization(3), - L.Linear(3, 5), + chainer.ChainList(L.Linear(3, 5)), ) async_.set_shared_params(model_b, arrays) From 5755ccb11ba1daae3cc92f0a0ab8faf0f904ceef Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 6 Nov 2019 21:12:06 +0900 Subject: [PATCH 10/13] Fix the no keyword argument error of python 2 --- chainerrl/misc/async_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chainerrl/misc/async_.py b/chainerrl/misc/async_.py index b244534c3..922b46744 100644 --- a/chainerrl/misc/async_.py +++ b/chainerrl/misc/async_.py @@ -53,7 +53,7 @@ def _set_persistent_values_recursively(link, persistent_name, shared_array): else: assert isinstance(link, (chainer.Chain, chainer.ChainList)) assert '/' in persistent_name - child_name, remaining = persistent_name.split('/', maxsplit=1) + child_name, remaining = persistent_name.split('/', 1) if isinstance(link, chainer.Chain): _set_persistent_values_recursively( getattr(link, child_name), remaining, shared_array) From 5256c937989bb408527c489755bed5d6f0707159 Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 6 Nov 2019 22:01:57 +0900 Subject: [PATCH 11/13] Explain why a 1-dim array is used --- chainerrl/misc/async_.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chainerrl/misc/async_.py b/chainerrl/misc/async_.py index 922b46744..ffcf98e43 100644 --- a/chainerrl/misc/async_.py +++ b/chainerrl/misc/async_.py @@ -133,6 +133,8 @@ def extract_params_as_shared_arrays(link): typecode, persistent.ravel()) else: assert np.isscalar(persistent) + # Wrap by a 1-dim array because multiprocessing.RawArray does not + # accept a 0-dim array. persistent_as_array = np.asarray([persistent]) typecode = persistent_as_array.dtype.char shared_arrays[persistent_name] = mp.RawArray( From 4aaf7664f2a422089696c6741e43531996c6e508 Mon Sep 17 00:00:00 2001 From: Yasuhiro Fujita Date: Wed, 6 Nov 2019 23:00:16 +0900 Subject: [PATCH 12/13] Update chainerrl/misc/async_.py Co-Authored-By: Toshiki Kataoka --- chainerrl/misc/async_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chainerrl/misc/async_.py b/chainerrl/misc/async_.py index ffcf98e43..605e47bb0 100644 --- a/chainerrl/misc/async_.py +++ b/chainerrl/misc/async_.py @@ -59,7 +59,7 @@ def _set_persistent_values_recursively(link, persistent_name, shared_array): getattr(link, child_name), remaining, shared_array) else: _set_persistent_values_recursively( - link._children[int(child_name)], remaining, shared_array) + link[int(child_name)], remaining, shared_array) def set_shared_params(a, b): From b82b7104b4460302c2cb42a44239bfe20f1d4073 Mon Sep 17 00:00:00 2001 From: muupan Date: Thu, 7 Nov 2019 17:57:38 +0900 Subject: [PATCH 13/13] Simplify namedpersistent by _namedchildren --- chainerrl/misc/namedpersistent.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/chainerrl/misc/namedpersistent.py b/chainerrl/misc/namedpersistent.py index 237a3ffff..1439a2d26 100644 --- a/chainerrl/misc/namedpersistent.py +++ b/chainerrl/misc/namedpersistent.py @@ -9,6 +9,15 @@ import chainer +def _namedchildren(link): + if isinstance(link, chainer.Chain): + for name in sorted(link._children): + yield name, link.__dict__[name] + elif isinstance(link, chainer.ChainList): + for idx, child in enumerate(link._children): + yield str(idx), child + + def namedpersistent(link): """Return a generator of all (path, persistent) pairs for a given link. @@ -25,13 +34,7 @@ def namedpersistent(link): d = link.__dict__ for name in sorted(link._persistent): yield '/' + name, d[name] - if isinstance(link, chainer.Chain): - for name in sorted(link._children): - prefix = '/' + name - for path, persistent in namedpersistent(d[name]): - yield prefix + path, persistent - elif isinstance(link, chainer.ChainList): - for idx, link in enumerate(link._children): - prefix = '/{}'.format(idx) - for path, persistent in namedpersistent(link): - yield prefix + path, persistent + for name, child in _namedchildren(link): + prefix = '/' + name + for path, persistent in namedpersistent(child): + yield prefix + path, persistent