generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
Copy pathrloo_trainer.py
678 lines (603 loc) · 32.4 KB
/
rloo_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import math
import os
import textwrap
import time
from collections import defaultdict
from typing import Optional, Union
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
BaseImageProcessor,
DataCollatorWithPadding,
FeatureExtractionMixin,
GenerationConfig,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainerCallback,
TrainerControl,
is_wandb_available,
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
OnlineTrainerState,
batch_generation,
disable_dropout_in_model,
exact_div,
first_true_indices,
forward,
get_reward,
prepare_deepspeed,
print_rich_table,
truncate_response,
)
from .rloo_config import RLOOConfig
from .utils import generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment
if is_wandb_available():
import wandb
INVALID_LOGPROB = 1.0
class RLOOTrainer(Trainer):
_tag_names = ["trl", "rloo"]
def __init__(
self,
config: RLOOConfig,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
],
policy: nn.Module,
ref_policy: nn.Module,
reward_model: nn.Module,
train_dataset: Dataset,
data_collator: Optional[DataCollatorWithPadding] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
# less commonly used
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
callbacks: Optional[list[TrainerCallback]] = None,
) -> None:
if ref_policy is policy:
raise ValueError(
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
"same as `policy`, you must mass a copy of it, or `None` if you use peft."
)
self.args = config
args = config
self.processing_class = processing_class
self.policy = policy
# Define the collator if not provided
if data_collator is None:
data_collator = DataCollatorWithPadding(self.processing_class)
self.policy.generation_config.eos_token_id = (
None # disable `pad_token_id` and `eos_token_id` because we just want to
)
self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
self.ref_policy = ref_policy
self.reward_model = reward_model
self.train_dataset = train_dataset
self.train_dataset_len = len(train_dataset)
self.data_collator = data_collator
self.eval_dataset = eval_dataset
self.optimizer, self.lr_scheduler = optimizers
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
#########
# calculate various batch sizes
#########
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
self.accelerator = accelerator
args.world_size = accelerator.num_processes
args.local_batch_size = (
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
)
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
args.batch_size = int(args.local_batch_size * args.world_size)
args.mini_batch_size = exact_div(
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
)
args.local_mini_batch_size = exact_div(
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
)
args.num_total_batches = math.ceil(
args.total_episodes / args.batch_size
) # we may train for more than `total_episodes`
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
if args.num_sample_generations > 0:
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
self.local_dataloader_batch_size = exact_div(
args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
#########
# setup model, optimizer, and others
#########
for module in [policy, ref_policy, reward_model]:
disable_dropout_in_model(module)
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = self.processing_class.eos_token_id
self.model = policy
self.create_optimizer_and_scheduler(
num_training_steps=args.num_total_batches
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
#########
### trainer specifics
#########
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
self.control = TrainerControl()
self.state = OnlineTrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)
self.current_flos = 0
self.hp_search_backend = None
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# Create distant repo and output directory if needed
self.hub_model_id = None
if self.args.push_to_hub:
self.init_hf_repo()
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)
self.backup_model = None
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
#########
### setup dataloader
#########
self.dataloader = DataLoader(
self.train_dataset,
batch_size=self.local_dataloader_batch_size,
shuffle=True,
collate_fn=self.data_collator,
drop_last=True, # needed; otherwise the last batch will be of ragged shape
)
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
torch.manual_seed(args.seed)
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
torch.manual_seed(self.local_seed) # reset the local seed again
self.eval_dataloader = DataLoader(
self.eval_dataset,
batch_size=args.per_device_eval_batch_size,
collate_fn=self.data_collator,
drop_last=True,
) # no need to shuffle eval dataset
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
if self.is_deepspeed_enabled:
self.reward_model = prepare_deepspeed(
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
)
self.ref_policy = prepare_deepspeed(
self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
)
self.deepspeed = self.model
else:
self.ref_policy = self.ref_policy.to(self.accelerator.device)
self.reward_model = self.reward_model.to(self.accelerator.device)
def get_train_dataloader(self) -> DataLoader:
return self.dataloader
def get_eval_dataloader(self) -> DataLoader:
return self.eval_dataloader
def train(self):
args = self.args
accelerator = self.accelerator
optimizer = self.optimizer
model = self.model
self.model_wrapped = self.model
ref_policy = self.ref_policy
reward_model = self.reward_model
processing_class = self.processing_class
dataloader = self.dataloader
device = accelerator.device
def repeat_generator():
while True:
yield from dataloader
iter_dataloader = iter(repeat_generator())
generation_config = GenerationConfig(
max_new_tokens=args.response_length,
temperature=(args.temperature + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
)
accelerator.print("===training policy===")
start_time = time.time()
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
approxkl_stats = torch.zeros(stats_shape, device=device)
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
pg_loss_stats = torch.zeros(stats_shape, device=device)
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
entropy_stats = torch.zeros(stats_shape, device=device)
ratio_stats = torch.zeros(stats_shape, device=device)
model.train()
# trainer state initialization
self.state.global_step = 0
self.state.episode = 0
self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
for update in range(1, args.num_total_batches + 1):
self.state.episode += 1 * args.batch_size
data = next(iter_dataloader)
with torch.no_grad():
queries = data["input_ids"].to(device)
queries = queries.repeat(args.rloo_k, 1)
context_length = queries.shape[1]
responses = []
postprocessed_responses = []
logprobs = []
ref_logprobs = []
scores = []
sequence_lengths = []
# Generate responses and compute logprobs
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
query_responses, logitss = batch_generation(
unwrapped_model,
queries,
args.local_rollout_forward_batch_size,
processing_class.pad_token_id,
generation_config,
)
# Process responses in batches
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
query = queries[i : i + args.local_rollout_forward_batch_size]
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del logits, all_logprob
torch.cuda.empty_cache()
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
torch.cuda.empty_cache()
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, processing_class.pad_token_id, response
)
# Response Processing 2. run reward model on the truncated responses
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
_, score, _ = get_reward(
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)
# Store batch results
responses.append(response)
postprocessed_responses.append(postprocessed_response)
logprobs.append(logprob)
ref_logprobs.append(ref_logprob)
sequence_lengths.append(sequence_length)
scores.append(score)
# Concatenate all batched results
responses = torch.cat(responses, 0)
postprocessed_responses = torch.cat(postprocessed_responses, 0)
logprobs = torch.cat(logprobs, 0)
ref_logprobs = torch.cat(ref_logprobs, 0)
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
del (logprob, ref_logprob, score)
torch.cuda.empty_cache()
gc.collect()
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
# responses not passing that filter will receive a low (fixed) score
# only query humans on responses that pass that filter
contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1)
if args.missing_eos_penalty is not None:
scores[~contain_eos_token] -= self.args.missing_eos_penalty
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
# 4. compute rewards
# Compute KL divergence
kl = logprobs - ref_logprobs
# Normalize rewards
if args.normalize_reward:
scores = (scores - scores.mean()) / (scores.std() + 1e-8)
scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)
# Compute total reward with KL penalty
if args.token_level_kl:
# Token-level KL penalty: apply KL penalty per token
kl_reward = -args.kl_coef * kl
# Get the index of the last non-padded token for each sequence
eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
last_reward = torch.zeros_like(kl)
# Ensure scores has correct shape and type
scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)
# Combine KL reward and last reward
non_score_reward = kl_reward.sum(1) # Keep this for logging
reward = last_reward + kl_reward
rlhf_reward = reward.sum(1) # Sum across sequence length
else:
# Sequence-level KL penalty: sum KL across tokens first
sequence_kl = kl.sum(1)
non_score_reward = -args.kl_coef * sequence_kl
rlhf_reward = non_score_reward + scores
# vectorized RLOO advantages implementation
rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
advantages = rlhf_reward - baseline
advantages = advantages.flatten()
# Normalize advantages
if args.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
torch.cuda.empty_cache()
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
for ppo_epoch_idx in range(args.num_ppo_epochs):
b_inds = np.random.permutation(args.local_batch_size)
minibatch_idx = 0
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
mini_batch_end = mini_batch_start + args.local_mini_batch_size
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
gradient_accumulation_idx = 0
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
with accelerator.accumulate(model):
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
# Get batch data
mb_advantage = advantages[micro_batch_inds]
mb_responses = responses[micro_batch_inds]
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]
# Forward pass
output = forward(model, mb_query_responses, processing_class.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
# Compute new logprobs
new_all_logprobs = F.log_softmax(logits, dim=-1)
new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
# Compute probability ratios
new_ratio = (new_logprobs - mb_logprobs).exp()
new_logprobs = new_logprobs.sum(1)
mb_logprobs = mb_logprobs.sum(1)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
# PPO clipped loss
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = pg_loss_max.mean()
# Final loss
loss = pg_loss
# Optimization step
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
pg_clipfrac
)
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1
# del everything and empty cache
# fmt: off
del (
output, logits, new_all_logprobs, new_logprobs,
logprobs_diff, ratio, pg_losses, pg_losses2,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
)
# fmt: on
torch.cuda.empty_cache()
# Compute metrics
with torch.no_grad():
mean_kl = kl.sum(1).mean()
mean_entropy = (-logprobs).sum(1).mean()
mean_non_score_reward = non_score_reward.mean()
eps = int(self.state.episode / (time.time() - start_time))
metrics = {}
metrics["eps"] = eps
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
metrics["objective/non_score_reward"] = (
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
)
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["episode"] = self.state.episode
self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log
self.log(metrics)
del kl, mean_kl, mean_entropy, scores
self.lr_scheduler.step()
self.state.global_step += 1
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
torch.cuda.empty_cache()
gc.collect()
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
self.generate_completions(sampling=True)
# HF trainer specifics
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def generate_completions(self, sampling: bool = False):
args = self.args
processing_class = self.processing_class
generation_config = GenerationConfig(
max_new_tokens=self.args.response_length,
temperature=(0.01 + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
)
table = defaultdict(list)
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
for batch in self.eval_dataloader:
query = batch["input_ids"]
with torch.no_grad():
context_length = query.shape[1]
query_response, _ = batch_generation(
unwrapped_model,
query,
query.shape[0],
processing_class.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, processing_class.pad_token_id, response
)
table["query"].extend(
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
)
table["model response"].extend(
gather_object(processing_class.batch_decode(postprocessed_response))
)
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
_, score, _ = get_reward(
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
if sampling:
break
df = pd.DataFrame(table)
if self.accelerator.is_main_process:
print_rich_table(df.iloc[0 : 0 + 5])
if "wandb" in args.report_to:
import wandb
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
if "comet_ml" in args.report_to:
log_table_to_comet_experiment(
name="completions.csv",
table=df,
)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str` or `None`, *optional*, defaults to `None`):
Name of the model.
dataset_name (`str` or `None`, *optional*, defaults to `None`):
Name of the dataset used for training.
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
tags = tags or []
if isinstance(tags, str):
tags = [tags]
if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")
citation = textwrap.dedent("""\
@inproceedings{ahmadian2024back,
title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
year = 2024,
booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
publisher = {Association for Computational Linguistics},
pages = {12248--12267},
editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="RLOO",
trainer_citation=citation,
paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
paper_id="2402.14740",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))