Skip to content

Commit

Permalink
Fixed bug in FIRETensor.__mul__ and __matmul__
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Nov 7, 2022
1 parent f7ee0a8 commit 92b1d93
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions firelang/models/_fireword.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,17 @@ def __mul__(self, other: FIRETensor):
if id(other) == id(self):
return self.measures.integral(self.funcs) * 2
else:
return other.measures_other.integral(self.funcs) + self.measures.integral(
other.funcs_other
return other.measures.integral(self.funcs) + self.measures.integral(
other.funcs
)

def __matmul__(self, other: FIRETensor):
if id(other) == id(self):
mat = self.measures.integral(self.funcs, cross=True)
return mat + torch.transpose(mat, -2, -1)
else:
return other.measures_other.integral(
self.funcs, cross=True
) + torch.transpose(
self.measures.integral(other.funcs_other, cross=True), -2, -1
return other.measures.integral(self.funcs, cross=True) + torch.transpose(
self.measures.integral(other.funcs, cross=True), -2, -1
)

def __repr__(self):
Expand Down

0 comments on commit 92b1d93

Please sign in to comment.