Skip to content

Commit

Permalink
add support for dict outputs (select the first value)
Browse files Browse the repository at this point in the history
  • Loading branch information
leogagnon committed Jul 29, 2024
1 parent 9591f47 commit 25a3473
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ def _gather_intervention_output(
# data structure casting
if isinstance(output, tuple):
original_output = output[0].clone()
elif isinstance(output, dict):
original_output = output[list(output.keys())[0]].clone()
else:
original_output = output.clone()
# for non-sequence models, there is no concept of
Expand Down Expand Up @@ -502,6 +504,8 @@ def _scatter_intervention_output(
# data structure casting
if isinstance(output, tuple):
original_output = output[0]
elif isinstance(output, dict):
original_output = output[list(output.keys())[0]]
else:
original_output = output
# for non-sequence-based models, we simply replace
Expand Down

0 comments on commit 25a3473

Please sign in to comment.