Skip to content

Commit

Permalink
typeguard can't handle from __future__ import annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
christopher-hesse committed Feb 1, 2020
1 parent fa2cdba commit e80ac32
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 41 deletions.
124 changes: 89 additions & 35 deletions blobfile/ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# https://mypy.readthedocs.io/en/stable/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
from __future__ import annotations

import calendar
import os
import tempfile
Expand Down Expand Up @@ -31,6 +34,7 @@
cast,
NamedTuple,
List,
Union,
)
from typing_extensions import Literal, Protocol, runtime_checkable

Expand Down Expand Up @@ -713,14 +717,65 @@ def _glob_full(pattern: str) -> Iterator[str]:
yield _strip_slash(path)


def glob(pattern: str) -> Iterator[str]:
class _GlobTask(NamedTuple):
cur: str
rem: Sequence[str]

class _GlobEntry(NamedTuple):
path: str

class _GlobTaskComplete(NamedTuple):
pass

def _process_glob_task(root: str, t:_GlobTask) -> Iterator[Union[_GlobTask, _GlobEntry]]:
cur = t.cur + t.rem[0]
rem = t.rem[1:]
if "**" in cur:
for path in _glob_full(root + cur + "".join(rem)):
yield _GlobEntry(path)
elif "*" in cur:
re_pattern = _compile_pattern(root + cur)
prefix, _, _ = cur.partition("*")
path = root + prefix
for blobpath in _list_blobs(path=path, delimiter="/"):
# in the case of dirname/* we should not return the path dirname/
if blobpath == path and blobpath.endswith("/"):
# we matched the parent directory
continue
if bool(re_pattern.match(blobpath)):
if len(rem) == 0:
yield _GlobEntry(_strip_slash(blobpath))
else:
assert path.startswith(root)
yield _GlobTask(blobpath[len(root):], rem)
else:
if len(rem) == 0:
path = root+cur
if exists(path):
yield _GlobEntry(_strip_slash(path))
else:
yield _GlobTask(cur, rem)


def _glob_worker(root: str, tasks: mp.Queue[_GlobTask], results: mp.Queue[Union[_GlobEntry, _GlobTask, _GlobTaskComplete]]) -> None:
while True:
t = tasks.get()
for r in _process_glob_task(root=root, t=t):
results.put(r)
results.put(_GlobTaskComplete())


def glob(pattern: str, parallel: bool = False) -> Iterator[str]:
"""
Find files and directories matching a pattern. Supports * and **
For local paths, this function uses glob.glob() which has special handling for * and **
that is not quite the same as remote paths. See https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames#different-behavior-for-dot-files-in-local-file-system_1 for more information.
Globs can have confusing performance, see https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames#different-behavior-for-dot-files-in-local-file-system_1 for more information.
You can set `parallel=True` to use multiple processes to perform the glob. It's likely
that the results will no longer be in order if that happens.
"""
assert "?" not in pattern and "[" not in pattern and "]" not in pattern

Expand All @@ -744,37 +799,38 @@ def glob(pattern: str) -> Iterator[str]:
assert "*" not in account and "*" not in container
root = f"as://{account}-{container}/"

stack = []
stack.append(("", _split_path(blob_prefix)))
while len(stack) > 0:
cur, rem = stack.pop()
part = rem[0]
if "**" in part:
yield from _glob_full(root + cur + "".join(rem))
elif "*" in part:
re_pattern = _compile_pattern(root + cur + part)
prefix, _, _ = part.partition("*")
path = root + cur + prefix
for blobpath in _list_blobs(path=path, delimiter="/"):
# in the case of dirname/* we should not return the path dirname/
if blobpath == path and blobpath.endswith("/"):
# we matched the parent directory
continue
if bool(re_pattern.match(blobpath)):
if len(rem) == 1:
yield _strip_slash(blobpath)
else:
assert path.startswith(root)
stack.append((blobpath[len(root):], rem[1:]))
else:
cur += part
if len(rem) == 1:
path = root+cur
if exists(path):
yield _strip_slash(path)
else:
stack.append((cur, rem[1:]))
return
initial_task = _GlobTask("", _split_path(blob_prefix))

