Skip to content

Commit

Permalink
fix: trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
sigridjineth committed Aug 24, 2024
1 parent bb5d87c commit 17b889a
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions src/tevatron/reranker/driver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,23 @@
HfArgumentParser,
set_seed,
)
from transformers import TrainingArguments

from tevatron.reranker.arguments import ModelArguments, DataArguments, \
TevatronTrainingArguments
from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
from tevatron.reranker.modeling import RerankerModel
from tevatron.reranker.dataset import RerankerTrainDataset
from tevatron.reranker.collator import RerankerTrainCollator
from tevatron.reranker.trainer import RerankerTrainer
from tevatron.reranker.gc_trainer import GradCacheTrainer

logger = logging.getLogger(__name__)

def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args, tevatron_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, tevatron_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
tevatron_args: TevatronTrainingArguments

# Combine TrainingArguments and TevatronTrainingArguments
for key, value in vars(tevatron_args).items():
setattr(training_args, key, value)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if (
os.path.exists(training_args.output_dir)
Expand Down Expand Up @@ -60,7 +51,6 @@ def main():
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("MODEL parameters %s", model_args)
logger.info("Tevatron parameters %s", tevatron_args)

set_seed(training_args.seed)

Expand All @@ -81,7 +71,9 @@ def main():
train_dataset = RerankerTrainDataset(data_args)
train_collator = RerankerTrainCollator(data_args, tokenizer)

trainer = RerankerTrainer(
# Choose the appropriate trainer based on the grad_cache flag
trainer_cls = GradCacheTrainer if training_args.grad_cache else RerankerTrainer
trainer = trainer_cls(
model=model,
args=training_args,
train_dataset=train_dataset,
Expand Down

0 comments on commit 17b889a

Please sign in to comment.