Skip to content

Commit

Permalink
Moves CUDAPoolExecutor into rebar
Browse files Browse the repository at this point in the history
  • Loading branch information
andyljones committed Mar 2, 2021
1 parent 87897c3 commit adcc336
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 34 deletions.
28 changes: 24 additions & 4 deletions boardlaw/arena/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from rebar import arrdict, dotdict
from logging import getLogger
from itertools import combinations

log = getLogger(__name__)

Expand Down Expand Up @@ -92,7 +93,7 @@ def suggest(self, seats):

return name, mask, self.live[mask]

class Evaluator:
class ChunkEvaluator:
# Idea: keep lots and lots of envs in memory at once, play
# every agent against every agent simultaneously

Expand Down Expand Up @@ -183,6 +184,26 @@ def step(self):
results = self.record(transitions, live, start, end)
return results

def evaluate(worldfunc, agentfunc, games, n_envs_per, chunksize=64):
assert games.index == games.columns

names = list(games.index)
chunks = [names[i:i+chunksize] for i in range(0, len(names), chunksize)]

jobs = []
# Diagonal pieces
for chunk in chunks:
jobs.append(games.loc[chunk, chunk])

# Skew pieces
for first, second in combinations(chunks, 2):
combined = first + second
subgames = games.loc[combined, combined].copy()
subgames.loc[first, first] = n_envs_per
subgames.loc[second, second] = n_envs_per
jobs.append(subgames)
pass

class MockAgent:

def __init__(self, id):
Expand Down Expand Up @@ -259,8 +280,7 @@ def test_tracker():
assert len(counts) == len(agents)*(len(agents)-1)
assert set(counts.values()) == {n_envs_per}


def test_evaluator():
def test_chunk_evaluator():
from pavlov import runs, storage
from boardlaw.arena import common

Expand All @@ -273,7 +293,7 @@ def test_evaluator():
agents = {k: agents[k] for k in list(agents)[:100]}

worldfunc = lambda n_envs: common.worlds(df.index[0], n_envs, device='cuda')
evaluator = Evaluator(worldfunc, agents, 512)
evaluator = ChunkEvaluator(worldfunc, agents, 512)

from IPython import display

Expand Down
25 changes: 2 additions & 23 deletions grid/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from geotorch.exceptions import InManifoldError
from logging import getLogger
from . import data, asymdata
from rebar.parallel import CUDAPoolExecutor

log = getLogger(__name__)

Expand Down Expand Up @@ -43,28 +44,6 @@ def update(games, wins, results):
wins.loc[result.names[1], result.names[0]] += result.wins[1]
return games, wins

class DeviceExecutor(ProcessPoolExecutor):
# Passes the index of the process to the init, so that we can balance CUDA jobs

def _adjust_process_count(self):
from concurrent.futures.process import _process_worker
for i in range(len(self._processes), self._max_workers):
p = self._mp_context.Process(
target=_process_worker,
args=(self._call_queue,
self._result_queue,
self._initializer,
(*self._initargs, i)))
p.start()
self._processes[p.pid] = p


def init(i):
import os
#TODO: Support variable number of GPUs
device = i % 2
os.environ['CUDA_VISIBLE_DEVICES'] = str(device)

def solve(games, wins, soln=None):
try:
return activelo.solve(games, wins, soln=soln)
Expand All @@ -88,7 +67,7 @@ def activelo_eval(boardsize=9, n_workers=6):

solver, soln, σ = None, None, None
futures = {}
with DeviceExecutor(n_workers+1, initializer=init) as pool:
with CUDAPoolExecutor(n_workers+1) as pool:
while True:
if solver is None:
log.info('Submitting solve task')
Expand Down
40 changes: 33 additions & 7 deletions rebar/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,51 @@ def submit(self, f, *args, **kwargs):
future.set_result(f(*args, **kwargs))
return future

class CUDAPoolExecutor(ProcessPoolExecutor):
# Passes the index of the process to the init, so that we can balance CUDA jobs

@staticmethod
def _device_init(i):
import os
import torch
device = i % torch.cuda.device_count()
os.environ['CUDA_VISIBLE_DEVICES'] = str(device)

def _adjust_process_count(self):
assert self._init_args == (), 'Device executor doesn\'t currently support custom initializers
from concurrent.futures.process import _process_worker
for i in range(len(self._processes), self._max_workers):
p = self._mp_context.Process(
target=_process_worker,
args=(self._call_queue,
self._result_queue,
self._device_init,
(i,)))
p.start()
self._processes[p.pid] = p

@contextmanager
def VariableExecutor(N=None, processes=True, **kwargs):
def VariableExecutor(N=None, executor='process', **kwargs):
"""An executor that can be easily switched between serial, thread and parallel execution.
If N=0, a serial executor will be used.
"""

N = multiprocessing.cpu_count() if N is None else N

if N == 0:
executor = SerialExecutor
elif processes:
executor = ProcessPoolExecutor
else:
executor = ThreadPoolExecutor
executor = 'serial'

executors = {
'process': ProcessPoolExecutor,
'thread': ThreadPoolExecutor,
'cuda': CUDAPoolExecutor}
executor = executors[executor]

log.debug('Launching a {} with {} processes'.format(executor.__name__, N))
with executor(N, **kwargs) as pool:
yield pool



@contextmanager
def parallel(f, progress=True, **kwargs):
"""Sugar for using the VariableExecutor. Call as
Expand Down

0 comments on commit adcc336

Please sign in to comment.