Skip to content

Commit

Permalink
added new benchmark file
Browse files Browse the repository at this point in the history
  • Loading branch information
Xeadriel committed Mar 13, 2024
1 parent e3cdeca commit 0476961
Show file tree
Hide file tree
Showing 8 changed files with 9,973 additions and 0 deletions.
Empty file.
Binary file not shown.
Binary file not shown.
40 changes: 40 additions & 0 deletions scripts/CorpusPreProcessor/corpusPreprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import csv
from nltk import word_tokenize
import string

def prepareMSRData(path):
"""
Makes the sentences lower case, strips leading and trailing whitespace and removes punctuation.
Returns the sentence pairs with their labels.
parameters:
String : path: path of MSR data
returns:
List : sentencePairs: list of tuples that contain sentence pairs
List : labels: integer values 1 or 0 that indicate whether the sentences are similar
"""
labels = []
sentencePairs = []

with open(path, newline='', encoding="utf8") as csvfile:
# there is a line that's broken by quotation marks. ß is never used in English so there is that
reader = csv.reader(csvfile, delimiter='\t', quotechar='ß')
next(reader) # skip header

for row in reader:
labels.append(row[0])

sentence1 = row[3].strip().lower()
sentence2 = row[4].strip().lower()

sentence1 = sentence1.translate(str.maketrans('', '', string.punctuation))
sentence2 = sentence2.translate(str.maketrans('', '', string.punctuation))

sentence1 = " ".join(word_tokenize(sentence1))
sentence2 = " ".join(word_tokenize(sentence2))

sentencePairs.append((sentence1, sentence2))

return sentencePairs, labels

53 changes: 53 additions & 0 deletions scripts/addtionalBenchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import List, Mapping
import argparse
import os
from collections.abc import Callable
from collections import defaultdict, Counter
from corpusit import Vocab
import nltk
import numpy as np
import pandas as pd
import torch
from scipy.stats import spearmanr
import multiprocessing
import tqdm
from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN
from sklearn.metrics import accuracy_score

from firelang.models import FireWord, FireTensor
from firelang.utils.log import logger
from firelang.utils.timer import Timer, elapsed
from scripts.sentsim import sentsim_as_weighted_wordsim_cuda


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpointsMRS",
nargs="+",
default=[
"checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy",
"checkpoints/wacky_mlplanardiv_d2_l4_k10",
"checkpoints/wacky_mlplanardiv_d2_l8_k20",
],
)

args = parser.parse_args()

for checkpoint in args.checkpointsMRS:
model = firelang.modules.FireWord.from_pretrained(checkpoint).to("cuda")
print(model)

# 'benchmarks/MSR/msr_paraphrase_train.csv'









if __name__ == "__main__":
main()
1,726 changes: 1,726 additions & 0 deletions scripts/benchmarks/MSR/msr_paraphrase_test.txt

Large diffs are not rendered by default.

4,077 changes: 4,077 additions & 0 deletions scripts/benchmarks/MSR/msr_paraphrase_train.csv

Large diffs are not rendered by default.

4,077 changes: 4,077 additions & 0 deletions scripts/benchmarks/MSR/msr_paraphrase_train.txt

Large diffs are not rendered by default.

0 comments on commit 0476961

Please sign in to comment.