Skip to content

Commit

Permalink
add test code
Browse files Browse the repository at this point in the history
  • Loading branch information
littlefive5 committed Jan 21, 2024
1 parent 53597b7 commit e8c2ec8
Show file tree
Hide file tree
Showing 3 changed files with 1,713 additions and 6 deletions.
8 changes: 5 additions & 3 deletions acdc/acdc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def negative_log_probs(
if last_seq_element_only:
logprobs = logprobs[:, -1, :]
if shift:
labels = labels[..., 1:].contiguous()
logprobs = logprobs[..., :-1, :].contiguous()
# Subtract a baseline for each element -- which could be 0 or the NLL of the base_model_logprobs
nll_all = (
Expand Down Expand Up @@ -125,11 +124,13 @@ def __init__(
assert logprobs.ndim == 2
else:
assert logprobs.ndim == 3
self.shift = False
if shift:
labels = labels[..., 1:].contiguous()
self.shift = True
self.labels = labels[..., 1:].contiguous()
logprobs = logprobs[..., :-1, :].contiguous()
self.base_nll_unreduced = F.nll_loss(
logprobs.view(-1, logprobs.size(-1)), labels.view(-1), reduction="none"
logprobs.view(-1, logprobs.size(-1)), self.labels.view(-1), reduction="none"
).view(logprobs.size()[:-1])
if mask_repeat_candidates is not None:
assert self.base_nll_unreduced.shape == mask_repeat_candidates.shape
Expand All @@ -144,6 +145,7 @@ def __call__(self, logits: torch.Tensor) -> torch.Tensor:
baseline=self.base_nll_unreduced,
last_seq_element_only=self.last_seq_element_only,
return_one_element=self.return_one_element,
shift=self.shift,
)

def logit_diff_metric(logits, correct_labels, wrong_labels, return_one_element: bool=True) -> torch.Tensor:
Expand Down
7 changes: 4 additions & 3 deletions acdc/knowledge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_model(name, hf_model, tokenizer, device="cuda",local_path=None) -> Hooke
tl_model = HookedTransformer.from_pretrained(name, hf_model=hf_model, tokenizer=tokenizer,local_path=local_path)
tl_model = tl_model.to(device)
tl_model.set_use_attn_result(True)
tl_model.set_use_split_qkv_input(False)
tl_model.set_use_split_qkv_input(True)
#改了这个地方后面绘图应该会报错
if "use_hook_mlp_in" in tl_model.cfg.to_dict():
tl_model.set_use_hook_mlp_in(True)
Expand All @@ -124,12 +124,13 @@ def get_model(name, hf_model, tokenizer, device="cuda",local_path=None) -> Hooke
return tl_model


def get_all_knowledge_things(num_examples, seq_len, device, model="gpt2", model_path="", data_seed=42, metric_name="kl_div", return_one_element=True) -> AllDataThings:
def get_all_knowledge_things(num_examples, device, model="gpt2", model_path="", data_path="",data_seed=42, metric_name="match_nll", return_one_element=True) -> AllDataThings:
hf_model, tokenizer = load_model(model_path,fp16=False)
tl_model = get_model(name=model, hf_model=hf_model, tokenizer=tokenizer,device=device)
tl_model = get_model(name=model, hf_model=hf_model, tokenizer=tokenizer,device=device,local_path=model_path)
knowledge_data, knowledge_label = get_and_filter_dataset(
tokenizer=tokenizer,
knowledge_type="factual",
data_path=data_path,
)
default_data = knowledge_data.to(device)
labels = knowledge_label.to(device)
Expand Down
Loading

0 comments on commit e8c2ec8

Please sign in to comment.