Skip to content

Commit

Permalink
Fixed bugs in firelang library
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Oct 16, 2022
1 parent f75b7f5 commit 3cea0af
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 14 deletions.
2 changes: 1 addition & 1 deletion firelang/function/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __mul__(self, other: Union[float, Functional, firelang.Measure]):
operator=op.mul,
)
elif isinstance(other, firelang.Measure):
return other.integral(self)
return other.integral(self, sum=False)
else:
raise TypeError(
f"`other` must be a float or Functional or Measure object, not {type(other)}."
Expand Down
32 changes: 22 additions & 10 deletions firelang/measure/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,34 @@


class Measure(StackingSlicing):
def integral(self, func, cross=False, batch_size=None, sum=True):
"""
cross=False:
func: (stack, K, input_dim) |-> (stack, K)
integral: |-> (stack, )
cross=True:
func: (measure_stack, n_component, input_dim) |-> (func_stack, measure_stack, K)
integral: |-> (func_stack, measure_stack)
def integral(self, func: firelang.Functional, cross: bool = False, batch_size: int = None, sum: bool = True):
""" Compute integral $\int f d\mu$
Args:
- func (firelang.Functional)
- cross (bool, optional): _description_. Defaults to False.
- batch_size (int, optional): _description_. Defaults to None.
- sum (bool, optional): whether sum the integral values at separate locations. Defaults to True.
Returns:
- if cross == False:
func should accept Tensor of shape (stack, K, input_dim) and produces (stack, K)
- if sum == True: returns (stack,)
- else: returns (stack, K)
- if cross == True:
func should be able to accept both (measure_stack, n_component, input_dim)
and (func_stack, measure_stack, n_component, input_dim), and produces (func_stack, measure_stack, K)
- if sum == True: returns (func_stack, measure_stack)
- else: returns (func_stack, measure_stack, K)
"""
raise NotImplementedError

def __mul__(self, other: firelang.Functional):
"""Returns: Tensor of shape (stack, K)"""
assert isinstance(other, firelang.Functional)
assert self.stack_size == other.stack_size
return self.integral(other)
return self.integral(other, sum=False)

def __matmul__(self, other: firelang.Functional):
"""Returns: Tensor of shape (func_stack, measure_stack)"""
assert isinstance(other, firelang.Functional)
return self.integral(other, cross=True).T
return self.integral(other, cross=True).T
4 changes: 2 additions & 2 deletions firelang/models/word.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def __mul__(self, other: FIREWordSlice):
funcs, measures = self
funcs_other, measures_other = other
if id(other) == id(self):
return measures.integral(funcs) * 2
return measures.integral(funcs, sum=False) * 2
else:
return measures_other.integral(funcs) + measures.integral(funcs_other)
return measures_other.integral(funcs, sum=False) + measures.integral(funcs_other, sum=False)

def __matmul__(self, other: FIREWordSlice):
funcs, measures = self
Expand Down
2 changes: 1 addition & 1 deletion firelang/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def restack(
while len(_cache[tag]) > max_cached_copies:
_cache[tag].popitem(last=False) # pop the earliest

return new
return new.to(self.detect_device())

stack = restack

Expand Down

0 comments on commit 3cea0af

Please sign in to comment.