Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
use submitit for local execution
Browse files Browse the repository at this point in the history
  • Loading branch information
gwenzek committed May 15, 2020
1 parent 8b0810a commit 432fb71
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 119 deletions.
253 changes: 137 additions & 116 deletions cc_net/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,34 @@
# LICENSE file in the root directory of this source tree.
#

import functools
import itertools
import logging
import multiprocessing
import os
import sys
import time
import warnings
from pathlib import Path
from typing import Callable, Dict, Iterable, Optional, Sized
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Sized

from typing_extensions import Protocol

import submitit


class Executor(Protocol):
def __call__(self, function: Callable[..., str], *args: Iterable) -> None:
...


class SubmititRetryOnTimeout(submitit.helpers.Checkpointable):
def __init__(self, fn: Callable):
self.fn = fn

def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)


def get_executor(
name: str,
log_dir: Path,
Expand All @@ -37,48 +48,13 @@ def get_executor(
options.update(
{kv.split("=", 1)[0]: kv.split("=", 1)[1] for kv in execution.split(",")[1:]}
)
if execution_mode == "slurm":
ex = get_submitit_executor(
name, log_dir, timeout_hour, mem_gb, cpus, task_parallelism, options
)
if ex is not None:
return ex

if execution_mode == "mp":
return MpExecutor(log_dir, cpus, task_parallelism)

return debug_executor


def get_submitit_executor(
name: str,
log_dir: Path,
timeout_hour: float,
mem_gb: int,
cpus: int,
task_parallelism: int,
options: dict,
) -> Optional[Executor]:
try:
import submitit

ex = submitit.AutoExecutor(log_dir)
except ImportError:
warnings.warn(f"Failed to import submitit, will try another executor.")
return None
except RuntimeError as e:
warnings.warn(
f"Failed to create submitit.AutoExecutor, will try another executor. ({e})"
)
return None

class SubmititRetryOnTimeout(submitit.helpers.Checkpointable):
def __init__(self, fn: Callable):
self.fn = fn

def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
warnings.warn("Execution mode 'mp' is deprecated, use 'local'.")
execution_mode = "local"

cluster = None if execution_mode == "auto" else execution_mode
ex = submitit.AutoExecutor(log_dir, cluster=execution_mode)
ex.update_parameters(
name=name,
timeout_min=int(timeout_hour * 60),
Expand All @@ -87,43 +63,51 @@ def __call__(self, *args, **kwargs):
slurm_array_parallelism=task_parallelism,
**options,
)
if ex.cluster == "local":
# LocalExecutor doesn't respect task_parallelism
return functools.partial(custom_map_array, ex, task_parallelism)
if ex.cluster == "debug":
return debug_executor

return functools.partial(map_array_and_wait, ex)


def submit_and_wait(function: Callable[..., str], *args: Iterable):
f_name = function.__name__

assert len(args) > 0, f"No arguments passed to {f_name}"
approx_length = _approx_length(*args)

print(f"Submitting {f_name} in a job array ({approx_length} jobs)")
jobs = ex.map_array(function, *args)
if not jobs:
return
failed_jobs = []
done = 0
total = len(jobs)
job_array_id = jobs[0].job_id.split("_")[0]
print(f"Started {f_name} in job array {job_array_id} ({len(jobs)} jobs).")
for job in submitit.helpers.as_completed(jobs):
done += 1
e = job.exception()
if not e:
print(f"Finished job {job.job_id} ({done} / {total}).", job.result())
continue

print(f"Failed job {job.job_id} ({done} / {total}):", e)
failed_jobs.append(job)

if failed_jobs:
n_failures = 10
message = f"{len(failed_jobs)} / {done} jobs failed while running {f_name}"
print(message)
for job in failed_jobs[:n_failures]:
print(f"Failed {job.job_id} -> {job.paths.stderr}")
if len(failed_jobs) > n_failures:
print(f"... ({len(failed_jobs) - n_failures} failed job skipped)")
raise Exception(message)

return submit_and_wait
def map_array_and_wait(
ex: submitit.AutoExecutor, function: Callable[..., str], *args: Iterable
):
f_name = function.__name__

assert len(args) > 0, f"No arguments passed to {f_name}"
approx_length = _approx_length(*args)

print(f"Submitting {f_name} in a job array ({approx_length} jobs)")
jobs = ex.map_array(function, *args)
if not jobs:
return
failed_jobs = []
done = 0
total = len(jobs)
job_array_id = jobs[0].job_id.split("_")[0]
print(f"Started {f_name} in job array {job_array_id} ({len(jobs)} jobs).")
for job in submitit.helpers.as_completed(jobs):
done += 1
e = job.exception()
if not e:
print(f"Finished job {job.job_id} ({done} / {total}).", job.result())
continue

print(f"Failed job {job.job_id} ({done} / {total}):", e)
failed_jobs.append(job)

if failed_jobs:
n_failures = 10
message = f"{len(failed_jobs)} / {done} jobs failed while running {f_name}"
print(message)
for job in failed_jobs[:n_failures]:
print(f"Failed {job.job_id} -> {job.paths.stderr}")
if len(failed_jobs) > n_failures:
print(f"... ({len(failed_jobs) - n_failures} failed job skipped)")
raise Exception(message)


def debug_executor(function: Callable[..., Optional[str]], *args: Iterable) -> None:
Expand Down Expand Up @@ -154,42 +138,79 @@ def _approx_length(*args: Iterable):
return -1


GLOBAL_FUNCTIONS: Dict[str, Callable[..., Optional[str]]] = {}


def global_fn(args) -> Optional[str]:
f_name = args[0]
f = GLOBAL_FUNCTIONS[f_name]
return f(*args[1:])


class MpExecutor(Executor):
def __init__(self, log_dir: Path, cpus: int, task_parallelism: int):
self.log_dir = log_dir
if task_parallelism < 0:
task_parallelism = os.cpu_count() or 1
self.processes = min(task_parallelism // cpus, os.cpu_count())

def __call__(self, function: Callable[..., Optional[str]], *args: Iterable):

f_name = function.__name__
global GLOBAL_FUNCTIONS
if f_name in GLOBAL_FUNCTIONS:
assert (
function == GLOBAL_FUNCTIONS[f_name]
), f"Conflicting name between {function} and {GLOBAL_FUNCTIONS[f_name]}"
else:
GLOBAL_FUNCTIONS[f_name] = function

approx_length = _approx_length(*args)

print(
f"Starting {f_name} over {self.processes} processes ({approx_length} tasks)."
)
with multiprocessing.Pool(processes=self.processes) as pool:
i = 0
for message in pool.imap_unordered(
global_fn, zip(itertools.repeat(f_name), *args)
):
i += 1
print(message, f"({i} / {approx_length})")
def custom_map_array(
ex: submitit.AutoExecutor,
parallelism: int,
function: Callable[..., Optional[str]],
*args: Iterable,
) -> None:
f_name = function.__name__
assert len(args) > 0, f"No arguments passed to {f_name}"

jobs_args = list(zip(*args))
total = len(jobs_args)
print(f"Submitting {total} jobs for {f_name}, with parallelism={parallelism}")
enqueued = 0
done = 0
running_jobs: List[submitit.Job] = []
failed_jobs: List[submitit.Job] = []

while done < len(jobs_args):
# Try to queue more job if we have some bandwidth.
if enqueued < total and len(running_jobs) < parallelism:
running_jobs.append(ex.submit(function, *jobs_args[enqueued]))
enqueued += 1
continue

# Else wait for some job to finish
if not running_jobs:
warnings.warn(
f"No more running jobs, yet we submitted only {enqueued} / {total} and finished {done} / {total}"
)
break

job = get_next_job(running_jobs)
running_jobs.remove(job)
done += 1
e = job.exception()
if not e:
print(f"Finished job {job.job_id} ({done} / {total}).", job.result())
continue

print(f"Failed job {job.job_id} ({done} / {total}):", e)
failed_jobs.append(job)

if failed_jobs:
n_failures = 10
message = f"{len(failed_jobs)} / {done} jobs failed while running {f_name}"
print(message)
for job in failed_jobs[:n_failures]:
print(f"Failed {job.job_id} -> {job.paths.stderr}")
if len(failed_jobs) > n_failures:
print(f"... ({len(failed_jobs) - n_failures} failed job skipped)")
raise Exception(message)


def get_next_job(
jobs: Sequence[submitit.Job], poll_frequency: float = 10
) -> submitit.Job:
"""
Waits for any of the job to finish and returns it.
jobs: list of jobs
poll_frequency: frequency in second at which we check job status
"""
start = time.time()
waiting = False
while True:
for job in jobs:
if job.done():
return job
if not waiting:
job_ids = [j.job_id for j in jobs[:4]]
suffix = "..." if len(jobs) > 4 else ""
print(
f"Waiting on {len(jobs)} running jobs. Job ids: {','.join(job_ids)}{suffix}"
)
waiting = True
time.sleep(poll_frequency)
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,13 @@
"func_argparse>=1.1.1",
"psutil>=5.6.3",
"sacremoses",
"submitit>=1.0.0",
"typing_extensions",
],
extras_require={
"dev": ["mypy>=0.730", "pytest", "black", "isort"],
# To use scripts inside cc_net/tools
"tools": ["lxml", "sentence_splitter"],
# TODO: include submitit by default and use local executor
# Allows to run on a SLURM cluster.
"slurm": ["submitit"],
# Memory-efficient hashset.
# This fork only compiles the kind of dict used by cc_net.
# Full version is at https://github.com/atom-moyer/getpy
Expand Down

0 comments on commit 432fb71

Please sign in to comment.