Skip to content

Commit

Permalink
Benchmark with FIRETensor
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Nov 7, 2022
1 parent 69cff40 commit f994edd
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from sklearn.cluster import DBSCAN
from sklearn.metrics import accuracy_score

from firelang.models import FIREWord, PFIREWord
from firelang.models import FIREWord, FIRETensor
from firelang.models import PFIREWord
from firelang.utils.log import logger
from firelang.utils.timer import Timer, elapsed
from scripts.sentsim import sentsim_as_weighted_wordsim_cuda
Expand Down Expand Up @@ -247,30 +248,29 @@ def benchmark_word_similarity(

""" similarity """
with Timer(elapsed, "similarity", sync_cuda=True):
func1, measure1 = model.forward(pairs[..., 0])
func2, measure2 = model.forward(pairs[..., 1])
preds = measure1.integral(func2 - func1) + measure2.integral(func1 - func2)
x1: FIRETensor = model.forward(pairs[..., 0])
x2: FIRETensor = model.forward(pairs[..., 1])
preds = x1.measures.integral(x2.funcs - x1.funcs) + x2.measures.integral(x1.funcs - x2.funcs)

""" smoothing by standardization """

def _estimate_mean_var(func, measure):
sims = (
measall.integral(func, cross=True)
+ measure.integral(funcall, cross=True).T
xall.measures.integral(func, cross=True)
+ torch.transpose(measure.integral(xall.funcs, cross=True), -2, -1)
)
sims = sims - (
measure.integral(func).reshape(-1, 1)
+ measall.integral(funcall).reshape(1, -1)
+ xall.measures.integral(xall.funcs).reshape(1, -1)
)
mean, std = sims.mean(dim=1), sims.std(dim=1)
return mean, std

with Timer(elapsed, "smooth", sync_cuda=True):
allids = torch.cat([pairs[:, 0], pairs[:, 1]], dim=0)
funcall, measall = model(allids)
xall: FIRETensor = model(allids)

sims1mean, sims1std = _estimate_mean_var(func1, measure1)
sims2mean, sims2std = _estimate_mean_var(func2, measure2)
sims1mean, sims1std = _estimate_mean_var(x1.funcs, x1.measures)
sims2mean, sims2std = _estimate_mean_var(x2.funcs, x2.measures)

preds = (preds - sims1mean / 2 - sims2mean / 2) / (
sims1std * sims2std
Expand Down Expand Up @@ -518,11 +518,11 @@ def _simmat_rescale(simmat) -> np.ndarray:

@torch.no_grad()
def batched_cross_selfsim(model, words, col_batch_size=100):
_, measures = model[words]
x: FIRETensor = model[words]
wordsim = np.zeros((len(words), len(words)), dtype=np.float32)
for i in range(0, len(words), col_batch_size):
funcs, _ = model[words[i : i + col_batch_size]]
_wordsim = measures.integral(funcs, cross=True).data.cpu().numpy()
xbatch = model[words[i : i + col_batch_size]]
_wordsim = x.measures.integral(xbatch.funcs, cross=True).data.cpu().numpy()
wordsim[i : i + col_batch_size, :] += _wordsim
wordsim[:, i : i + col_batch_size] += _wordsim.T
return wordsim
Expand Down Expand Up @@ -661,9 +661,9 @@ def detect_num_clusters_DBSCAN(relwordpos, eps=0.5, minfreq=0.00):

def get_relwordpos(allmeasures, model, centerword, k=1000, pca=False):

func, measure = model[centerword]
x = model[centerword]

strength = allmeasures.integral(func, cross=True).squeeze()
strength = allmeasures.integral(x.funcs, cross=True).squeeze()
potrank = torch.argsort(strength, descending=True)
relwordids = potrank[:k]

Expand Down Expand Up @@ -693,15 +693,15 @@ def benchmark_wordsense_number(
freqwords = [key for key, val in wordcounts[:nfreqwords]]

""" polysemy prediction """
allfuncs, allmeasures = model[freqwords]
xall = model[freqwords]
with ctx.Pool(num_workers) as pool:
labels, blabels, preds, bpreds = [], [], [], []
async_results = []
for word in tqdm.tqdm(w2nsense):
label = w2nsense[word]
word = word.lower()
with Timer(elapsed, "get_relwordpos"):
relwordpos = get_relwordpos(allmeasures, model, word, k=1000, pca=True)
relwordpos = get_relwordpos(xall.measures, model, word, k=1000, pca=True)
asy = pool.apply_async(
detect_num_clusters_DBSCAN,
kwds={"relwordpos": relwordpos, "eps": eps, "minfreq": 0.000},
Expand Down Expand Up @@ -729,16 +729,16 @@ def benchmark_wordsense_number(
"--checkpoints_for_similarity",
nargs="+",
default=[
"checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy",
"checkpoints/wacky_mlplanardiv_d2_l4_k10",
"checkpoints/wacky_mlplanardiv_d2_l8_k20",
"checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy",
"checkpoints/wacky_mlplanardiv_d2_l4_k10",
"checkpoints/wacky_mlplanardiv_d2_l8_k20",
],
)
parser.add_argument(
"--checkpoints_for_polysemy",
nargs="+",
default=[
"checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy",
"checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy",
],
)
args = parser.parse_args()
Expand Down

0 comments on commit f994edd

Please sign in to comment.