Skip to content

Commit

Permalink
attention weights done - working on cross attention weights
Browse files Browse the repository at this point in the history
  • Loading branch information
notvenky committed Mar 4, 2024
1 parent a9eba24 commit b462113
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions basic_demo/hf_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
).to(DEVICE).eval()

image_path = "../../Downloads/vlm/IMG_0864.jpg"
command_list_txt = ["What is the distance between tennis ball and bottle of disinfectant", "Where is the tennis ball and bottle of disinfectant"]
command_list_txt = ["What is the distance between tennis ball and bottle of disinfectant"] #, "Where is the tennis ball and bottle of disinfectant"]

image = Image.open(image_path).convert('RGB')
history = []
Expand All @@ -51,15 +51,21 @@
gen_kwargs = {"max_length": 2048,
"do_sample": False} # "temperature": 0.9
with torch.no_grad():
model_outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
model_outputs = model(**inputs, output_hidden_states=True, output_attentions=True, extract_intermediate_representation=True, return_dict=True)
encoder_hidden_states = model_outputs.hidden_states
print(dir(model_outputs))
# intermediate_representation = model_outputs.intermediate_representations
# print(dir(model_outputs))
print(f"Hidden states shape: {encoder_hidden_states[-1].shape}")
if model_outputs.attentions is not None:
encoder_attentions = model_outputs.attentions
print(f"Attention shape: {encoder_attentions[-1].shape}")
else:
print("Attention weights are not available.")
intermediate_representation = model_outputs.intermediate_representations
print(intermediate_representation)



# if model_outputs.attentions is not None:
# encoder_attentions = model_outputs.attentions
# print(f"Attention shape: {encoder_attentions[-1].shape}")
# else:
# print("Attention weights are not available.")
intr_outputs = model.generate(**inputs, **gen_kwargs)
# print("Methods in Model:", dir(model))
test_outputs = model.get_output_embeddings()
Expand All @@ -71,11 +77,9 @@
response = response.split("</s>")[0]

print("Query:", query)
# print("Image Shape:", inputs['images'][0][0].shape if inputs['images'] else None)
# print("Cross Image Shape:", inputs['cross_images'][0][0].shape if inputs['cross_images'] else None)
# print("Shapes of Inputs: input_ids", inputs['input_ids'].shape, "token_type_ids", inputs['token_type_ids'].shape, "attention_mask", inputs['attention_mask'].shape)
print("Shapes of Intrinsic Outputs:", intr_outputs.shape)
print("Intrinsic Output:", intr_outputs)
# print("Intrinsic Output:", intr_outputs)
print("Number of Non-Zero Items:", (intr_outputs != 0).sum().item())
print("Shape of Reply Outputs:", reply_outputs.shape)
print("Response:", response)
Expand Down

0 comments on commit b462113

Please sign in to comment.