Skip to content

Commit

Permalink
Merge branch 'main' of github.com:stanfordnlp/pyvene into zen/ndif
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jul 4, 2024
2 parents 39741be + 5b35936 commit d41388f
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 8 deletions.
1 change: 1 addition & 0 deletions pyvene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .models.interventions import SourcelessIntervention
from .models.interventions import NoiseIntervention
from .models.interventions import SigmoidMaskIntervention
from .models.interventions import AutoencoderIntervention


# Utils
Expand Down
6 changes: 3 additions & 3 deletions pyvene/models/configuration_intervenable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
"max_number_of_units "
"low_rank_dimension intervention_type intervention "
"subspace_partition group_key intervention_link_key moe_key "
"source_representation hidden_source_representation",
"source_representation hidden_source_representation latent_dim",
defaults=(
0, "block_output", "pos", 1, None, None,
None, None, None, None, None, None, None),
None, None, None, None, None, None, None, None),
)


Expand Down Expand Up @@ -134,4 +134,4 @@ def __repr__(self):
return f"IntervenableConfig\n{_repr_string}"

def __str__(self):
return self.__repr__()
return self.__repr__()
4 changes: 2 additions & 2 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self, config, model, backend, **kwargs):
all_metadata["embed_dim"] = component_dim
all_metadata["use_fast"] = self.use_fast
intervention = intervention_function(
**all_metadata
**all_metadata
)

if representation.intervention_link_key in self._intervention_pointers:
Expand Down Expand Up @@ -2272,4 +2272,4 @@ def eval_alignment(
all_num_examples += [b_s]
result = weighted_average(all_metrics, all_num_examples)

return result
return result
31 changes: 28 additions & 3 deletions pyvene/models/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from abc import ABC, abstractmethod

from .layers import RotateLayer, LowRankRotateLayer, SubspaceLowRankRotateLayer
from .layers import RotateLayer, LowRankRotateLayer, SubspaceLowRankRotateLayer, AutoencoderLayer
from .basic_utils import sigmoid_boundary
from .intervention_utils import _can_use_fast, _do_intervention_by_swap

Expand Down Expand Up @@ -54,7 +54,8 @@ def set_source_representation(self, source_representation):
self.register_buffer('source_representation', source_representation)

def set_interchange_dim(self, interchange_dim):
if isinstance(interchange_dim, int):
if not isinstance(interchange_dim, torch.Tensor):
# Convert integer or list into torch.Tensor.
self.interchange_dim = torch.tensor(interchange_dim)
else:
self.interchange_dim = interchange_dim
Expand Down Expand Up @@ -556,4 +557,28 @@ def forward(self, base, source=None, subspaces=None):

def __str__(self):
return f"NoiseIntervention()"



class AutoencoderIntervention(TrainableIntervention):
"""Intervene in the latent space of an autoencoder."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
if "latent_dim" not in kwargs:
raise ValueError('Missing latent_dim in kwargs.')
if "embed_dim" in kwargs:
self.embed_dim = torch.tensor(kwargs["embed_dim"])
self.autoencoder = AutoencoderLayer(
self.embed_dim, kwargs["latent_dim"])

def forward(self, base, source, subspaces=None):
base_dtype = base.dtype
base = base.to(self.autoencoder.encoder[0].weight.dtype)
base_latent = self.autoencoder.encode(base)
source_latent = self.autoencoder.encode(source)
base_latent[..., self.interchange_dim] = source_latent[..., self.interchange_dim]
inv_output = self.autoencoder.decode(base_latent)
return inv_output.to(base_dtype)

def __str__(self):
return f"AutoencoderIntervention()"
45 changes: 45 additions & 0 deletions pyvene/models/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from abc import ABCMeta, abstractmethod

import torch


Expand Down Expand Up @@ -56,3 +58,46 @@ def __init__(self, n, m, init_orth=True):

def forward(self, x, l, r):
return torch.matmul(x.to(self.weight.dtype), self.weight[:, l:r])


class AutoencoderLayerBase(torch.nn.Module, metaclass=ABCMeta):
"""An abstract base class that defines an interface of an autoencoder."""

@abstractmethod
def encode(self, x):
...

@abstractmethod
def decode(self, latent):
...


class AutoencoderLayer(AutoencoderLayerBase):
"""An autoencoder with a single-layer encoder and single-layer decoder."""
def __init__(self, input_dim, latent_dim, **kwargs):
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.encoder = torch.nn.Sequential(
torch.nn.Linear(input_dim, latent_dim, bias=True),
torch.nn.ReLU())
self.decoder = torch.nn.Sequential(
torch.nn.Linear(latent_dim, input_dim, bias=True))

def encode(self, x):
x = x.to(self.encoder[0].weight.dtype)
x = x - self.decoder[0].bias
latent = self.encoder(x)
return latent

def decode(self, latent):
return self.decoder(latent)

def forward(self, base, return_latent=False):
base_type = base.dtype
base = base.to(self.encoder[0].weight.dtype)
latent = self.encode(base)
base_reconstruct = self.decode(latent)
if not return_latent:
return base_reconstruct.to(base_type)
return {'latent': latent, 'output': base_reconstruct}

0 comments on commit d41388f

Please sign in to comment.