-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
497 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import os | ||
|
||
import hydra | ||
|
||
from src.classifier.torch_helpers.load_pretrained import load_torch_model | ||
from src.classifier.torch_helpers.torch_dataloader import get_dataloader | ||
from src.classifier.torch_helpers.eval_torch import evaluate | ||
from src.classifier.non_torch.save_and_load_model import load_model | ||
from src.classifier.non_torch.eval_non_torch import evaluate_model | ||
|
||
|
||
def evaluate_on_test_set(cfg, X_test, Y_test, texts_test): | ||
if cfg.dev_settings.annotation == "majority": | ||
model_path = cfg.classifier_mode.model_path.majority | ||
else: | ||
model_path = cfg.classifier_mode.model_path.unanimous | ||
model_path = hydra.utils.to_absolute_path(model_path) | ||
results_path = cfg.classifier_mode.results_path | ||
results_path = hydra.utils.to_absolute_path(results_path) | ||
|
||
dest = os.path.join( | ||
results_path, | ||
cfg.classifier.name, | ||
) | ||
|
||
os.makedirs(dest, exist_ok=True) | ||
name_str = f"{cfg.embedding.name}_{cfg.classifier.name}" | ||
|
||
if cfg.classifier.name == "xgb": | ||
model = load_model(model_path) | ||
evaluate_model( | ||
cfg.embedding.name, | ||
cfg.classifier.name, | ||
model, | ||
X_test, | ||
Y_test, | ||
texts_test, | ||
dest, | ||
) | ||
else: | ||
model = load_torch_model(model_path, cfg.classifier.name, logger=None) | ||
model.to("cpu") | ||
model.eval() | ||
|
||
test_loader = get_dataloader( | ||
X_test, Y_test, cfg.classifier_mode.batch_size, shuffle=False | ||
) | ||
|
||
_, _, _ = evaluate( | ||
model, | ||
test_loader, | ||
texts_test, | ||
classes=set(Y_test), | ||
name_str=name_str, | ||
output_path=dest, | ||
plot=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import torch | ||
import numpy as np | ||
import xgboost as xgb | ||
from sklearn.ensemble import RandomForestClassifier | ||
|
||
from src.classifier.lstm.lstm_classifier import RegardLSTM | ||
from src.classifier.sent_transformer.sbert_classifier import RegardBERT | ||
|
||
|
||
def get_classifier(model_params, model_type, n_embed, weight_vector=None, classes=None): | ||
if model_type == "rf": | ||
classifier = RandomForestClassifier( | ||
n_estimators=model_params.n_estimators, | ||
max_depth=model_params.max_depth, | ||
random_state=42, | ||
) | ||
elif model_type == "xgb": | ||
classifier = xgb.XGBClassifier( | ||
n_estimators=model_params.n_estimators, | ||
learning_rate=model_params.learning_rate, | ||
max_depth=model_params.max_depth, | ||
random_state=42, | ||
) | ||
elif model_type == "lstm": | ||
classifier = RegardLSTM( | ||
n_embed=n_embed, | ||
n_hidden=model_params.n_hidden, | ||
n_hidden_lin=model_params.n_hidden_lin, | ||
n_output=model_params.n_output, | ||
n_layers=model_params.n_layers, | ||
lr=model_params.lr, | ||
weight_vector=weight_vector, | ||
bidirectional=model_params.bidir, | ||
gru=model_params.unit, | ||
drop_p=model_params.dropout, | ||
drop_p_gru=model_params.dropout_gru, | ||
) | ||
elif model_type == "transformer": | ||
classifier = RegardBERT( | ||
n_embed=n_embed, | ||
n_hidden_lin=model_params.n_hidden_lin, | ||
n_hidden_lin_2=model_params.n_hidden_lin_2, | ||
n_output=model_params.n_output, | ||
lr=model_params.lr, | ||
weight_vector=weight_vector, | ||
drop_p=model_params.dropout, | ||
) | ||
else: | ||
print( | ||
"Please choose a classifier type that is implemented.\ | ||
So far only rf for RandomForest, xgb for XGBoost, or lstm." | ||
) | ||
|
||
return classifier | ||
|
||
|
||
def compute_weight_vector(Y, use_torch=True): | ||
weight_vector = len(Y) / (len(set(Y)) * np.bincount(Y)) | ||
if use_torch: | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
weight_vector = torch.FloatTensor(weight_vector).to(device) | ||
return weight_vector |
89 changes: 89 additions & 0 deletions
89
src/classifier/classifier_training/incremental_training.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
|
||
import hydra.utils | ||
import pandas as pd | ||
|
||
from src.classifier.classifier_training.training import train_classifier | ||
|
||
|
||
def _get_data_increments(data_len, percentage): | ||
# creates list of upperbounds to index different amounts of data | ||
increments = [] | ||
for p in range(percentage, 100 + percentage, percentage): | ||
upper_bound = int(data_len * (p / 100)) | ||
increments.append(upper_bound) | ||
return increments | ||
|
||
|
||
def train_on_increments( | ||
cfg, X_dev, Y_dev, X_test, Y_test, texts_test, logger, num_runs=5 | ||
): | ||
|
||
data = pd.DataFrame(columns=["id", "step", "metric"]) | ||
plot_path = hydra.utils.to_absolute_path(cfg.run_mode.plot_path) | ||
os.makedirs(plot_path, exist_ok=True) | ||
counter = 0 | ||
if cfg.classifier_mode.mult_seeds: | ||
seeds = [21, 42, 84, 168, 336] | ||
for run in range(num_runs): | ||
print(f"---- RUN {run} ----") | ||
data, counter = _run( | ||
cfg, | ||
X_dev, | ||
X_test, | ||
Y_dev, | ||
Y_test, | ||
texts_test, | ||
counter, | ||
data, | ||
logger, | ||
run, | ||
seeds[run], | ||
) | ||
else: | ||
data, _ = _run( | ||
cfg, X_dev, X_test, Y_dev, Y_test, texts_test, counter, data, logger | ||
) | ||
print(data) | ||
|
||
data.to_csv(os.path.join(plot_path, "incr_train_data.csv")) | ||
print( | ||
"Stored results at", | ||
plot_path, | ||
"\nCreate a plot in " "incremental_training_exploration.ipynb", | ||
) | ||
|
||
|
||
def _run( | ||
cfg, | ||
X_dev, | ||
X_test, | ||
Y_dev, | ||
Y_test, | ||
texts_test, | ||
counter, | ||
data, | ||
logger, | ||
run=None, | ||
seed=42, | ||
): | ||
print("Full size", len(X_dev)) | ||
increments = _get_data_increments(len(X_dev), cfg.classifier_mode.percentage) | ||
for upper_bound in increments: | ||
print(f"{upper_bound} data points") | ||
X_dev_i, Y_dev_i = X_dev[:upper_bound], Y_dev[:upper_bound] | ||
scores = train_classifier( | ||
cfg, X_dev_i, Y_dev_i, X_test, Y_test, texts_test, logger, seed | ||
) | ||
if cfg.classifier_mode.cv_folds: | ||
for fold, score in enumerate(scores): | ||
data.loc[counter, "id"] = fold | ||
data.loc[counter, "step"] = upper_bound | ||
data.loc[counter, "metric"] = score | ||
counter += 1 | ||
else: | ||
data.loc[counter, "id"] = run | ||
data.loc[counter, "step"] = upper_bound | ||
data.loc[counter, "metric"] = scores | ||
counter += 1 | ||
return data, counter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import numpy as np | ||
|
||
from src.classifier.torch_helpers.torch_training import train_torch_model | ||
from src.classifier.non_torch.non_torch_training import train_sklearn | ||
|
||
|
||
def train_classifier( | ||
cfg, X_dev_emb, Y_dev, X_test_emb, Y_test, texts_test, logger, seed=42 | ||
): | ||
classes = set(Y_dev) | ||
|
||
if not cfg.classifier.name.startswith(("lstm", "transformer")): | ||
score = train_sklearn( | ||
cfg, X_dev_emb, X_test_emb, Y_dev, Y_test, logger, texts_test | ||
) | ||
|
||
else: | ||
score = train_torch_model( | ||
cfg, X_dev_emb, X_test_emb, Y_dev, Y_test, classes, texts_test, seed | ||
) | ||
|
||
return score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from sklearn.ensemble import RandomForestClassifier | ||
import xgboost | ||
|
||
|
||
def suggest_xgb(model_params, trial, xgb=None): | ||
n_estimators = trial.suggest_int( | ||
model_params.n_estimators.name, | ||
model_params.n_estimators.lower, | ||
model_params.n_estimators.upper, | ||
model_params.n_estimators.step, | ||
) | ||
lr = trial.suggest_float( | ||
model_params.learning_rate.name, | ||
model_params.learning_rate.lower, | ||
model_params.learning_rate.upper, | ||
log=True, | ||
) | ||
max_depth = trial.suggest_int( | ||
model_params.max_depth.name, | ||
model_params.max_depth.lower, | ||
model_params.max_depth.upper, | ||
model_params.max_depth.step, | ||
) | ||
|
||
classifier = xgboost.XGBClassifier( | ||
n_estimators=n_estimators, | ||
learning_rate=lr, | ||
max_depth=max_depth, | ||
random_state=42, | ||
use_label_encoder=False, | ||
tree_method="gpu_hist", | ||
gpu_id=0, | ||
) | ||
return classifier | ||
|
||
|
||
def suggest_rf(model_params, trial): | ||
n_estimators = trial.suggest_int( | ||
model_params.n_estimators.name, | ||
model_params.n_estimators.lower, | ||
model_params.n_estimators.upper, | ||
model_params.n_estimators.step, | ||
) | ||
max_depth = trial.suggest_int( | ||
model_params.max_depth.name, | ||
model_params.max_depth.lower, | ||
model_params.max_depth.upper, | ||
model_params.max_depth.step, | ||
) | ||
|
||
classifier = RandomForestClassifier( | ||
n_estimators=n_estimators, max_depth=max_depth, random_state=42 | ||
) | ||
return classifier |
Oops, something went wrong.