Skip to content

Commit

Permalink
add in basic support of activation collection and interchange interve…
Browse files Browse the repository at this point in the history
…ntion
  • Loading branch information
frankaging committed Jul 5, 2024
1 parent 5c3e5ae commit bf07a70
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 255 deletions.
2 changes: 1 addition & 1 deletion pyvene/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Generic APIs
from .data_generators.causal_model import CausalModel
from .models.intervenable_base import IntervenableModel, IntervenableNdifModel
from .models.intervenable_base import IntervenableModel, IntervenableNdifModel, build_intervenable_model
from .models.configuration_intervenable_model import IntervenableConfig
from .models.configuration_intervenable_model import RepresentationConfig

Expand Down
78 changes: 54 additions & 24 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json, logging, torch, types
import nnsight
import numpy as np
from collections import OrderedDict
from typing import List, Optional, Tuple, Union, Dict, Any

from .constants import *
from .basic_utils import *
from .modeling_utils import *
from .intervention_utils import *
Expand Down Expand Up @@ -469,14 +471,14 @@ def _gather_intervention_output(
)

# gather based on intervention locations
selected_output = original_output
# gather_neurons(
# original_output,
# self.representations[
# representations_key
# ].unit,
# unit_locations,
# )
selected_output = gather_neurons(
original_output,
self.representations[
representations_key
].unit,
unit_locations,
device=self.get_device()
)

return selected_output

Expand Down Expand Up @@ -518,6 +520,7 @@ def _scatter_intervention_output(
self.model_type,
self.model_config,
self.use_fast,
device=self.get_device()
)

return original_output
Expand Down Expand Up @@ -675,8 +678,10 @@ class IntervenableNdifModel(BaseModel):

