Skip to content

Commit

Permalink
Fixed a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Nov 2, 2022
1 parent 0a516b9 commit 48f71dc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 15 deletions.
19 changes: 4 additions & 15 deletions firelang/models/_fireword.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from firelang.measure import Measure
from firelang.function import Functional
from firelang.stack import StackingSlicing
from firelang.measure import DiracMixture, metrics
from firelang.measure import DiracMixture
from firelang.utils.timer import Timer, elapsed
from firelang.utils.optim import Loss

Expand Down Expand Up @@ -139,7 +139,9 @@ def field(
return outputs

def loss_skipgram(
self, pairs: Tensor, labels: Tensor, args: Namespace = Namespace()
self,
pairs: Tensor,
labels: Tensor,
) -> Loss:
"""Noise contrastive estimation loss for the SkipGram task.
Expand All @@ -162,19 +164,6 @@ def loss_skipgram(
)
loss.add("sim", loss_sim)

if hasattr(args, "sinkhorn_weight") and args.sinkhorn_weight > 0.0:
s = metrics.sinkhorn(
measure1,
measure2,
reg=args.sinkhorn_reg,
max_iter=args.sinkhorn_max_iter,
p=args.sinkhorn_p,
tau=args.sinkhorn_tau,
stop_threshold=args.sinkhorn_stop_threshold,
) # (n,)
s[~labels] = -s[~labels]
loss.add("sinkhorn", s * args.sinkhorn_weight)

return loss


Expand Down
1 change: 1 addition & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from typing import List
import argparse
import os
Expand Down

0 comments on commit 48f71dc

Please sign in to comment.