-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathasyncio.py
93 lines (78 loc) · 2.69 KB
/
asyncio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Asynchronous progressbar decorator for iterators.
Includes a default `range` iterator printing to `stderr`.
Usage:
>>> from tqdm.asyncio import trange, tqdm
>>> async for i in trange(10):
... ...
"""
import asyncio
from sys import version_info
from .std import tqdm as std_tqdm
__author__ = {"github.com/": ["casperdcl"]}
__all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange']
class tqdm_asyncio(std_tqdm):
"""
Asynchronous-friendly version of tqdm.
"""
def __init__(self, iterable=None, *args, **kwargs):
super().__init__(iterable, *args, **kwargs)
self.iterable_awaitable = False
if iterable is not None:
if hasattr(iterable, "__anext__"):
self.iterable_next = iterable.__anext__
self.iterable_awaitable = True
elif hasattr(iterable, "__next__"):
self.iterable_next = iterable.__next__
else:
self.iterable_iterator = iter(iterable)
self.iterable_next = self.iterable_iterator.__next__
def __aiter__(self):
return self
async def __anext__(self):
try:
if self.iterable_awaitable:
res = await self.iterable_next()
else:
res = self.iterable_next()
self.update()
return res
except StopIteration:
self.close()
raise StopAsyncIteration
except BaseException:
self.close()
raise
def send(self, *args, **kwargs):
return self.iterable.send(*args, **kwargs)
@classmethod
def as_completed(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs):
"""
Wrapper for `asyncio.as_completed`.
"""
if total is None:
total = len(fs)
kwargs = {}
if version_info[:2] < (3, 10):
kwargs['loop'] = loop
yield from cls(asyncio.as_completed(fs, timeout=timeout, **kwargs),
total=total, **tqdm_kwargs)
@classmethod
async def gather(cls, *fs, loop=None, timeout=None, total=None, **tqdm_kwargs):
"""
Wrapper for `asyncio.gather`.
"""
async def wrap_awaitable(i, f):
return i, await f
ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)]
res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout,
total=total, **tqdm_kwargs)]
return [i for _, i in sorted(res)]
def tarange(*args, **kwargs):
"""
A shortcut for `tqdm.asyncio.tqdm(range(*args), **kwargs)`.
"""
return tqdm_asyncio(range(*args), **kwargs)
# Aliases
tqdm = tqdm_asyncio
trange = tarange