forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
60 lines (48 loc) · 1.7 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Script used to generate the mini datasets."""
import multiprocessing as mp
try:
mp.set_start_method("spawn")
except Exception:
pass
from tempfile import TemporaryDirectory
from datasets import Dataset, DatasetDict, load_dataset
from torchrl.data.rlhf.dataset import get_dataloader
from torchrl.data.rlhf.prompt import PromptData
def generate_small_dataset(comparison=True):
if comparison:
d = load_dataset("CarperAI/openai_summarize_comparisons")
else:
d = load_dataset("CarperAI/openai_summarize_tldr")
smalld = {}
for key in list(d.keys()):
if any(key.startswith(sub) for sub in ("train", "valid", "test")):
smalld[key] = Dataset.from_dict(d[key][:50])
smalld = DatasetDict(smalld)
if comparison:
smalld.save_to_disk("test/datasets_mini/openai_summarize_comparisons")
else:
smalld.save_to_disk("test/datasets_mini/openai_summarize_tldr")
def get_minibatch():
with TemporaryDirectory() as tmpdir:
dl = get_dataloader(
batch_size=16,
block_size=33,
tensorclass_type=PromptData,
dataset_name="../datasets_mini/openai_summarize_tldr",
device="cpu",
num_workers=2,
infinite=False,
prefetch=0,
split="train",
from_disk=True,
root_dir=tmpdir,
)
for data in dl:
data = data.clone().memmap_("test/datasets_mini/tldr_batch/")
break
if __name__ == "__main__":
get_minibatch()