def __init__(self, config, model, **kwargs):
super().__init__(config, model, "ndif", **kwargs)
logging.info(
f"We currently have very limited intervention support for NdifModel type."
# this is not used for now.
self.remote = kwargs["remote"] if "remote" in kwargs else False
logging.warning(
f"We currently have very limited intervention support for ndif backend."
)

def save(
Expand Down Expand Up @@ -737,13 +742,21 @@ def _intervention_getter(
"""
handlers = []
for key_i, key in enumerate(keys):
intervention, module_hook = self.interventions[key]
intervention, (module_hook, hook_type) = self.interventions[key]
if self._is_generation:
raise NotImplementedError("Generation is not implemented for ndif backend")

# only the first item will be accessed.
output = module_hook.output[0].save()
if hook_type == CONST_INPUT_HOOK:
output = module_hook.input
elif hook_type == CONST_OUTPUT_HOOK:
output = module_hook.output

# TODO: this could be faulty by assuming the types.
if isinstance(output.dtype, tuple) and isinstance(output.dtype[0], tuple):
output = output[0][0]
elif isinstance(output.dtype, tuple):
output = output[0]

if isinstance(intervention, SkipIntervention):
raise NotImplementedError("Skip intervention is not implemented for ndif backend")
else:
Expand All @@ -755,7 +768,7 @@ def _intervention_getter(
# WARNING: might be worth to check the below assertion at runtime,
# but commenting it out for now just to avoid confusion.
# assert key not in self.activations
self.activations[key] = selected_output
self.activations[key] = selected_output.save()
else:
raise NotImplementedError("Stateful models are not supported for ndif backend")

Expand All @@ -769,13 +782,12 @@ def _intervention_setter(
subspaces,
) -> HandlerList:
"""
Create a list of setter handlers that will set activations
Create a list of setter tracer that will set activations
"""
self._tidy_stateful_activations()

handlers = []
for key_i, key in enumerate(keys):
intervention, module_hook = self.interventions[key]
intervention, (module_hook, hook_type) = self.interventions[key]
if unit_locations_base[0] is not None:
self._batched_setter_activation_select[key] = [
0 for _ in range(len(unit_locations_base[0]))
Expand All @@ -784,9 +796,17 @@ def _intervention_setter(
if self._is_generation:
raise NotImplementedError("Generation is not implemented for ndif backend")

# only the first item will be accessed
output = module_hook.output[0]
# .clone().save()
if hook_type == CONST_INPUT_HOOK:
output = module_hook.input
elif hook_type == CONST_OUTPUT_HOOK:
output = module_hook.output

# TODO: this could be faulty by assuming the types.
if isinstance(output.dtype, tuple) and isinstance(output.dtype[0], tuple):
output = output[0][0]
elif isinstance(output.dtype, tuple):
output = output[0]

selected_output = self._gather_intervention_output(
output, key, unit_locations_base[key_i]
)
Expand Down Expand Up @@ -852,11 +872,11 @@ def _intervention_setter(

if isinstance(output, tuple):
_ = self._scatter_intervention_output(
module_hook.output[0], intervened_representation, key, unit_locations_base[key_i]
output[0], intervened_representation, key, unit_locations_base[key_i]
)
else:
_ = self._scatter_intervention_output(
module_hook.output, intervened_representation, key, unit_locations_base[key_i]
output, intervened_representation, key, unit_locations_base[key_i]
)

self._intervention_state[key].inc_setter_version()
Expand Down Expand Up @@ -1023,8 +1043,7 @@ def forward(
self.interventions[key][0],
CollectIntervention
):
print(self.activations[key])
collected_activations += self.activations[key]
collected_activations += self.activations[key].clone()

except Exception as e:
raise e
Expand Down Expand Up @@ -2273,3 +2292,14 @@ def eval_alignment(
result = weighted_average(all_metrics, all_num_examples)

return result


def build_intervenable_model(config, model, **kwargs):
"""
Factory design pattern for different types of intervenable models.
"""
if isinstance(model, nnsight.LanguageModel):
return IntervenableNdifModel(config, model, **kwargs)
else:
return IntervenableModel(config, model, **kwargs)

22 changes: 15 additions & 7 deletions pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def get_module_hook(model, representation, backend="native") -> nn.Module:
elif backend == "ndif":
# we assume the input v.s. output is handled outside
module_hook = module
return (module_hook, hook_type)

return module_hook

Expand Down Expand Up @@ -243,7 +244,7 @@ def output_to_subcomponent(output, component, model_type, model_config):
return subcomponent


def gather_neurons(tensor_input, unit, unit_locations_as_list):
def gather_neurons(tensor_input, unit, unit_locations_as_list, device=None):
"""Gather intervening neurons.
:param tensor_input: tensors of shape (batch_size, sequence_length, ...) if
Expand All @@ -262,8 +263,10 @@ def gather_neurons(tensor_input, unit, unit_locations_as_list):

if "." in unit:
unit_locations = (
torch.tensor(unit_locations_as_list[0], device=tensor_input.device),
torch.tensor(unit_locations_as_list[1], device=tensor_input.device),
torch.tensor(unit_locations_as_list[0],
device=tensor_input.device if device is None else device),
torch.tensor(unit_locations_as_list[1],
device=tensor_input.device if device is None else device),
)
# we assume unit_locations is a tuple
head_unit_locations = unit_locations[0]
Expand All @@ -290,8 +293,9 @@ def gather_neurons(tensor_input, unit, unit_locations_as_list):
return tensor_output # b, num_unit (h), num_unit (pos), d
else:
unit_locations = torch.tensor(
unit_locations_as_list, device=tensor_input.device
unit_locations_as_list, device="cpu"
)

tensor_output = torch.gather(
tensor_input,
1,
Expand All @@ -311,6 +315,7 @@ def scatter_neurons(
model_type,
model_config,
use_fast,
device=None
):
"""Replace selected neurons in `tensor_input` by `replacing_tensor_input`.
Expand All @@ -337,12 +342,15 @@ def scatter_neurons(
if "." in unit:
# extra dimension for multi-level intervention
unit_locations = (
torch.tensor(unit_locations_as_list[0], device=tensor_input.device),
torch.tensor(unit_locations_as_list[1], device=tensor_input.device),
torch.tensor(unit_locations_as_list[0],
device=tensor_input.device if device is None else device),
torch.tensor(unit_locations_as_list[1],
device=tensor_input.device if device is None else device),
)
else:
unit_locations = torch.tensor(
unit_locations_as_list, device=tensor_input.device
unit_locations_as_list,
device=tensor_input.device if device is None else device
)

# if tensor is splitted, we need to get the start and end indices
Expand Down
Loading

0 comments on commit bf07a70

Please sign in to comment.