Skip to content

Commit

Permalink
Merge remote-tracking branch 'Plasma-Blue/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Dec 5, 2023
2 parents 0bac106 + 038d993 commit 24f059a
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import itertools
import multiprocessing
import os
import traceback
Expand Down Expand Up @@ -549,20 +550,11 @@ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:
assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'

num_predictons = 2 ** len(mirror_axes)
if 0 in mirror_axes:
prediction += torch.flip(self.network(torch.flip(x, (2,))), (2,))
if 1 in mirror_axes:
prediction += torch.flip(self.network(torch.flip(x, (3,))), (3,))
if 2 in mirror_axes:
prediction += torch.flip(self.network(torch.flip(x, (4,))), (4,))
if 0 in mirror_axes and 1 in mirror_axes:
prediction += torch.flip(self.network(torch.flip(x, (2, 3))), (2, 3))
if 0 in mirror_axes and 2 in mirror_axes:
prediction += torch.flip(self.network(torch.flip(x, (2, 4))), (2, 4))
if 1 in mirror_axes and 2 in mirror_axes:
prediction += torch.flip(self.network(torch.flip(x, (3, 4))), (3, 4))
if 0 in mirror_axes and 1 in mirror_axes and 2 in mirror_axes:
prediction += torch.flip(self.network(torch.flip(x, (2, 3, 4))), (2, 3, 4))
axes_combinations = [
c for i in range(len(mirror_axes)) for c in itertools.combinations([m + 2 for m in mirror_axes], i + 1)
]
for axes in axes_combinations:
prediction += torch.flip(self.network(torch.flip(x, (*axes,))), (*axes,))
prediction /= num_predictons
return prediction

Expand Down

0 comments on commit 24f059a

Please sign in to comment.