Skip to content

Commit

Permalink
Merge branch 'stanfordnlp:main' into main
Browse files Browse the repository at this point in the history
aryamanarora authored Jun 17, 2024
2 parents 2219d82 + 3da8474 commit 3fccdaa
Showing 14 changed files with 558 additions and 79 deletions.
28 changes: 28 additions & 0 deletions pyvene/analyses/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import seaborn
import torch

def rotation_token_heatmap(rotate_layer,
tokens,
token_size,
variables,
intervention_size):

W = rotate_layer.weight.data
in_dim, out_dim = W.shape

assert in_dim % token_size == 0
assert in_dim / token_size >= len(tokens)

assert out_dim % intervention_size == 0
assert out_dim / intervention_size >= len(variables)

heatmap = []
for j in range(len(variables)):
row = []
for i in range(len(tokens)):
row.append(torch.norm(W[i*token_size:(i+1)*token_size, j*intervention_size:(j+1)*intervention_size]))
mean = sum(row)
heatmap.append([x/mean for x in row])
return seaborn.heatmap(heatmap,
xticklabels=tokens,
yticklabels=variables)
21 changes: 10 additions & 11 deletions pyvene/data_generators/causal_model.py
Original file line number Diff line number Diff line change
@@ -35,9 +35,6 @@ def __init__(
assert variable in self.values
assert variable in self.children
assert variable in self.functions
assert len(inspect.getfullargspec(self.functions[variable])[0]) == len(
self.parents[variable]
)
if timesteps is not None:
assert variable in timesteps
for variable2 in copy.copy(self.variables):
@@ -79,6 +76,8 @@ def __init__(
self.equiv_classes = equiv_classes
else:
self.equiv_classes = {}

def generate_equiv_classes(self):
for var in self.variables:
if var in self.inputs or var in self.equiv_classes:
continue
@@ -113,7 +112,7 @@ def generate_timesteps(self):
def marginalize(self, target):
pass

def print_structure(self, pos=None):
def print_structure(self, pos=None, font=12, node_size=1000):
G = nx.DiGraph()
G.add_edges_from(
[
@@ -123,7 +122,7 @@ def print_structure(self, pos=None):
]
)
plt.figure(figsize=(10, 10))
nx.draw_networkx(G, with_labels=True, node_color="green", pos=self.pos)
nx.draw_networkx(G, with_labels=True, node_color="green", pos=self.pos, font_size=font, node_size=node_size)
plt.show()

def find_live_paths(self, intervention):
@@ -149,12 +148,9 @@ def find_live_paths(self, intervention):
del paths[1]
return paths

def print_setting(self, total_setting, display=None):
labeler = lambda var: var + ": " + str(total_setting[var]) \
if display is None or display[var] \
else var
def print_setting(self, total_setting, font=12, node_size=1000):
relabeler = {
var: labeler(var) for var in self.variables
var: var + ": " + str(total_setting[var]) for var in self.variables
}
G = nx.DiGraph()
G.add_edges_from(
@@ -170,7 +166,7 @@ def print_setting(self, total_setting, display=None):
if self.pos is not None:
for var in self.pos:
newpos[relabeler[var]] = self.pos[var]
nx.draw_networkx(G, with_labels=True, node_color="green", pos=newpos)
nx.draw_networkx(G, with_labels=True, node_color="green", pos=newpos, font_size=font, node_size=node_size)
plt.show()

def run_forward(self, intervention=None):
@@ -233,11 +229,14 @@ def sample_input(self, mandatory=None):

def sample_input_tree_balanced(self, output_var=None, output_var_value=None):
assert output_var is not None or len(self.outputs) == 1
self.generate_equiv_classes()

if output_var is None:
output_var = self.outputs[0]
if output_var_value is None:
output_var_value = random.choice(self.values[output_var])


def create_input(var, value, input={}):
parent_values = random.choice(self.equiv_classes[var][value])
for parent in parent_values:
22 changes: 12 additions & 10 deletions pyvene/models/gemma/modelings_intervenable_gemma.py
Original file line number Diff line number Diff line change
@@ -20,41 +20,43 @@
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
}


gemma_type_to_dimension_mapping = {
"n_head": ("num_attention_heads",),
"n_kv_head": ("num_key_value_heads",),
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": ("intermediate_size",),
"mlp_output": ("hidden_size",),
"mlp_input": ("hidden_size",),
"attention_value_output": ("hidden_size",),
"head_attention_value_output": ("hidden_size/num_attention_heads",),
"head_attention_value_output": ("head_dim",),
"attention_output": ("hidden_size",),
"attention_input": ("hidden_size",),
"query_output": ("hidden_size",),
"key_output": ("hidden_size",),
"value_output": ("hidden_size",),
"head_query_output": ("hidden_size/num_attention_heads",),
"head_key_output": ("hidden_size/num_attention_heads",),
"head_value_output": ("hidden_size/num_attention_heads",),
"head_query_output": ("head_dim",),
"head_key_output": ("head_dim",),
"head_value_output": ("hhead_dim",),
}


"""gemma model with LM head"""
gemma_lm_type_to_module_mapping = {}
for k, v in gemma_type_to_module_mapping.items():
gemma_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])
gemma_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]


