-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
573 additions
and
0 deletions.
There are no files selected for viewing
120 changes: 120 additions & 0 deletions
120
src/classifier/data_processing/annotate/annotate_sentences.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import os | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import hydra | ||
|
||
from src.classifier.data_processing.annotate.metrics import ( | ||
fleiss_kappa, | ||
get_all_pairwise_kappas, | ||
) | ||
|
||
|
||
def create_combined_df(data_dir): | ||
data_dir = hydra.utils.to_absolute_path(data_dir) | ||
annotations = pd.DataFrame() | ||
annotator_names = [] | ||
for i, annotation in enumerate(os.listdir(data_dir)): | ||
annotator = annotation.split("_")[-1].split(".")[0] | ||
annotator_names += [annotator] | ||
data = pd.read_csv(os.path.join(data_dir, annotation), index_col=0) | ||
if "Unsicher" in data.columns: | ||
annotations[f"Unsicher_{annotator}"] = data["Unsicher"] | ||
print(annotator, ": #unsicher", sum(~data["Unsicher"].isna())) | ||
# print(f'{annotator} not sure about {data['Unsicher']} sentences.') | ||
annotations[annotator] = data["Label"].fillna(98) | ||
annotations.loc[ | ||
~annotations[f"Unsicher_{annotator}"].isna(), annotator | ||
] = 98 | ||
annotations[annotator] = annotations[annotator].astype("int32") | ||
if i == 0: | ||
annotations["Text"] = data["Text"] | ||
annotations["Gender"] = data["Gender"] | ||
return annotations, annotator_names | ||
|
||
|
||
def clean_uncertain_labels(remove_uncertain, annotations, annotator_names): | ||
if remove_uncertain == "all": | ||
min_uncertain = 1 | ||
else: | ||
min_uncertain = 2 | ||
|
||
rm_cases = annotations.loc[ | ||
np.sum(annotations[annotator_names] == 98, axis=1) >= min_uncertain, | ||
annotator_names, | ||
].index | ||
annotations_cleaned = annotations.drop( | ||
annotations.loc[rm_cases, annotator_names].index | ||
) | ||
|
||
annotations_cleaned = annotations_cleaned.replace(98, np.nan) | ||
print(f"Dropping {len(rm_cases)} cases.") | ||
return annotations_cleaned | ||
|
||
|
||
def label_with_aggregate_annotation( | ||
annotation, | ||
label_col, | ||
annotations, | ||
annotator_names, | ||
force_majority=False, | ||
): | ||
if annotation == "majority" or force_majority: | ||
return_df = _get_majority_label( | ||
annotations, | ||
annotator_names, | ||
label_col, | ||
for_stratification_only=force_majority, | ||
) | ||
else: | ||
not_all_equal_idcs = [] | ||
for i, row in annotations[annotator_names].iterrows(): | ||
e = _all_equal(row) | ||
if e is False: | ||
not_all_equal_idcs += [i] | ||
all_equal_indcs = list( | ||
set(annotations.index.values.tolist()) - set(not_all_equal_idcs) | ||
) | ||
return_df = _get_majority_label( | ||
annotations.loc[all_equal_indcs, :], | ||
annotator_names, | ||
label_col, | ||
for_stratification_only=force_majority, | ||
) | ||
|
||
print( | ||
f"Removed {len(not_all_equal_idcs)} with varying votes. {len(all_equal_indcs)} unanimously labeled sentences remain." | ||
) | ||
|
||
# Check inter rater reliability | ||
fleiss_kappa(return_df, annotator_names) | ||
get_all_pairwise_kappas(return_df, annotator_names) | ||
|
||
return return_df | ||
|
||
|
||
def _all_equal(iterator): | ||
iterator = iter(iterator) | ||
try: | ||
first = next(iterator) | ||
except StopIteration: | ||
return True | ||
return all(first == x for x in iterator) | ||
|
||
|
||
def _get_majority_label( | ||
annotations, | ||
annotator_names, | ||
label_col, | ||
for_stratification_only, | ||
): | ||
annotations[label_col] = annotations[annotator_names].mode(axis="columns")[0] | ||
if for_stratification_only and 98 in annotations[label_col]: | ||
from random import choice | ||
|
||
options = annotations[label_col].drop(98) | ||
annotations.loc[annotations[label_col] == 98, label_col] = choice( | ||
options | ||
) # remove unsicher | ||
|
||
return annotations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import pandas as pd | ||
from nltk import agreement | ||
from sklearn.metrics import cohen_kappa_score | ||
|
||
|
||
def fleiss_kappa(data, annotator_names): | ||
formatted_codes = [] | ||
|
||
for j, annotator in enumerate(annotator_names): | ||
formatted_codes += [[j, i, val] for i, val in enumerate(data[annotator])] | ||
|
||
ratingtask = agreement.AnnotationTask(data=formatted_codes) | ||
|
||
print("Fleiss' Kappa:", ratingtask.multi_kappa()) | ||
|
||
|
||
def get_all_pairwise_kappas(data, annotator_names, anonymize=True): | ||
a_names_cl = annotator_names | ||
if anonymize: | ||
annotator_names = [f"Annotator_{i}" for i, _ in enumerate(annotator_names)] | ||
results = pd.DataFrame() | ||
for i, a in enumerate(annotator_names): | ||
for j, b in enumerate(annotator_names): | ||
if j > i: | ||
results.loc[a, b] = cohen_kappa_score( | ||
data[a_names_cl[i]], data[a_names_cl[j]] | ||
) | ||
print("Pairwise Cohen Kappa\n", results) |
82 changes: 82 additions & 0 deletions
82
src/classifier/data_processing/data_augmentation/gendered_prompts.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import random | ||
|
||
from src import constants | ||
|
||
|
||
def replace_with_gendered_pronouns(augment, text_col, df): | ||
assert len(set(df.Gender.unique()) - set(["F", "M", "N"])) == 0 | ||
if augment == "single_gender": | ||
df = replace_with_single_option(text_col, df) | ||
elif augment == "list_gender": | ||
df = replace_from_list(text_col, df) | ||
else: | ||
SystemExit("Asking for non-specified augmentation option") | ||
|
||
return df | ||
|
||
|
||
def replace_from_list(text_col, df): | ||
# For all sentences with female indication, prepend female pronoun/ subject | ||
df.loc[df["Gender"] == "F", text_col] = df.loc[df["Gender"] == "F", text_col].apply( | ||
lambda text: text.replace( | ||
"Die Person", | ||
random.choice(constants.FEMALE_LIST), | ||
) | ||
) | ||
|
||
print(df.loc[df["Gender"] == "F", text_col][:5]) | ||
|
||
# For all sentences with male indication, prepend male pronoun/ subject | ||
df.loc[df["Gender"] == "M", text_col] = df.loc[df["Gender"] == "M", text_col].apply( | ||
lambda text: text.replace( | ||
"Die Person", | ||
random.choice(constants.MALE_LIST), | ||
) | ||
) | ||
|
||
print(df.loc[df["Gender"] == "M", text_col][:5]) | ||
|
||
# For all sentences without any gender indication, gender randomly | ||
df.loc[df["Gender"] == "N", text_col] = df.loc[df["Gender"] == "N", text_col].apply( | ||
lambda text: text.replace( | ||
"Die Person", | ||
random.choice( | ||
[ | ||
random.choice(constants.FEMALE_LIST), | ||
random.choice(constants.MALE_LIST), | ||
] | ||
), | ||
) | ||
) | ||
|
||
print(df.loc[df["Gender"] == "N", text_col][:20]) | ||
|
||
return df | ||
|
||
|
||
def replace_with_single_option(text_col, df): | ||
# For all sentences with female indication, prepend female pronoun/ subject | ||
df.loc[df["Gender"] == "F", text_col] = df.loc[ | ||
df["Gender"] == "F", text_col | ||
].str.replace("Die Person", constants.FEMALE_SINGLE) | ||
|
||
print(df.loc[df["Gender"] == "F", text_col][:5]) | ||
|
||
# For all sentences with male indication, prepend male pronoun/ subject | ||
df.loc[df["Gender"] == "M", text_col] = df.loc[ | ||
df["Gender"] == "M", text_col | ||
].str.replace("Die Person", constants.MALE_SINGLE) | ||
|
||
print(df.loc[df["Gender"] == "M", text_col][:5]) | ||
|
||
# For all sentences without any gender indication, gender randomly | ||
df.loc[df["Gender"] == "N", text_col] = df.loc[df["Gender"] == "N", text_col].apply( | ||
lambda text: text.replace( | ||
"Die Person", | ||
random.choice([constants.FEMALE_SINGLE, constants.MALE_SINGLE]), | ||
) | ||
) | ||
|
||
print(df.loc[df["Gender"] == "N", text_col][:20]) | ||
|
||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import os | ||
import pickle | ||
|
||
from sklearn.model_selection import train_test_split, StratifiedKFold | ||
from src.classifier.data_processing.annotate.annotate_sentences import ( | ||
label_with_aggregate_annotation, | ||
) | ||
|
||
|
||
def create_or_load_indcs(dcfg, label_col, df, annotator_names): | ||
if dcfg.balance_on_majority or dcfg.k_fold: | ||
labels_for_strat = label_with_aggregate_annotation( | ||
dcfg.annotation, | ||
label_col, | ||
df, | ||
annotator_names, | ||
force_majority=True, | ||
)[label_col] | ||
else: | ||
labels_for_strat = None | ||
if dcfg.k_fold: | ||
dest = os.path.join(dcfg.paths.dev_test_indcs, f"num_folds_{dcfg.k_fold}") | ||
dev_set, test_set = [], [] | ||
if not os.path.isdir(dest) or not os.listdir(dest): | ||
os.makedirs(dest, exist_ok=True) | ||
skf = StratifiedKFold(dcfg.k_fold, random_state=42, shuffle=True) | ||
for fold, (dev_indices, test_indices) in enumerate( | ||
skf.split(df.index, labels_for_strat) | ||
): | ||
fold_dest = os.path.join(dest, f"fold_{fold}") | ||
os.makedirs(fold_dest) | ||
dump_dev_test(fold_dest, dev_indices, test_indices) | ||
dev_set.append(df.iloc[dev_indices]) | ||
test_set.append(df.iloc[test_indices]) | ||
else: | ||
for fold in range(dcfg.k_fold): | ||
fold_dest = os.path.join(dest, f"fold_{fold}") | ||
dev_indices, test_indices = load_dev_test(fold_dest) | ||
dev_set.append(df.iloc[dev_indices]) | ||
test_set.append(df.iloc[test_indices]) | ||
else: | ||
test_size = dcfg.test_split | ||
dest = os.path.join(dcfg.paths.dev_test_indcs, f"test_size_{test_size}") | ||
if not os.path.isdir(dest) or not os.listdir(dest): | ||
os.makedirs(dest, exist_ok=True) | ||
dev_indices, test_indices = train_test_split( | ||
df.index, | ||
test_size=test_size, | ||
shuffle=True, | ||
stratify=labels_for_strat, | ||
random_state=42, | ||
) | ||
dump_dev_test(dest, dev_indices, test_indices) | ||
else: | ||
dev_indices, test_indices = load_dev_test(dest) | ||
|
||
dev_set, test_set = df.iloc[dev_indices], df.iloc[test_indices] | ||
return dev_set, test_set | ||
|
||
|
||
def load_dev_test(dest): | ||
with open(os.path.join(dest, "dev_indices.pkl"), "rb") as d: | ||
dev_indices = pickle.load(d) | ||
with open(os.path.join(dest, "test_indices.pkl"), "rb") as t: | ||
test_indices = pickle.load(t) | ||
return dev_indices, test_indices | ||
|
||
|
||
def dump_dev_test(dest, dev_indices, test_indices): | ||
with open(os.path.join(dest, "dev_indices.pkl"), "wb") as d: | ||
pickle.dump(dev_indices, d) | ||
with open(os.path.join(dest, "test_indices.pkl"), "wb") as t: | ||
pickle.dump(test_indices, t) | ||
|
||
|
||
def get_data_splits(dcfg, label_col, df, annotator_names): | ||
dev_set, test_set = create_or_load_indcs(dcfg, label_col, df, annotator_names) | ||
|
||
return dev_set, test_set |
23 changes: 23 additions & 0 deletions
23
src/classifier/data_processing/splitting/load_datasplits.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
import pickle | ||
|
||
|
||
def load_cached(cfg, file_name, logger): | ||
if os.path.exists(os.path.join(cfg.data_cache, file_name)): | ||
|
||
logger.info(f"Loading {file_name}") | ||
embedded_splits = pickle.load( | ||
open(os.path.join(cfg.data_cache, file_name), "rb") | ||
) | ||
X_train_emb = embedded_splits["X_train_emb"] | ||
X_test_emb = embedded_splits["X_test_emb"] | ||
Y_train = embedded_splits["Y_train"] | ||
Y_test = embedded_splits["Y_test"] | ||
|
||
return X_train_emb, X_test_emb, Y_train, Y_test | ||
|
||
return None | ||
|
||
|
||
def load_dev_test(cfg): | ||
pass |
51 changes: 51 additions & 0 deletions
51
src/classifier/data_processing/text_embedding/embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
|
||
import hydra.utils | ||
from sentence_transformers import SentenceTransformer | ||
from gensim.models.keyedvectors import KeyedVectors | ||
from gensim.models.fasttext import ( | ||
FastText, | ||
load_facebook_vectors, | ||
load_facebook_model, | ||
) | ||
|
||
|
||
def get_embedding(cfg, X=None): | ||
if cfg.embedding.name != "transformer": | ||
emb_path = hydra.utils.to_absolute_path(cfg.embedding.path) | ||
else: | ||
emb_path = cfg.embedding.path | ||
|
||
if cfg.embedding.name == "w2v": | ||
embedding = KeyedVectors.load_word2vec_format( | ||
emb_path, binary=False, no_header=cfg.embedding.no_header | ||
) | ||
|
||
elif cfg.embedding.name == "fastt": | ||
if cfg.run_mode.name == "data" and cfg.pre_processing.tune: | ||
dest_path = os.path.join( | ||
cfg.embedding.tuned_path, f"{cfg.pre_processing.epochs}_epochs" | ||
) | ||
os.makedirs(dest_path, exist_ok=True) | ||
dest_file = os.path.join(dest_path, "model.bin") | ||
|
||
# if not os.path.isfile(dest_file): | ||
print( | ||
f"Tuning {cfg.embedding.name} for {cfg.pre_processing.epochs} epochs." | ||
) | ||
embedding = load_facebook_model(emb_path) | ||
embedding.build_vocab( | ||
X, update=True | ||
) # adds previously unseen words to vocab | ||
embedding.train(X, total_examples=len(X), epochs=cfg.pre_processing.epochs) | ||
# embedding.save(dest_file) | ||
# print(f"Saved finetuned {cfg.embedding.name} as {dest_file}.") | ||
|
||
else: | ||
embedding = load_facebook_vectors(emb_path) | ||
elif cfg.embedding.name == "transformer": | ||
embedding = SentenceTransformer(emb_path) | ||
else: | ||
raise SystemExit(f"{cfg.embedding.name} not implemented.") | ||
|
||
return embedding |
Oops, something went wrong.