Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
MorenoLaQuatra committed Mar 3, 2023
2 parents 8868d2e + e4715da commit e4e94f4
Showing 1 changed file with 40 additions and 16 deletions.
56 changes: 40 additions & 16 deletions ic_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,19 @@ def parse_cmd_line_params():
required=False)
parser.add_argument(
"--model",
help="model to use -- choose one of: facebook/wav2vec2-large-xlsr-53, \
facebook/wav2vec2-xls-r-300m, facebook/wav2vec2-xls-r-1b, \
facebook/wav2vec2-xls-r-2b, jonatasgrosman/wav2vec2-large-xlsr-53-italian",
help="model to use -- choose one of: \
facebook/wav2vec2-large-xlsr-53, \
facebook/wav2vec2-xls-r-300m, \
facebook/wav2vec2-xls-r-1b, \
facebook/wav2vec2-xls-r-2b, \
jonatasgrosman/wav2vec2-large-xlsr-53-italian",
default="facebook/wav2vec2-xls-r-300m",
type=str,
required=True)
parser.add_argument(
"--dataset_name",
help="name of the dataset to use",
default="rita-nlp/italic-easy",
default="RiTA-nlp/italic-easy",
type=str,
required=True)
parser.add_argument(
Expand All @@ -58,13 +61,15 @@ def parse_cmd_line_params():

""" Define model and feature extractor """
def define_model(model_checkpoint, num_labels, label2id, id2label):
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_checkpoint
)
model = AutoModelForAudioClassification.from_pretrained(
model_checkpoint,
num_labels=num_labels,
label2id=label2id,
id2label=id2label
)
)
return feature_extractor, model


Expand Down Expand Up @@ -93,7 +98,7 @@ def define_model(model_checkpoint, num_labels, label2id, id2label):
dataset = load_dataset(
dataset_name,
use_auth_token=True if use_auth_token else None
)
)
ds_train = dataset["train"]
ds_validation = dataset["validation"]

Expand All @@ -106,27 +111,46 @@ def define_model(model_checkpoint, num_labels, label2id, id2label):
num_labels = len(id2label)

## Model & Feature Extractor
model_checkpoint = parse_cmd_line_params().model
model_name = model_checkpoint.split("/")[-1]
feature_extractor, model = define_model(model_checkpoint, num_labels, label2id, id2label)
feature_extractor, model = define_model(
model_checkpoint,
num_labels,
label2id,
id2label
)

## Train & Validation Datasets
train_dataset = Dataset(ds_train, feature_extractor, label2id, max_duration, device)
val_dataset = Dataset(ds_validation, feature_extractor, label2id, max_duration, device)
train_dataset = Dataset(
ds_train,
feature_extractor,
label2id,
max_duration,
device
)

val_dataset = Dataset(
ds_validation,
feature_extractor,
label2id,
max_duration,
device
)

## Training Arguments and Class Weights
training_arguments = define_training_args(output_dir, batch_size, num_epochs, gradient_accumulation_steps=gradient_accumulation_steps)
# class_weights = compute_class_weights(ds_train, label2id)
training_arguments = define_training_args(
output_dir,
batch_size,
num_epochs,
gradient_accumulation_steps=gradient_accumulation_steps
)

## Trainer
trainer = WeightedTrainer(
# class_weights=class_weights,
model=model,
args=training_arguments,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics
)
)

## Train and Evaluate
trainer.train()
Expand Down

0 comments on commit e4e94f4

Please sign in to comment.