-
Notifications
You must be signed in to change notification settings - Fork 148
/
Copy pathc4.py
162 lines (142 loc) · 8.23 KB
/
c4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
"""C4 (Colossal Cleaned Common Crawl) dataset.
This dataset is a colossal, cleaned version of Common Crawl's web crawl corpus and it is based on
the `Common Crawl <https://commoncrawl.org>`_ dataset.
"""
from typing import Any, Optional
from transformers.models.auto.tokenization_auto import AutoTokenizer
from streaming.base import StreamingDataset
__all__ = ['StreamingC4']
class StreamingC4(StreamingDataset):
"""Implementation of the C4 (Colossal Cleaned Common Crawl) dataset using StreamingDataset.
Args:
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
local (str, optional): Local working directory to download shards to. This is where shards
are cached while they are being used. Uses a temp directory if not set.
StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``.
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
``False``.
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``.
predownload (int, optional): Target number of samples to download per worker in advance
of current sample. Workers will attempt to download ahead by this many samples during,
but not before, training. Recommendation is to provide a value greater than per device
batch size to ensure at-least per device batch size number of samples cached locally.
If ``None``, its value gets derived using per device batch size and number of
canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``.
Defaults to ``None``.
cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache.
Before downloading a shard, the least recently used resident shard(s) may be evicted
(deleted from the local cache) in order to stay under the limit. Set to ``None`` to
disable shard eviction. Defaults to ``None``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. The sample space is divided evenly according to the number of canonical
nodes. The higher the value, the more independent non-overlapping paths the
StreamingDataset replicas take through the shards per model replica (increasing data
source diversity). Defaults to ``None``, which is interpreted as 64 times the number
of nodes of the initial run.
.. note::
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
tokenizer_name (str): The name of the HuggingFace tokenizer to use to tokenize samples.
max_seq_len (int): The max sequence length of each token sample.
group_method (str): How to group text samples into token samples. Currently only supporting
``'truncate'``.
"""
def __init__(self,
*,
remote: Optional[str] = None,
local: Optional[str] = None,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
cache_limit: Optional[int] = None,
partition_algo: str = 'orig',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1s',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
tokenizer_name: str,
max_seq_len: int,
group_method: str) -> None:
if group_method not in {'truncate'}:
raise ValueError(f"group_method='{group_method}' must be one of {'truncate'}.")
super().__init__(remote=remote,
local=local,
split=split,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip,
epoch_size=epoch_size,
predownload=predownload,
cache_limit=cache_limit,
partition_algo=partition_algo,
num_canonical_nodes=num_canonical_nodes,
batch_size=batch_size,
shuffle=shuffle,
shuffle_algo=shuffle_algo,
shuffle_seed=shuffle_seed,
shuffle_block_size=shuffle_block_size)
self.tokenizer_name = tokenizer_name
self.max_seq_len = max_seq_len
self.group_method = group_method
# Build tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
if self.tokenizer.pad_token is None:
# Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs
self.tokenizer.pad_token = self.tokenizer.eos_token
def _tokenize(self, text_sample: dict[str, Any]):
"""Apply the tokenizer to a sample.
Args:
text_sample (Dict[str, Any]): Sample to tokenize.
"""
if self.group_method == 'truncate':
truncation = True
padding = 'max_length'
max_length = self.max_seq_len
else:
raise ValueError(f'Got unknown group_method={self.group_method}.')
return self.tokenizer(text_sample['text'],
truncation=truncation,
padding=padding,
max_length=max_length)
def get_item(self, idx: int) -> Any:
"""Get sample by global index, blocking to load its shard if missing.
Args:
idx (int): Sample index.
Returns:
Any: Sample data.
"""
text_sample = super().get_item(idx)
token_sample = self._tokenize(text_sample)
# Skip any token grouping, currently only supporting group_method='truncate'
return token_sample