From e80ac32271bf0b8323b278129e0cbb895a92f46c Mon Sep 17 00:00:00 2001 From: Christopher Hesse <48501609+cshesse@users.noreply.github.com> Date: Sat, 1 Feb 2020 13:07:35 -0800 Subject: [PATCH] typeguard can't handle from __future__ import annotations --- blobfile/ops.py | 124 +++++++++++++++++++++++++++++++------------ blobfile/ops_test.py | 7 ++- env.yaml | 1 - run-tests.py | 2 +- setup.py | 4 +- 5 files changed, 97 insertions(+), 41 deletions(-) diff --git a/blobfile/ops.py b/blobfile/ops.py index 4a7d4b2..34d7891 100644 --- a/blobfile/ops.py +++ b/blobfile/ops.py @@ -1,3 +1,6 @@ +# https://mypy.readthedocs.io/en/stable/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime +from __future__ import annotations + import calendar import os import tempfile @@ -31,6 +34,7 @@ cast, NamedTuple, List, + Union, ) from typing_extensions import Literal, Protocol, runtime_checkable @@ -713,7 +717,55 @@ def _glob_full(pattern: str) -> Iterator[str]: yield _strip_slash(path) -def glob(pattern: str) -> Iterator[str]: +class _GlobTask(NamedTuple): + cur: str + rem: Sequence[str] + +class _GlobEntry(NamedTuple): + path: str + +class _GlobTaskComplete(NamedTuple): + pass + +def _process_glob_task(root: str, t:_GlobTask) -> Iterator[Union[_GlobTask, _GlobEntry]]: + cur = t.cur + t.rem[0] + rem = t.rem[1:] + if "**" in cur: + for path in _glob_full(root + cur + "".join(rem)): + yield _GlobEntry(path) + elif "*" in cur: + re_pattern = _compile_pattern(root + cur) + prefix, _, _ = cur.partition("*") + path = root + prefix + for blobpath in _list_blobs(path=path, delimiter="/"): + # in the case of dirname/* we should not return the path dirname/ + if blobpath == path and blobpath.endswith("/"): + # we matched the parent directory + continue + if bool(re_pattern.match(blobpath)): + if len(rem) == 0: + yield _GlobEntry(_strip_slash(blobpath)) + else: + assert path.startswith(root) + yield _GlobTask(blobpath[len(root):], rem) + else: + if len(rem) == 0: + path = root+cur + if exists(path): + yield _GlobEntry(_strip_slash(path)) + else: + yield _GlobTask(cur, rem) + + +def _glob_worker(root: str, tasks: mp.Queue[_GlobTask], results: mp.Queue[Union[_GlobEntry, _GlobTask, _GlobTaskComplete]]) -> None: + while True: + t = tasks.get() + for r in _process_glob_task(root=root, t=t): + results.put(r) + results.put(_GlobTaskComplete()) + + +def glob(pattern: str, parallel: bool = False) -> Iterator[str]: """ Find files and directories matching a pattern. Supports * and ** @@ -721,6 +773,9 @@ def glob(pattern: str) -> Iterator[str]: that is not quite the same as remote paths. See https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames#different-behavior-for-dot-files-in-local-file-system_1 for more information. Globs can have confusing performance, see https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames#different-behavior-for-dot-files-in-local-file-system_1 for more information. + + You can set `parallel=True` to use multiple processes to perform the glob. It's likely + that the results will no longer be in order if that happens. """ assert "?" not in pattern and "[" not in pattern and "]" not in pattern @@ -744,37 +799,38 @@ def glob(pattern: str) -> Iterator[str]: assert "*" not in account and "*" not in container root = f"as://{account}-{container}/" - stack = [] - stack.append(("", _split_path(blob_prefix))) - while len(stack) > 0: - cur, rem = stack.pop() - part = rem[0] - if "**" in part: - yield from _glob_full(root + cur + "".join(rem)) - elif "*" in part: - re_pattern = _compile_pattern(root + cur + part) - prefix, _, _ = part.partition("*") - path = root + cur + prefix - for blobpath in _list_blobs(path=path, delimiter="/"): - # in the case of dirname/* we should not return the path dirname/ - if blobpath == path and blobpath.endswith("/"): - # we matched the parent directory - continue - if bool(re_pattern.match(blobpath)): - if len(rem) == 1: - yield _strip_slash(blobpath) - else: - assert path.startswith(root) - stack.append((blobpath[len(root):], rem[1:])) - else: - cur += part - if len(rem) == 1: - path = root+cur - if exists(path): - yield _strip_slash(path) - else: - stack.append((cur, rem[1:])) - return + initial_task = _GlobTask("", _split_path(blob_prefix)) + + if parallel: + tasks = mp.Queue() + tasks.put(initial_task) + tasks_enqueued = 1 + results = mp.Queue() + + tasks_done = 0 + with mp.Pool(initializer=_glob_worker, initargs=(root, tasks, results) + ): + while tasks_done < tasks_enqueued: + r = results.get() + if isinstance(r, _GlobEntry): + yield r.path + elif isinstance(r, _GlobTask): + tasks.put(r) + tasks_enqueued += 1 + elif isinstance(r, _GlobTaskComplete): + tasks_done += 1 + else: + raise Exception("invalid result") + else: + dq: collections.deque[_GlobTask] = collections.deque() + dq.append(initial_task) + while len(dq) > 0: + t = dq.popleft() + for r in _process_glob_task(root=root, t=t): + if isinstance(r, _GlobEntry): + yield r.path + else: + dq.append(r) else: raise Exception("unrecognized path") @@ -946,9 +1002,7 @@ def listdir(path: str, shard_prefix_length: int = 0) -> Iterator[str]: raise Exception("unrecognized path") -# the Queues cannot be annotated without breaking pyimport when running pytest -# they should be mp.Queue[Tuple[str, str, bool]] and mp.Queue[str] -def _sharded_listdir_worker(prefixes: mp.Queue, items: mp.Queue) -> None: +def _sharded_listdir_worker(prefixes: mp.Queue[Tuple[str, str, bool]], items: mp.Queue[Optional[str]]) -> None: while True: base, prefix, exact = prefixes.get(True) if exact: diff --git a/blobfile/ops_test.py b/blobfile/ops_test.py index 0feacee..a58ced0 100644 --- a/blobfile/ops_test.py +++ b/blobfile/ops_test.py @@ -404,7 +404,10 @@ def test_walk(ctx): @pytest.mark.parametrize( "ctx", [_get_temp_local_path, _get_temp_gcs_path, _get_temp_as_path] ) -def test_glob(ctx): +@pytest.mark.parametrize( + "parallel", [False, True] +) +def test_glob(ctx, parallel): contents = b"meow!" with ctx() as path: dirpath = bf.dirname(path) @@ -417,7 +420,7 @@ def test_glob(ctx): def assert_listing_equal(path, desired): desired = sorted([bf.join(dirpath, p) for p in desired]) - actual = sorted(list(bf.glob(path))) + actual = sorted(list(bf.glob(path, parallel=parallel))) assert actual == desired, f"{actual} != {desired}" assert_listing_equal(bf.join(dirpath, "*b"), ["ab", "bb"]) diff --git a/env.yaml b/env.yaml index 1f99580..4b70452 100644 --- a/env.yaml +++ b/env.yaml @@ -17,5 +17,4 @@ dependencies: - imageio-ffmpeg==0.3.0 - xmltodict==0.12.0 - azure-cli==2.0.75 - - typeguard==2.6.1 - typing-extensions==3.7.4.1 diff --git a/run-tests.py b/run-tests.py index ae9e314..d67da7d 100644 --- a/run-tests.py +++ b/run-tests.py @@ -3,5 +3,5 @@ sp.run(["pip", "install", "-e", "."], check=True) sp.run( - ["pytest", "blobfile", "--typeguard-packages=blobfile"] + sys.argv[1:], check=True + ["pytest", "blobfile"] + sys.argv[1:], check=True ) diff --git a/setup.py b/setup.py index 5cf3cbd..2ea64d8 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def run(self): setup_dict = dict( name="blobfile", - version="0.11.0", + version="0.12.0", description="Read GCS and local paths with the same interface, clone of tensorflow.io.gfile", long_description=README, long_description_content_type="text/markdown", @@ -45,7 +45,7 @@ def run(self): "typeguard", ] }, - python_requires=">=3.6.0", + python_requires=">=3.7.0", # indicate that we have type information package_data={"blobfile": ["*.pyi", "py.typed"]}, # mypy cannot find type information in zip files