import os, random, argparse, sys, pickle, time import torch from tqdm import tqdm, trange import numpy as np import pandas as pd os.environ["TOKENIZERS_PARALLELISM"] = "false" from datasets import Dataset from torch.utils.data import DataLoader import wandb from dataclasses import dataclass, field def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) from transformers.utils import logging logging.set_verbosity_info() logger = logging.get_logger("transformers") """ This code is designed for alignment search for large models, i.e., >1B parameters. We test it out with Alpaca 7B which is based on LLaMA 7B model, but it should be extensible to larger models as well if computation resource is allowed. """ CACHE_DIR = "../.cache/" class AlpacaAligner(object): def __init__( self, model, is_master, logger, args, lr=5e-5, apex_enable=False, n_gpu=1, gpu_id=0, early_stopping=5, do_statistic=False, model_name="", device="cuda" ): self.model = model num_params = count_parameters(model) logger.info(f'Number of Alpaca-7B model params: {num_params}') self.is_master = is_master self.logger = logger self.is_wandb = args.is_wandb self.model_name = model_name self.lr = lr self.n_gpu = n_gpu self.device = device self.early_stopping = early_stopping if args.is_wandb and is_master: import wandb run = wandb.init( project=f"Boundless-DAS-{args.task_name}", entity=args.wandb_username, name=model_name, ) wandb.config.update(args) def save_model(self, output_dir, model_name): if self.n_gpu > 1: torch.save({ 'rotate_layer': self.model.module.model.rotate_layer.state_dict(), 'intervention_boundaries': self.model.module.model.intervention_boundaries, 'temperature': self.model.module.model.temperature }, os.path.join(output_dir, model_name)) else: torch.save({ 'rotate_layer': self.model.model.rotate_layer.state_dict(), 'intervention_boundaries': self.model.model.intervention_boundaries, 'temperature': self.model.model.temperature }, os.path.join(output_dir, model_name)) def prealign_eval(self, prealign_dataloader, output_dir): total_count = 0 correct_count = 0 self.model.eval() with torch.no_grad(): for step, inputs in enumerate(prealign_dataloader): for k, v in inputs.items(): if v is not None and isinstance(v, torch.Tensor): inputs[k] = v.to(self.device) # aligning forward! outputs = self.model( input_ids=inputs['input_ids'], labels=inputs['labels'] ) actual_test_labels = inputs['labels'][:, -1] pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1) correct_labels = (actual_test_labels==pred_test_labels) total_count += len(correct_labels) correct_count += correct_labels.sum().tolist() current_acc = round(correct_count/total_count, 2) logger.info(f"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {current_acc}") if self.is_master and not self.is_wandb: log_prealign = open(os.path.join(output_dir, 'prealign_log.txt'), 'w', buffering=1) print(f'prealign_accuracy,{current_acc}', file=log_prealign) log_prealign.close() elif self.is_wandb: wandb.log( { "eval/prealign_accuracy": current_acc }, step=0 ) def train( self, train_dataloader, dev_dataloader, test_dataloader, optimizer, scheduler, output_dir, log_step, valid_steps, epochs, gradient_accumulation_steps, ): if self.is_master and not self.is_wandb: log_train = open(os.path.join(output_dir, 'train_log.txt'), 'w', buffering=1) log_eval = open(os.path.join(output_dir, 'eval_log.txt'), 'w', buffering=1) print('step,loss,accuracy', file=log_train) print('step,accuracy', file=log_eval) log_train.close() log_eval.close() # okay, have to honest, not sure whether we do train mode align or eval align; # i guess it is good to try both, but ... only trying train here and move on. self.model.train() train_iterator = trange( 0, int(epochs), desc="Epoch" ) total_step = 0 total_log_step = 0 best_eval_acc = -1 target_total_step = len(train_dataloader) * int(epochs) temperature_start = 50.0 temperature_end = 0.1 temperature_schedule = torch.linspace(temperature_start, temperature_end, target_total_step).to(torch.bfloat16) self.model.model.temperature.data = temperature_schedule[total_step] for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True) for step, inputs in enumerate(epoch_iterator): for k, v in inputs.items(): if v is not None and isinstance(v, torch.Tensor): inputs[k] = v.to(self.device) # aligning forward! source_hidden_states = self.model( input_ids=inputs['source_input_ids'], output_rotated_hidden_states_only=True ).rotated_hidden_states outputs = self.model( input_ids=inputs['input_ids'], source_hidden_states=source_hidden_states, intervention_ids=inputs['intervention_ids'], labels=inputs['labels'] ) loss = outputs.loss.mean() if self.n_gpu > 1 else outputs.loss actual_test_labels = inputs['labels'][:, -1] pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1) correct_labels = (actual_test_labels==pred_test_labels) step_accuracy = correct_labels.sum() / correct_labels.shape[0] step_accuracy = step_accuracy.tolist() if self.is_master and total_step % log_step == 0: if self.is_wandb: intervention_boundaries = torch.clamp(self.model.model.intervention_boundaries, 1e-3, 1) wandb.log( { "train/loss": loss.item(), "train/step_accuracy": step_accuracy, "train/temperature": self.model.model.temperature.data, "train/unified_boundary": intervention_boundaries.data[0], "train/unified_boundary (dummy)": intervention_boundaries.data[1], }, step=total_step ) else: log_train = open(os.path.join(output_dir, 'train_log.txt'), 'a', buffering=1) print('{},{},{}'.format( total_step, loss.item(), step_accuracy ), file=log_train ) log_train.close() if total_step != 0 and total_step % valid_steps == 0: total_count = 0 correct_count = 0 self.model.eval() with torch.no_grad(): for step, inputs in enumerate(dev_dataloader): for k, v in inputs.items(): if v is not None and isinstance(v, torch.Tensor): inputs[k] = v.to(self.device) # aligning forward! source_hidden_states = self.model( input_ids=inputs['source_input_ids'], output_rotated_hidden_states_only=True ).rotated_hidden_states outputs = self.model( input_ids=inputs['input_ids'], source_hidden_states=source_hidden_states, intervention_ids=inputs['intervention_ids'], labels=inputs['labels'] ) actual_test_labels = inputs['labels'][:, -1] pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1) correct_labels = (actual_test_labels==pred_test_labels) total_count += len(correct_labels) correct_count += correct_labels.sum().tolist() current_acc = round(correct_count/total_count, 2) if self.is_wandb: wandb.log( { "eval/accuracy": current_acc }, step=total_step ) else: log_eval = open(os.path.join(output_dir, 'eval_log.txt'), 'a', buffering=1) print('{},{}'.format(total_step, current_acc), file=log_eval) log_eval.close() if current_acc > best_eval_acc: best_eval_acc = current_acc if self.is_master: self.save_model(output_dir, 'pytorch-rotate-best.bin') self.model.train() total_log_step += 1 loss_str = round(loss.item(), 2) epoch_iterator.set_postfix({'loss': loss_str}) if gradient_accumulation_steps > 1: loss = loss / gradient_accumulation_steps if total_step % gradient_accumulation_steps == 0: if not (gradient_accumulation_steps > 1 and total_step == 0): loss.backward() optimizer.step() scheduler.step() self.model.zero_grad() self.model.model.temperature.data = temperature_schedule[total_step] total_step += 1 logger.info("Training is finished ...") ############################### # End of training evaluation. if self.is_master: total_count = 0 correct_count = 0 self.model.eval() with torch.no_grad(): for step, inputs in enumerate(test_dataloader): for k, v in inputs.items(): if v is not None and isinstance(v, torch.Tensor): inputs[k] = v.to(self.device) # aligning forward! source_hidden_states = self.model( input_ids=inputs['source_input_ids'], output_rotated_hidden_states_only=True ).rotated_hidden_states outputs = self.model( input_ids=inputs['input_ids'], source_hidden_states=source_hidden_states, intervention_ids=inputs['intervention_ids'], labels=inputs['labels'] ) actual_test_labels = inputs['labels'][:, -1] pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1) correct_labels = (actual_test_labels==pred_test_labels) total_count += len(correct_labels) correct_count += correct_labels.sum().tolist() current_acc = round(correct_count/total_count, 2) if self.is_wandb: wandb.log( { "test/accuracy": current_acc }, step=total_step ) wandb.finish() else: log_eval = open(os.path.join(output_dir, 'eval_log.txt'), 'a', buffering=1) print('{},{}'.format(total_step, current_acc), file=log_eval) log_eval.close() ############################### if self.is_master: self.save_model(output_dir, 'pytorch-rotate-last.bin')