gemma_lm_type_to_dimension_mapping = gemma_type_to_dimension_mapping
@@ -63,7 +65,7 @@
"""gemma model with classifier head"""
gemma_classifier_type_to_module_mapping = {}
for k, v in gemma_type_to_module_mapping.items():
gemma_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])
gemma_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]


gemma_classifier_type_to_dimension_mapping = gemma_type_to_dimension_mapping
2 changes: 1 addition & 1 deletion pyvene/models/gpt_neo/modelings_intervenable_gpt_neo.py
Original file line number Diff line number Diff line change
@@ -58,7 +58,7 @@
"""gpt_neo model with LM head"""
gpt_neo_lm_type_to_module_mapping = {}
for k, v in gpt_neo_type_to_module_mapping.items():
gpt_neo_lm_type_to_module_mapping[k] = (f"transformer.{v[0]}", v[1])
gpt_neo_lm_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:]


gpt_neo_lm_type_to_dimension_mapping = gpt_neo_type_to_dimension_mapping
4 changes: 2 additions & 2 deletions pyvene/models/gpt_neox/modelings_intervenable_gpt_neox.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@


gpt_neox_type_to_dimension_mapping = {
"n_head": "num_attention_heads",
"n_head": ("num_attention_heads",),
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": (
@@ -58,7 +58,7 @@
"""gpt_neox model with LM head"""
gpt_neox_lm_type_to_module_mapping = {}
for k, v in gpt_neox_type_to_module_mapping.items():
gpt_neox_lm_type_to_module_mapping[k] = (f"gpt_neox.{v[0]}", v[1])
gpt_neox_lm_type_to_module_mapping[k] = (f"gpt_neox.{v[0]}", ) + v[1:]


gpt_neox_lm_type_to_dimension_mapping = gpt_neox_type_to_dimension_mapping
9 changes: 5 additions & 4 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
@@ -366,13 +366,14 @@ def disable_intervention_gradients(self):
# Freeze all intervention weights
pass

def set_device(self, device):
def set_device(self, device, set_model=True):
"""
Set device of interventions and the model
"""
for k, v in self.interventions.items():
v[0].to(device)
self.model.to(device)
if set_model:
self.model.to(device)

def get_device(self):
"""
@@ -1739,8 +1740,8 @@ def _batch_process_unit_location(self, inputs):

return batched_location_dict

def train(self):
self.model.train()
def train(self, mode=True):
self.model.train(mode=mode)

def eval(self):
self.model.eval()
10 changes: 6 additions & 4 deletions pyvene/models/layers.py
Original file line number Diff line number Diff line change
@@ -33,11 +33,12 @@ def forward(self, x):
class LowRankRotateLayer(torch.nn.Module):
"""A linear transformation with orthogonal initialization."""

def __init__(self, n, m):
def __init__(self, n, m, init_orth=True):
super().__init__()
# n > m
self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
torch.nn.init.orthogonal_(self.weight)
if init_orth:
torch.nn.init.orthogonal_(self.weight)

def forward(self, x):
return torch.matmul(x.to(self.weight.dtype), self.weight)
@@ -46,11 +47,12 @@ def forward(self, x):
class SubspaceLowRankRotateLayer(torch.nn.Module):
"""A linear transformation with orthogonal initialization with subspace."""

def __init__(self, n, m):
def __init__(self, n, m, init_orth=True):
super().__init__()
# n > m
self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
torch.nn.init.orthogonal_(self.weight)
if init_orth:
torch.nn.init.orthogonal_(self.weight)

def forward(self, x, l, r):
return torch.matmul(x.to(self.weight.dtype), self.weight[:, l:r])
39 changes: 22 additions & 17 deletions pyvene/models/llama/modelings_intervenable_llama.py
Original file line number Diff line number Diff line change
@@ -20,19 +20,21 @@
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
}


llama_type_to_dimension_mapping = {
"n_head": ("num_attention_heads",),
"n_kv_head": ("num_key_value_heads",),
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": ("intermediate_size",),
@@ -54,7 +56,7 @@
"""llama model with LM head"""
llama_lm_type_to_module_mapping = {}
for k, v in llama_type_to_module_mapping.items():
llama_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])
llama_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]


llama_lm_type_to_dimension_mapping = llama_type_to_dimension_mapping
@@ -63,25 +65,28 @@
"""llama model with classifier head"""
llama_classifier_type_to_module_mapping = {}
for k, v in llama_type_to_module_mapping.items():
llama_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])
llama_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]


llama_classifier_type_to_dimension_mapping = llama_type_to_dimension_mapping


def create_llama(
name="sharpbai/alpaca-7b-merged", cache_dir=None, dtype=torch.bfloat16
name="sharpbai/alpaca-7b-merged", cache_dir=None, dtype=torch.bfloat16, config=None
):
"""Creates a LLaMA Causal LM model, config, and tokenizer from the given name and revision"""
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig

config = LlamaConfig.from_pretrained(name, cache_dir=cache_dir)
tokenizer = LlamaTokenizer.from_pretrained(name, cache_dir=cache_dir)
llama = LlamaForCausalLM.from_pretrained(
name,
config=config,
cache_dir=cache_dir,
torch_dtype=dtype, # save memory
)
if config is None:
config = LlamaConfig.from_pretrained(name, cache_dir=cache_dir)
llama = LlamaForCausalLM.from_pretrained(
name,
config=config,
cache_dir=cache_dir,
torch_dtype=dtype, # save memory
)
tokenizer = LlamaTokenizer.from_pretrained(name, cache_dir=cache_dir)
else:
llama = LlamaForCausalLM(config)
tokenizer = LlamaTokenizer.from_pretrained(name, cache_dir=cache_dir)
print("loaded model")
return config, tokenizer, llama
return config, tokenizer, llama
14 changes: 8 additions & 6 deletions pyvene/models/llava/modelings_intervenable_llava.py
Original file line number Diff line number Diff line change
@@ -19,19 +19,21 @@
"mlp_output": ("language_model.model.layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("language_model.model.layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("language_model.model.layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("language_model.model.layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("language_model.model.layers[%s].self_attn.o_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("language_model.model.layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("language_model.model.layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("language_model.model.layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("language_model.model.layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("language_model.model.layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("language_model.model.layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"head_key_output": ("language_model.model.layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"head_value_output": ("language_model.model.layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("language_model.model.layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("language_model.model.layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
"head_value_output": ("language_model.model.layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
}


llava_type_to_dimension_mapping = {
"n_head": ("text_config.num_attention_heads",),
"n_kv_head": ("text_config.num_key_value_heads",),
"block_input": ("text_config.hidden_size",),
"block_output": ("text_config.hidden_size",),
"mlp_activation": ("text_config.intermediate_size",),
@@ -53,7 +55,7 @@
"""llava model with LM head"""
llava_lm_type_to_module_mapping = {}
for k, v in llava_type_to_module_mapping.items():
llava_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])
llava_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]


llava_lm_type_to_dimension_mapping = llava_type_to_dimension_mapping
@@ -62,7 +64,7 @@
"""llava model with classifier head"""
llava_classifier_type_to_module_mapping = {}
for k, v in llava_type_to_module_mapping.items():
llava_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])
llava_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]


llava_classifier_type_to_dimension_mapping = llava_type_to_dimension_mapping
Loading
Oops, something went wrong.

0 comments on commit 3fccdaa

Please sign in to comment.