forked from kduxin/firelang
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Xeadriel
committed
Mar 13, 2024
1 parent
e3cdeca
commit 0476961
Showing
8 changed files
with
9,973 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Binary file not shown.
Binary file added
BIN
+2.24 KB
scripts/CorpusPreProcessor/__pycache__/corpusPreprocessor.cpython-311.pyc
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import csv | ||
from nltk import word_tokenize | ||
import string | ||
|
||
def prepareMSRData(path): | ||
""" | ||
Makes the sentences lower case, strips leading and trailing whitespace and removes punctuation. | ||
Returns the sentence pairs with their labels. | ||
parameters: | ||
String : path: path of MSR data | ||
returns: | ||
List : sentencePairs: list of tuples that contain sentence pairs | ||
List : labels: integer values 1 or 0 that indicate whether the sentences are similar | ||
""" | ||
labels = [] | ||
sentencePairs = [] | ||
|
||
with open(path, newline='', encoding="utf8") as csvfile: | ||
# there is a line that's broken by quotation marks. ß is never used in English so there is that | ||
reader = csv.reader(csvfile, delimiter='\t', quotechar='ß') | ||
next(reader) # skip header | ||
|
||
for row in reader: | ||
labels.append(row[0]) | ||
|
||
sentence1 = row[3].strip().lower() | ||
sentence2 = row[4].strip().lower() | ||
|
||
sentence1 = sentence1.translate(str.maketrans('', '', string.punctuation)) | ||
sentence2 = sentence2.translate(str.maketrans('', '', string.punctuation)) | ||
|
||
sentence1 = " ".join(word_tokenize(sentence1)) | ||
sentence2 = " ".join(word_tokenize(sentence2)) | ||
|
||
sentencePairs.append((sentence1, sentence2)) | ||
|
||
return sentencePairs, labels | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import List, Mapping | ||
import argparse | ||
import os | ||
from collections.abc import Callable | ||
from collections import defaultdict, Counter | ||
from corpusit import Vocab | ||
import nltk | ||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from scipy.stats import spearmanr | ||
import multiprocessing | ||
import tqdm | ||
from sklearn.decomposition import PCA | ||
from sklearn.cluster import DBSCAN | ||
from sklearn.metrics import accuracy_score | ||
|
||
from firelang.models import FireWord, FireTensor | ||
from firelang.utils.log import logger | ||
from firelang.utils.timer import Timer, elapsed | ||
from scripts.sentsim import sentsim_as_weighted_wordsim_cuda | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--checkpointsMRS", | ||
nargs="+", | ||
default=[ | ||
"checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy", | ||
"checkpoints/wacky_mlplanardiv_d2_l4_k10", | ||
"checkpoints/wacky_mlplanardiv_d2_l8_k20", | ||
], | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
for checkpoint in args.checkpointsMRS: | ||
model = firelang.modules.FireWord.from_pretrained(checkpoint).to("cuda") | ||
print(model) | ||
|
||
# 'benchmarks/MSR/msr_paraphrase_train.csv' | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.