Skip to content

Commit

Permalink
added all benchmark tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Xeadriel committed May 2, 2024
1 parent c4b93fc commit 46e30c6
Show file tree
Hide file tree
Showing 37 changed files with 644,511 additions and 78 deletions.
636 changes: 567 additions & 69 deletions scripts/additionalBenchmark.py

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def benchmark_word_similarity(
_cache[f"oov_reported/{bname}"] = True

""" similarity """
with Timer(elapsed, "similarity", sync_cuda=True):
with Timer(elapsed, "similarity", sync_cuda=False):
x1: FIRETensor = model.forward(pairs[..., 0])
x2: FIRETensor = model.forward(pairs[..., 1])
preds = x1.measures.integral(x2.funcs - x1.funcs) + x2.measures.integral(
Expand All @@ -266,7 +266,7 @@ def _estimate_mean_var(func, measure):
mean, std = sims.mean(dim=1), sims.std(dim=1)
return mean, std

with Timer(elapsed, "smooth", sync_cuda=True):
with Timer(elapsed, "smooth", sync_cuda=False):
allids = torch.cat([pairs[:, 0], pairs[:, 1]], dim=0)
xall: FIRETensor = model(allids)

Expand Down Expand Up @@ -376,13 +376,13 @@ def load_sentsim_benchmark(
tasks = [(x, y) for x, y in SENTSIM_BENCHMARKS_PATH if x == name]
allsentpair, allgs = [], []
for finput, fgs in tasks:
lines = open(f"{dirpath}/{finput}", "rt").readlines()
lines = open(f"{dirpath}/{finput}", "rt", errors= "ignore").readlines()
if lower:
lines = [line.lower() for line in lines]
sentpairs = [line.strip().split("\t") for line in lines]
allsentpair.extend(sentpairs)

lines = open(f"{dirpath}/{fgs}", "rt").readlines()
lines = open(f"{dirpath}/{fgs}", "rt", errors= "ignore").readlines()
gs = [float(line.strip()) if line.strip() else None for line in lines]
allgs.extend(gs)

Expand Down Expand Up @@ -432,7 +432,7 @@ def benchmark_sentence_similarity(
]

""" similarity """
with Timer(elapsed, "similarity", sync_cuda=True):
with Timer(elapsed, "similarity", sync_cuda=False):
simmat = sentence_simmat(model, allsents, sif_weights)

""" regularization: sim(i,j) <- sim(i,j) - 0.5 * (sim(i,i) + sim(j,j)) """
Expand Down Expand Up @@ -558,7 +558,7 @@ def benchmark_word_in_context(
]

""" similarity """
with Timer(elapsed, "similarity", sync_cuda=True):
with Timer(elapsed, "similarity", sync_cuda=False):
simmat = sentence_simmat(model, allsents, sif_weights)

""" regularization: sim(i,j) <- sim(i,j) - 0.5 * (sim(i,i) + sim(j,j)) """
Expand All @@ -577,6 +577,7 @@ def benchmark_word_in_context(
sentsims = np.array([simmat[i, i + N] for i in range(N)])
sentsims = np.exp(sentsims)


threshold = np.median(sentsims)
preds = sentsims > threshold
acc = (preds == benchmark.labels).mean()
Expand Down Expand Up @@ -672,7 +673,7 @@ def benchmark_wordsense_number(

if __name__ == "__main__":

device = "cuda"
device = "cpu"

parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down
Loading

0 comments on commit 46e30c6

Please sign in to comment.