Skip to content

Commit

Permalink
fix bug zjunlp#298
Browse files Browse the repository at this point in the history
pengzju committed Jun 13, 2024
1 parent 3f8d523 commit 38c5c34
Showing 6 changed files with 23 additions and 26 deletions.
13 changes: 9 additions & 4 deletions easyeditor/evaluate/evaluate_utils.py
Original file line number Diff line number Diff line change
@@ -83,15 +83,17 @@ def test_prediction_acc(model, tok, hparams, prompts, targets, device, locality=
prompts, targets = [prompts, ], [targets, ]
results = []
for prompt, target_new in zip(prompts, targets):
target_new_tokens = tok.encode(' ' + target_new, add_special_tokens=False)
target_new_tokens = tok.encode(target_new, add_special_tokens=False)
prompt_tok = tok(
prompt,
return_tensors="pt",
).to(device)
gen_token = model.generate(
input_ids=prompt_tok['input_ids'],
attention_mask=prompt_tok['attention_mask'],
max_new_tokens=len(target_new_tokens)
max_new_tokens=len(target_new_tokens),
pad_token_id=tok.eos_token_id,
use_cache=False,
)
if locality:
results.append(gen_token.detach().cpu().numpy().tolist()[0][-len(target_new_tokens):])
@@ -444,15 +446,18 @@ def kl_loc_loss(pre, post, mask=None):

def F1(model, tok, hparams, prompts, targets, device, locality=False, vanilla_generation=True):
if vanilla_generation:
target_new_tokens = tok.encode(' ' + targets, add_special_tokens=False)
target_new_tokens = tok.encode(targets, add_special_tokens=False)
prompt_tok = tok(
prompts,
return_tensors="pt",
).to(device)
gen_token = model.generate(
input_ids=prompt_tok['input_ids'],
attention_mask=prompt_tok['attention_mask'],
max_new_tokens=len(target_new_tokens)
max_new_tokens=len(target_new_tokens),
pad_token_id=tok.eos_token_id,
use_cache=False,

)
return f1_score(target_new_tokens, gen_token.detach().cpu().numpy().tolist()[0][-len(target_new_tokens):], average='macro')
if isinstance(prompts, str):
12 changes: 7 additions & 5 deletions easyeditor/models/grace/GRACE.py
Original file line number Diff line number Diff line change
@@ -185,12 +185,14 @@ def forward(self, *args):
# print(self.__dict__)
return layer_out
else:
if not self.training and not self.ensure_replace_token_loc and self.key_id == -1:
token_to_edit = args[0].shape[1]-1
self.key_id = args[0].shape[1]-1
self.ensure_replace_token_loc = True
if not self.training:
if self.key_id == -1:
token_to_edit = args[0].shape[1] - 1
self.key_id = args[0].shape[1] - 1
else:
token_to_edit = min(self.key_id, args[0].shape[1] - 1)
else:
token_to_edit = min(self.key_id, args[0].shape[1]-1) # args[0].shape[1] - 1 is sequence length
token_to_edit = min(self.key_id, args[0].shape[1] - 1) # args[0].shape[1] - 1 is sequence length
query = args[0][:, token_to_edit, :] # Just use activation for last token
if self.config.val_init == "cold":
new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
2 changes: 1 addition & 1 deletion easyeditor/models/grace/grace_hparams.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ class GraceHyperParams(HyperParams):
model_name: str

# Defaults
batch_size: int = 128
batch_size: int = 1
max_length: int = 30
model_parallel: bool = False

12 changes: 1 addition & 11 deletions easyeditor/models/grace/grace_main.py
Original file line number Diff line number Diff line change
@@ -21,22 +21,12 @@ def apply_grace_to_model(
request = requests[0]
if copy:
model = deepcopy(model)
weights_copy = {}
device = torch.device(f'cuda:{hparams.device}')
editor = GRACE(model=model, config=hparams, device=device)
tokens = tokenize(request, tokenizer=tok, device=device)
editor.edit(config=hparams, tokens=tokens,edit_id=request['target_new'])
# editor.rolllback(request['target_new'])


with torch.no_grad():
for w_name in hparams.inner_params:
w_name=w_name.replace("[", ".").replace("]", "")
w = nethook.get_parameter(editor.model, w_name)
weights_copy[w_name]=w

if keep_original_weight:
weights_copy = editor.reset_layer
weights_copy = editor.reset_layer


return editor, weights_copy
8 changes: 4 additions & 4 deletions edit.py
Original file line number Diff line number Diff line change
@@ -1219,10 +1219,10 @@ def test_Llama2():
train_ds = ZsreDataset('./data/zsre_mend_train.json', size=10000)
# encode_ike_facts(sentence_model, train_ds, hparams)
# hparams = ROMEHyperParams.from_hparams('./hparams/ROME/llama-7b.yaml')
hparams = R_ROMEHyperParams.from_hparams('./hparams/R-ROME/llama-7b.yaml')
# hparams = R_ROMEHyperParams.from_hparams('./hparams/R-ROME/llama-7b.yaml')
# hparams = MEMITHyperParams.from_hparams('./hparams/MEMIT/llama-7b.yaml')
# hparams = SERACHparams.from_hparams('./hparams/SERAC/llama-7b.yaml')
# hparams = GraceHyperParams.from_hparams('./hparams/GRACE/llama-7B.yaml')
hparams = GraceHyperParams.from_hparams('./hparams/GRACE/llama-7B.yaml')
# hparams = KNHyperParams.from_hparams('./hparams/KN/gpt2-xl.yaml')

editor = BaseEditor.from_hparams(hparams)
@@ -2721,7 +2721,7 @@ def main():
# test_ROME_LlaMA()
# test_ROME_DEMO()
# ROME_DEMO_2()
# test_Llama2()
test_Llama2()
# test_ROME_Baichuan()
# test_MEND_Baichuan()
# test_MEMIT_Baichuan()
@@ -2748,7 +2748,7 @@ def main():
# test_ROME_Internlm()
# test_FT_Qwen()
# test_KN_Qwen()
test_ROME_Qwen()
# test_ROME_Qwen()
# test_IKE_Qwen()
# test_MEMIT_Qwen()
# test_MEND_Train_Qwen()
2 changes: 1 addition & 1 deletion hparams/GRACE/llama-7B.yaml
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ inner_params:
- model.layers[27].mlp.down_proj.weight

edit_lr: 1.0
n_iter: 100
n_iter: 50
eps: 1.0
dist_fn: euc # euc, mmd, cos
val_init: cold # cold, warm

0 comments on commit 38c5c34

Please sign in to comment.