Skip to content

Commit

Permalink
Implement FIREWord.most_similar
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Nov 7, 2022
1 parent f994edd commit 1c53d2e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
3 changes: 0 additions & 3 deletions firelang/function/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def __mul__(self, other: Union[float, Functional, firelang.Measure]):
- if is `float` or `Functional`: generate a new Functional.
- if is `Measure`, compute the paired integral.
"""
if isinstance(other, StackingSlicing) or isinstance(other, firelang.Measure):
assert self.shape == other.shape

if isinstance(other, float) or isinstance(other, Functional):
return Functional(
locals_={"shape": self.shape},
Expand Down
42 changes: 40 additions & 2 deletions firelang/models/_fireword.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,40 @@ def field(
outputs = outputs.view(*meshx.shape)
return outputs

@torch.no_grad()
def most_similar(
self, word: str, k: int = 10, p: float = 0.3,
) -> List[Tuple[str, float]]:
"""Return the most similar `k` words to `word`, as well as the (frequency-adjusted) similarity scores.
Args:
word (str): the word of which the most similar words are computed.
k (int): the number of similar words to return.
p (float, optional): a exponent controlling the strength of frequency-based adjustment. Defaults to 0.3.
Returns:
List[Tuple[str, float]]: the similar words and their frequency-adjusted similar scores.
"""
w = self[word]
sims = self.funcs * w.measures + w.funcs * self.measures # (vocab_size,)

# adjust with word frequency
vocab = self.vocab
if p is not None:
counts = torch.tensor(
[vocab.i2count[self.rank2i[rank]] for rank in range(len(vocab))],
dtype=torch.float32,
device=sims.device,
)
sims = sims * (counts**p)

topk = sims.topk(k)
ranks = topk.indices.data.cpu().numpy()
values = topk.values.data.cpu().numpy()

words = [vocab.i2s[self.rank2i[rank]] for rank in ranks]
return list(zip(words, values))

def loss_skipgram(
self, pairs: Tensor, labels: Tensor, args: Namespace = Namespace()
) -> Loss:
Expand Down Expand Up @@ -200,14 +234,18 @@ def __mul__(self, other: FIRETensor):
if id(other) == id(self):
return self.measures.integral(self.funcs) * 2
else:
return other.measures_other.integral(self.funcs) + self.measures.integral(other.funcs_other)
return other.measures_other.integral(self.funcs) + self.measures.integral(
other.funcs_other
)

def __matmul__(self, other: FIRETensor):
if id(other) == id(self):
mat = self.measures.integral(self.funcs, cross=True)
return mat + torch.transpose(mat, -2, -1)
else:
return other.measures_other.integral(self.funcs, cross=True) + torch.transpose(
return other.measures_other.integral(
self.funcs, cross=True
) + torch.transpose(
self.measures.integral(other.funcs_other, cross=True), -2, -1
)

Expand Down

0 comments on commit 1c53d2e

Please sign in to comment.