Skip to content

Commit

Permalink
gitignore, checkpoints, corpuspreprocessor moved, first benchmark att…
Browse files Browse the repository at this point in the history
…empt, prints
  • Loading branch information
Xeadriel committed Mar 28, 2024
1 parent 91fd295 commit d18b1d4
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 9 deletions.
160 changes: 160 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
Empty file.
Binary file not shown.
Binary file not shown.
120 changes: 115 additions & 5 deletions scripts/additionalBenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from firelang.utils.log import logger
from firelang.utils.timer import Timer, elapsed
from scripts.sentsim import sentsim_as_weighted_wordsim_cuda
from scripts.corpusPreprocessor import *
from scripts.benchmark import sentence_simmat


@torch.no_grad()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -35,15 +37,123 @@ def main():

args = parser.parse_args()

for checkpoint in args.checkpointsMRS:
model = FireWord.from_pretrained(checkpoint).to("cuda")
print(model)
sifA = 0.001
device = "cpu"

# for checkpoint in args.checkpointsMRS:
checkpoint = args.checkpointsMRS[0]
model = FireWord.from_pretrained(checkpoint).to(device)

vocab: Vocab = model.vocab

pairs, labels = prepareMSRData('scripts/benchmarks/MSR/msr_paraphrase_train.csv')

assert len(pairs[0]) == len(labels) == len(pairs[0])

# print(f"pairs: {len(pairs[0])}")
score, preds = benchmark_sentence_similarity(model, pairs, labels, sifA)

print(f"score: {score}")
print(f"(preds, labels):\n{np.array([preds, labels])}")








@torch.no_grad()
@Timer(elapsed, "sentsim")
def benchmark_sentence_similarity(
model: FireWord,
pairs,
labels,
sif_alpha=1e-3,
):
vocab: Vocab = model.vocab

counts = pd.Series(vocab.counts_dict())
probs = counts / counts.sum()
sif_weights: Mapping[str, float] = {
w: sif_alpha / (sif_alpha + prob) for w, prob in probs.items()
}

scores = 0

sents1 = pairs[0]
sents2 = pairs[1]
allsents = sents1 + sents2
allsents = [
[w for w in sent if w in sif_weights and w != vocab.unk]
for sent in allsents
]

""" similarity """
with Timer(elapsed, "similarity", sync_cuda=True):
simmat = sentence_simmat(model, allsents, sif_weights)

print("simmat")
print(simmat)
print(f"max: {max(max(x) for x in simmat)}")
print(f"min: {min(min(x) for x in simmat)}")

""" regularization: sim(i,j) <- sim(i,j) - 0.5 * (sim(i,i) + sim(j,j)) """
with Timer(elapsed, "regularization"):
diag = np.diag(simmat)
simmat = simmat - 0.5 * (diag.reshape(-1, 1) + diag.reshape(1, -1))
print("diag")
print(diag)
print("diag.reshape(-1, 1)")
print(diag.reshape(-1, 1))
print("diag.reshape(1, -1)")
print(diag.reshape(1, -1))
print("(diag.reshape(-1, 1) + diag.reshape(1, -1))*0.5")
print((diag.reshape(-1, 1) + diag.reshape(1, -1))/2)
print("simmat")
print(simmat)
print(f"max: {max(max(x) for x in simmat)}")
print(f"min: {min(min(x) for x in simmat)}")

""" rescaling (smoothing) and exponential """

def _simmat_rescale(simmat) -> np.ndarray:
scale = np.abs(simmat).mean(axis=1, keepdims=True)
simmat = simmat / (scale * scale.T) ** 0.5

print("scale")
print(scale)
print("(scale * scale.T)")
print((scale * scale.T))
print("(scale * scale.T)** 0.5")
print((scale * scale.T)** 0.5)
print("simmat")
print(simmat)
print(f"max: {max(max(x) for x in simmat)}")
print(f"min: {min(min(x) for x in simmat)}")

# 'benchmarks/MSR/msr_paraphrase_train.csv'
return simmat

with Timer(elapsed, "smooth"):
simmat = _simmat_rescale(simmat)
simmat = np.exp(simmat)
print("exp simmat")
print(simmat)
print(f"max: {max(max(x) for x in simmat)}")
print(f"min: {min(min(x) for x in simmat)}")

N = len(pairs[0])
preds = [simmat[i, i + N] for i in range(N)]
# print(simmat.shape)
# print(simmat)
# print(f"max: {max(max(x) for x in simmat)}")
# print(f"min: {min(min(x) for x in simmat)}")
# print(sum([x >= 1 for x in simmat]))
# print()

score = sum([int(preds[i] >= 0.5) == (labels[i]) for i in range(len(preds))]) / len(preds)

return score, np.array(preds)



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ def prepareMSRData(path):
String : path: path of MSR data
returns:
List : sentencePairs: list of tuples that contain sentence pairs
List : sentencePairs: 2 lists that contain sentence pairs
List : labels: integer values 1 or 0 that indicate whether the sentences are similar
"""
labels = []
sentencePairs = []
sentencePairs = [[], []]

with open(path, newline='', encoding="utf8") as csvfile:
# there is a line that's broken by quotation marks. ß is never used in English so there is that
reader = csv.reader(csvfile, delimiter='\t', quotechar='ß')
next(reader) # skip header

for row in reader:
labels.append(row[0])
labels.append(int(row[0]))

sentence1 = row[3].strip().lower()
sentence2 = row[4].strip().lower()
Expand All @@ -34,7 +34,8 @@ def prepareMSRData(path):
sentence1 = " ".join(word_tokenize(sentence1))
sentence2 = " ".join(word_tokenize(sentence2))

sentencePairs.append((sentence1, sentence2))
sentencePairs[0].append(sentence1)
sentencePairs[1].append(sentence2)

return sentencePairs, labels

0 comments on commit d18b1d4

Please sign in to comment.