diff --git a/submitit/core/job_environment.py b/submitit/core/job_environment.py index 0f50502..89e044e 100644 --- a/submitit/core/job_environment.py +++ b/submitit/core/job_environment.py @@ -154,7 +154,10 @@ def _handle_signals(self, paths: JobPaths, submission: DelayedSubmission) -> Non # A priori we don't need other signals anymore, # but still log them to make it easier to debug. signal.signal(signal.SIGTERM, handler.bypass) - signal.signal(signal.SIGCONT, handler.bypass) + try: + signal.signal(signal.SIGCONT, handler.bypass) + except AttributeError: # no SIGCONT on Windows + pass # pylint: disable=unused-argument def _requeue(self, countdown: int) -> None: diff --git a/submitit/core/utils.py b/submitit/core/utils.py index ab0bd01..7c6d4e9 100644 --- a/submitit/core/utils.py +++ b/submitit/core/utils.py @@ -16,6 +16,7 @@ import tarfile import typing as tp from pathlib import Path +from threading import Thread import cloudpickle @@ -233,7 +234,6 @@ def cloudpickle_dump(obj: tp.Any, filename: tp.Union[str, Path]) -> None: cloudpickle.dump(obj, ofile, pickle.HIGHEST_PROTOCOL) -# pylint: disable=too-many-locals def copy_process_streams( process: subprocess.Popen, stdout: io.StringIO, stderr: io.StringIO, verbose: bool = False ): @@ -250,10 +250,56 @@ def raw(stream: tp.Optional[tp.IO[bytes]]) -> tp.IO[bytes]: return stream p_stdout, p_stderr = raw(process.stdout), raw(process.stderr) - stream_by_fd: tp.Dict[int, tp.Tuple[tp.IO[bytes], io.StringIO, tp.IO[str]]] = { - p_stdout.fileno(): (p_stdout, stdout, sys.stdout), - p_stderr.fileno(): (p_stderr, stderr, sys.stderr), + stream_by_fd: tp.Dict[int, tp.Tuple[tp.IO[bytes], io.StringIO, tp.Optional[tp.IO[str]]]] = { + p_stdout.fileno(): (p_stdout, stdout, sys.stdout if verbose else None), + p_stderr.fileno(): (p_stderr, stderr, sys.stderr if verbose else None), } + + if os.name == "nt": + _copy_streams_threaded(stream_by_fd) + else: + _copy_streams_select_pipes(stream_by_fd) + + +def _read_and_copy(p_stream: tp.IO[bytes], string: io.StringIO, std: tp.Optional[tp.IO[str]]) -> bool: + """ + Returns False iff there is definitely no more to read. + """ + raw_buf = p_stream.read(2**16) + if not raw_buf: + return False + buf = raw_buf.decode() + string.write(buf) + string.flush() + if std is not None: + std.write(buf) + std.flush() + return True + + +def _read_and_copy_whole_stream_blocking( + p_stream: tp.IO[bytes], string: io.StringIO, std: tp.IO[str] +) -> None: + while True: + if not _read_and_copy(p_stream, string, std): + return + + +def _copy_streams_threaded( + stream_by_fd: tp.Dict[int, tp.Tuple[tp.IO[bytes], io.StringIO, tp.Optional[tp.IO[str]]]] +) -> None: + threads: tp.List[Thread] = [] + for p_stream, string, std in stream_by_fd.values(): + t = Thread(target=_read_and_copy_whole_stream_blocking, args=(p_stream, string, std), daemon=True) + t.start() + threads.append(t) + for t in threads: + t.join() + + +def _copy_streams_select_pipes( + stream_by_fd: tp.Dict[int, tp.Tuple[tp.IO[bytes], io.StringIO, tp.Optional[tp.IO[str]]]] +) -> None: fds = list(stream_by_fd.keys()) poller = select.poll() for fd in stream_by_fd: @@ -263,17 +309,9 @@ def raw(stream: tp.Optional[tp.IO[bytes]]) -> tp.IO[bytes]: ready = poller.poll() for fd, _ in ready: p_stream, string, std = stream_by_fd[fd] - raw_buf = p_stream.read(2**16) - if not raw_buf: + if not _read_and_copy(p_stream, string, std): fds.remove(fd) poller.unregister(fd) - continue - buf = raw_buf.decode() - string.write(buf) - string.flush() - if verbose: - std.write(buf) - std.flush() # used in "_core", so cannot be in "helpers"