Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
littlefive5 committed Mar 5, 2024
1 parent 0871ca0 commit a6a7c56
Showing 4 changed files with 49 additions and 23 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -11,3 +11,6 @@ __pycache__/
*~
.DS_Store
.env
*/__pycache__/
*ims
*/ims/
29 changes: 20 additions & 9 deletions acdc/knowledge/knowledge_dataset.py
Original file line number Diff line number Diff line change
@@ -71,7 +71,7 @@ class Relation(DataClassJsonMixin):

_domain: list[str] | None = None
_range: list[str] | None = None

reverse_prompt_templates: list[str] | None = None
@property
def domain(self) -> set[str]:
if self._domain is not None:
@@ -350,16 +350,27 @@ def get_and_filter_dataset(
index_name='city_in_country.pt',
data_path="./data",
tokenizer=None,
reverse=False,
):
paths=get_path(data_path,knowledge_type,relation_name)
relation = load_dataset(paths)[0]
prompt_template = relation.prompt_templates[0]
sentences = [
prompt_template.format(sample.subject) for sample in relation.samples
]
answers = [
sample.object for sample in relation.samples
]
if reverse:
print("reverse_relation")
prompt_template = relation.reverse_prompt_templates[0]
sentences = [
prompt_template.format(sample.object) for sample in relation.samples
]
answers = [
sample.subject for sample in relation.samples
]
else:
prompt_template = relation.prompt_templates[0]
sentences = [
prompt_template.format(sample.subject) for sample in relation.samples
]
answers = [
sample.object for sample in relation.samples
]
#每个模版都有两个句子,所以两倍答案
inputs = [f"{p} {l}" for p, l in zip(sentences, answers)]
# inputs = sentences
@@ -378,7 +389,7 @@ def get_and_filter_dataset(
#left padding
labels[input_ids == tokenizer.pad_token_id] = -100
# return input_ids, labels
pt_path = os.path.join(data_path, 'pt_gpt2_large',index_name)
pt_path = os.path.join(data_path, 'pt',index_name)
if os.path.exists(pt_path):
select_index = torch.load(pt_path)
assert select_index.shape[0] == input_ids.shape[0]
3 changes: 2 additions & 1 deletion acdc/knowledge/utils.py
Original file line number Diff line number Diff line change
@@ -124,7 +124,7 @@ def get_model(name, hf_model, tokenizer, device="cuda",local_path=None) -> Hooke
return tl_model


def get_all_knowledge_things(num_examples, device, model="gpt2", model_path="",knowledge_type="", data_path="", relation_name="",index_name="",data_seed=42, metric_name="match_nll", return_one_element=True) -> AllDataThings:
def get_all_knowledge_things(num_examples, device, model="gpt2", model_path="",knowledge_type="", data_path="", relation_name="",index_name="",reverse = False, data_seed=42, metric_name="match_nll", return_one_element=True) -> AllDataThings:
hf_model, tokenizer = load_model(model_path,fp16=False)
tl_model = get_model(name=model, hf_model=hf_model, tokenizer=tokenizer,device=device,local_path=model_path)
knowledge_data, knowledge_label = get_and_filter_dataset(
@@ -133,6 +133,7 @@ def get_all_knowledge_things(num_examples, device, model="gpt2", model_path="",k
relation_name=relation_name,
index_name=index_name,
data_path=data_path,
reverse=reverse,
)
default_data = knowledge_data.to(device)
labels = knowledge_label.to(device)
37 changes: 24 additions & 13 deletions acdc/main.py
Original file line number Diff line number Diff line change
@@ -164,6 +164,10 @@
parser.add_argument('--indices-mode', type=str, default="normal")
parser.add_argument('--names-mode', type=str, default="normal")
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--specific-knowledge', type=str)
parser.add_argument('--num-examples', type=int)
parser.add_argument('--knowledge-type', type=str, default="factual")
parser.add_argument('--relation-reverse', type=str, required=False, default="False")
parser.add_argument('--reset-network', type=int, default=0, help="Whether to reset the network we're operating on before running interp on it")
parser.add_argument('--metric', type=str, default="kl_div", help="Which metric to use for the experiment")
parser.add_argument('--torch-num-threads', type=int, default=0, help="How many threads to use for torch (0=all)")
@@ -225,7 +229,7 @@
DEVICE = args.device
RESET_NETWORK = args.reset_network
SINGLE_STEP = True if args.single_step else False

specific_knowledge = args.specific_knowledge
#%% [markdown]
# <h2>Setup Task</h2>

@@ -288,15 +292,16 @@
num_examples=num_examples, metric_name=args.metric, device=DEVICE
)
elif TASK == "knowledge":
num_examples = 40
num_examples = args.num_examples
things = get_all_knowledge_things(
num_examples=num_examples, metric_name=args.metric, device=DEVICE,
model="gpt2-large",
model_path="/newdisk3/yunzhi/gpt2-large",
model="gpt2",
model_path="/newdisk3/yunzhi/gpt2",
data_path="../data",
knowledge_type="linguistic",
relation_name="adj_antonym.json",
index_name="adj_antonym.pt",
knowledge_type=args.knowledge_type,
relation_name=f"{specific_knowledge}.json",
index_name=f"{specific_knowledge}.pt",
reverse=True if args.relation_reverse == 'True' else False,
)
else:
raise ValueError(f"Unknown task {TASK}")
@@ -378,19 +383,25 @@

import datetime
exp_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

threshold = args.threshold
if args.relation_reverse == 'True':
fold_name = f"reverse_ims_{specific_knowledge}_{threshold}"
else:
fold_name = f"ims_{specific_knowledge}_{threshold}"
if not os.path.exists(fold_name):
os.makedirs(fold_name)
for i in range(args.max_num_epochs):
exp.step(testing=False)

show(
exp.corr,
f"ims/img_new_{i+1}.png",
f"{fold_name}/img_new_{i+1}.png",
show_full_index=False,
)

if IN_COLAB or ipython is not None:
# so long as we're not running this as a script, show the image!
display(Image(f"ims/img_new_{i+1}.png"))
display(Image(f"{fold_name}/img_new_{i+1}.png"))

print(i, "-" * 50)
print(exp.count_no_edges())
@@ -401,15 +412,15 @@
if exp.current_node is None or SINGLE_STEP:
show(
exp.corr,
f"ims/ACDC_img_{exp_time}.png",
f"{fold_name}/ACDC_img_{exp_time}.png",

)
break

exp.save_edges("another_final_edges.pkl")
exp.save_edges(f"{fold_name}/another_final_edges.pkl")

if USING_WANDB:
edges_fname = f"edges.pth"
edges_fname = f"{fold_name}/edges.pth"
exp.save_edges(edges_fname)
artifact = wandb.Artifact(edges_fname, type="dataset")
artifact.add_file(edges_fname)

0 comments on commit a6a7c56

Please sign in to comment.