diff --git a/firelang/measure/dirac.py b/firelang/measure/dirac.py index 7023cff..bf5f3b9 100644 --- a/firelang/measure/dirac.py +++ b/firelang/measure/dirac.py @@ -71,7 +71,7 @@ def integral( func = func.view(*func.shape[:-1], func_size, 1, 1) size = func_size * self.k - nrow_per_batch = (batch_size - size + 1) // size + nrow_per_batch = (batch_size + size - 1) // size fx = [] for i in range(0, measure_size, nrow_per_batch): _x = x[..., i : i + nrow_per_batch, :, :]