diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index 0331023ae54cdb..f8d326c6170615 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -216,12 +216,16 @@ def __post_init__(self): elif self.source_lang is None or self.target_lang is None: raise ValueError("Need to specify the source language and the target language.") + # accepting both json and jsonl file extensions, as + # many jsonlines files actually have a .json extension + valid_extensions = ["json", "jsonl"] + if self.train_file is not None: extension = self.train_file.split(".")[-1] - assert extension == "json", "`train_file` should be a json file." + assert extension in valid_extensions, "`train_file` should be a jsonlines file." if self.validation_file is not None: extension = self.validation_file.split(".")[-1] - assert extension == "json", "`validation_file` should be a json file." + assert extension in valid_extensions, "`validation_file` should be a jsonlines file." if self.val_max_target_length is None: self.val_max_target_length = self.max_target_length