if parallel:
tasks = mp.Queue()
tasks.put(initial_task)
tasks_enqueued = 1
results = mp.Queue()

tasks_done = 0
with mp.Pool(initializer=_glob_worker, initargs=(root, tasks, results)
):
while tasks_done < tasks_enqueued:
r = results.get()
if isinstance(r, _GlobEntry):
yield r.path
elif isinstance(r, _GlobTask):
tasks.put(r)
tasks_enqueued += 1
elif isinstance(r, _GlobTaskComplete):
tasks_done += 1
else:
raise Exception("invalid result")
else:
dq: collections.deque[_GlobTask] = collections.deque()
dq.append(initial_task)
while len(dq) > 0:
t = dq.popleft()
for r in _process_glob_task(root=root, t=t):
if isinstance(r, _GlobEntry):
yield r.path
else:
dq.append(r)
else:
raise Exception("unrecognized path")

Expand Down Expand Up @@ -946,9 +1002,7 @@ def listdir(path: str, shard_prefix_length: int = 0) -> Iterator[str]:
raise Exception("unrecognized path")


# the Queues cannot be annotated without breaking pyimport when running pytest
# they should be mp.Queue[Tuple[str, str, bool]] and mp.Queue[str]
def _sharded_listdir_worker(prefixes: mp.Queue, items: mp.Queue) -> None:
def _sharded_listdir_worker(prefixes: mp.Queue[Tuple[str, str, bool]], items: mp.Queue[Optional[str]]) -> None:
while True:
base, prefix, exact = prefixes.get(True)
if exact:
Expand Down
7 changes: 5 additions & 2 deletions blobfile/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,10 @@ def test_walk(ctx):
@pytest.mark.parametrize(
"ctx", [_get_temp_local_path, _get_temp_gcs_path, _get_temp_as_path]
)
def test_glob(ctx):
@pytest.mark.parametrize(
"parallel", [False, True]
)
def test_glob(ctx, parallel):
contents = b"meow!"
with ctx() as path:
dirpath = bf.dirname(path)
Expand All @@ -417,7 +420,7 @@ def test_glob(ctx):

def assert_listing_equal(path, desired):
desired = sorted([bf.join(dirpath, p) for p in desired])
actual = sorted(list(bf.glob(path)))
actual = sorted(list(bf.glob(path, parallel=parallel)))
assert actual == desired, f"{actual} != {desired}"

assert_listing_equal(bf.join(dirpath, "*b"), ["ab", "bb"])
Expand Down
1 change: 0 additions & 1 deletion env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ dependencies:
- imageio-ffmpeg==0.3.0
- xmltodict==0.12.0
- azure-cli==2.0.75
- typeguard==2.6.1
- typing-extensions==3.7.4.1
2 changes: 1 addition & 1 deletion run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

sp.run(["pip", "install", "-e", "."], check=True)
sp.run(
["pytest", "blobfile", "--typeguard-packages=blobfile"] + sys.argv[1:], check=True
["pytest", "blobfile"] + sys.argv[1:], check=True
)
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run(self):

setup_dict = dict(
name="blobfile",
version="0.11.0",
version="0.12.0",
description="Read GCS and local paths with the same interface, clone of tensorflow.io.gfile",
long_description=README,
long_description_content_type="text/markdown",
Expand All @@ -45,7 +45,7 @@ def run(self):
"typeguard",
]
},
python_requires=">=3.6.0",
python_requires=">=3.7.0",
# indicate that we have type information
package_data={"blobfile": ["*.pyi", "py.typed"]},
# mypy cannot find type information in zip files
Expand Down

0 comments on commit e80ac32

Please sign in to comment.