Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] RLHF dataloading #1309

Merged
merged 25 commits into from
Jun 27, 2023
Merged

[Feature] RLHF dataloading #1309

merged 25 commits into from
Jun 27, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Jun 23, 2023

cc @tcbegley @apbard

There are 3 TODO i need help with

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 23, 2023
@vmoens vmoens added the enhancement New feature or request label Jun 23, 2023
Args:
split (str): One of ``"train"`` or ``"valid"``.
dataset_name (str): the name or path of the dataset.
pre_tokenization_hook (callable): TODO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre_tokenization_hook is called on the Dataset before tokenization and should return a modified Dataset object. The intended use is for carrying out tasks that require modifying the dataset as a whole as opposed to modifying individual datapoints, for example discarding certain datapoints based on a particular condition. Tokenization and other "elementwise" operations on the data are performed by the process function which is mapped over the dataset.

max_length (int): the maximum sequence length.
dataset_name (str): the name of the dataset.
make_process_fn (callable): a preprocess function.
pre_tokenization_hook (callable): TODO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below, this is passed on to the load_dataset function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to commit all of these arrow files etc. to the repo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a dummy dataset. I agree it isn't perfect but I'm not sure what choice we have. could we zip it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vision has something of similar taste
https://github.com/pytorch/vision/tree/main/test/assets
I can put it all in a .pt

Copy link
Contributor

@apbard apbard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general I noticed that everything is assuming we use gpt2 models but we are not saying explicitly. E.g. tokenize function

LGTM

test/assets/generate.py Outdated Show resolved Hide resolved
)[prefix]


def load_dataset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only used internally? should we put a "_" in front?

Suggested change
def load_dataset(
def _load_dataset(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it? IDK if we want to expose it or not. To me it's a well self-contained function properly documented that could be used somewhere else, but I'm happy to make it private if there's no use.

Copy link
Contributor

@apbard apbard Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the common case user will want to (possibly) create and load. Hence my proposal to keep this internal and expose only the maybe-renamed other function

_has_datasets = importlib.util.find_spec("datasets") is not None


def create_or_load_dataset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about calling this just "load_dataset" and maybe add an argument such as "exists" in mkdir?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure i see what you mean but the reason I named it that way is that it can either create a dataset for the first time or load it if the memmap tensordict copy is already there. load_dataset already exists and does just that, nothing more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion/discussion is whether we could rename this to load_dataset without changing it's current behaviour and add an argument to specify the behaviour we want in case it does not exists.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im fine with renaming it, though I think that users should be informed that this function caches the results and the name says that quite explicitly.
I'm not sure I see what the extra arg should do. If it does not exist we must load it, and if it does we should use the cache. What extra control should we give?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a flag like exists_ok of mkdir that raises if it does not exists or something similar to from_disk but control whether we want to store cache or not.
I don't have strong opinion on this I am just reasoning if it's really necessary such a long name (even if definitely more explicit) or if it will be enough to say that we cache in the doc.

Copy link
Contributor

@tcbegley tcbegley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, a couple of comments on the generated data / zipping + unzipping it.

lmemmap_save = deepcopy(TensorDict.load_memmap)
mocked_hello = mocker.patch("tensordict.TensorDict.load_memmap")
mocked_hello.side_effect = lmemmap_save
with zipfile.ZipFile(dataset_path, "r") as zip_ref:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing this in every test that loads the data, perhaps we should write a conftest.py file that extracts on setup, then cleans up on completion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point yeah
But then how do we tell conftest to do it only when the test_rlhf.py script is executed?

Copy link
Contributor

@tcbegley tcbegley Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then how do we tell conftest to do it only when the test_rlhf.py script is executed?

How about instead of conftest.py we add a fixture that has session scope and yields a directory which is the location of the extracted data? I think the fixture should only be executed if it is used, and it will only be executed once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fixtures are executed in each test, otherwise the same temporary directories would be reused multiple times

Copy link
Contributor

@tcbegley tcbegley Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the default, but you can change that behaviour

@pytest.fixture(scope="session")
def minidata_dir():
    # unzip
    # yield path

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh now i remember why i did not do that
we have 2 mini datasets but you can't pass a fixture to parametrize, so I couldn't write 2 different fixtures for the tests
The "ugly" solution would be

@pytest.mark.parametrize("dataset", ["comp", "prompt"])
def test_stuff(dataset, comp_fixture, prompt_fixture):
    if dataset == "comp":
        dataset = comp_fixture
    else:
        dataset = prompt_fixture

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add logic here to zip the generated data so that the committed .zip files are reproducible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not use python for that but we could do it

@vmoens
Copy link
Contributor Author

vmoens commented Jun 27, 2023

in general I noticed that everything is assuming we use gpt2 models but we are not saying explicitly. E.g. tokenize function

Would the tokenization change with another model?
Any pointer to a place where that would break?

vmoens and others added 5 commits June 27, 2023 12:04
Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com>
Co-authored-by: Tom Begley <tomcbegley@gmail.com>
Co-authored-by: Tom Begley <tomcbegley@gmail.com>
@apbard
Copy link
Contributor

apbard commented Jun 27, 2023

in general I noticed that everything is assuming we use gpt2 models but we are not saying explicitly. E.g. tokenize function

Would the tokenization change with another model? Any pointer to a place where that would break?

If the tokenizer is huggingface-like I think the tokenize function should work provided that we somehow expose: https://github.com/pytorch/rl/pull/1309/files#r1243540647

num_workers = max(os.cpu_count() // 2, 1)
if excluded_features is None:
excluded_features = {"text", "prompt", "label", "valid_sample"}
tokenizer = AutoTokenizer.from_pretrained("gpt2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it
Let's parametrize it then

return dataset


def tokenize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tokenizers already have a method tokenize. maybe it's better to be more explicit about what this does

Suggested change
def tokenize(
def tokenize_dataset(

Copy link
Contributor

@apbard apbard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@tcbegley tcbegley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

vmoens and others added 3 commits June 27, 2023 15:15
@vmoens vmoens merged commit 2c47fa3 into main Jun 27, 2023
@vmoens vmoens deleted the rlhf_data branch June 27, 2023 17:16
@vmoens vmoens restored the rlhf_data branch June 28, 2023 09:04
@vmoens vmoens deleted the rlhf_data branch July 11, 2023 15:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants