From 1c53d2ef7a8a196fa0ef393a04056345a1f66af5 Mon Sep 17 00:00:00 2001 From: duxin Date: Mon, 7 Nov 2022 16:22:59 +0900 Subject: [PATCH] Implement FIREWord.most_similar --- firelang/function/base.py | 3 --- firelang/models/_fireword.py | 42 ++++++++++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/firelang/function/base.py b/firelang/function/base.py index 9272b40..54ec056 100644 --- a/firelang/function/base.py +++ b/firelang/function/base.py @@ -50,9 +50,6 @@ def __mul__(self, other: Union[float, Functional, firelang.Measure]): - if is `float` or `Functional`: generate a new Functional. - if is `Measure`, compute the paired integral. """ - if isinstance(other, StackingSlicing) or isinstance(other, firelang.Measure): - assert self.shape == other.shape - if isinstance(other, float) or isinstance(other, Functional): return Functional( locals_={"shape": self.shape}, diff --git a/firelang/models/_fireword.py b/firelang/models/_fireword.py index 01ca9e6..37d5512 100644 --- a/firelang/models/_fireword.py +++ b/firelang/models/_fireword.py @@ -139,6 +139,40 @@ def field( outputs = outputs.view(*meshx.shape) return outputs + @torch.no_grad() + def most_similar( + self, word: str, k: int = 10, p: float = 0.3, + ) -> List[Tuple[str, float]]: + """Return the most similar `k` words to `word`, as well as the (frequency-adjusted) similarity scores. + + Args: + word (str): the word of which the most similar words are computed. + k (int): the number of similar words to return. + p (float, optional): a exponent controlling the strength of frequency-based adjustment. Defaults to 0.3. + + Returns: + List[Tuple[str, float]]: the similar words and their frequency-adjusted similar scores. + """ + w = self[word] + sims = self.funcs * w.measures + w.funcs * self.measures # (vocab_size,) + + # adjust with word frequency + vocab = self.vocab + if p is not None: + counts = torch.tensor( + [vocab.i2count[self.rank2i[rank]] for rank in range(len(vocab))], + dtype=torch.float32, + device=sims.device, + ) + sims = sims * (counts**p) + + topk = sims.topk(k) + ranks = topk.indices.data.cpu().numpy() + values = topk.values.data.cpu().numpy() + + words = [vocab.i2s[self.rank2i[rank]] for rank in ranks] + return list(zip(words, values)) + def loss_skipgram( self, pairs: Tensor, labels: Tensor, args: Namespace = Namespace() ) -> Loss: @@ -200,14 +234,18 @@ def __mul__(self, other: FIRETensor): if id(other) == id(self): return self.measures.integral(self.funcs) * 2 else: - return other.measures_other.integral(self.funcs) + self.measures.integral(other.funcs_other) + return other.measures_other.integral(self.funcs) + self.measures.integral( + other.funcs_other + ) def __matmul__(self, other: FIRETensor): if id(other) == id(self): mat = self.measures.integral(self.funcs, cross=True) return mat + torch.transpose(mat, -2, -1) else: - return other.measures_other.integral(self.funcs, cross=True) + torch.transpose( + return other.measures_other.integral( + self.funcs, cross=True + ) + torch.transpose( self.measures.integral(other.funcs_other, cross=True), -2, -1 )