Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
zRzRzRzRzRzRzR committed Dec 26, 2023
2 parents 04a46b6 + 54f28da commit 48dc3c6
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dialogue with images, GUI Agent, Grounding**, and more.
🌟 **Jump to detailed introduction: [Introduction to CogVLM](#introduction-to-cogvlm)
🆕 [Introduction to CogAgent](#introduction-to-cogagent)**

📔 For more detailed usage information, please refer to: [CogVLM technical documentation(Only Chinese)](https://zhipu-ai.feishu.cn/wiki/LXQIwqo1OiIVTykMh9Lc3w1Fn7g)
📔 For more detailed usage information, please refer to: [CogVLM & CogAgent's technical documentation (in Chinese)](https://zhipu-ai.feishu.cn/wiki/LXQIwqo1OiIVTykMh9Lc3w1Fn7g)

<table>
<tr>
Expand Down
10 changes: 4 additions & 6 deletions README_zh.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# CogVLM & CogAgent

📗 [中文版README](./README_zh.md)
- 🔥🔥🔥 **News**: ```2023/12/26```:我们公开了 [CogVLM-SFT-311K](dataset_zh.md) 数据集,它包含了超过15万条我们用于训练 **CogVLM v1.0(仅该模型)** 的数据。欢迎关注和使用。
🔥🔥🔥 🆕 ```2023/12/15```: CogAgent正式上线!CogAgent是基于CogVLM开发的图像理解模型。它具有基于视觉的GUI
Agent功能,并在图像理解方面有进一步的增强。它支持1120*1120分辨率的图像输入,并具有包括与图像进行多轮对话、GUI
Agent、Grounding等多种能力。
📗 [README in English](./README.md)
- 🔥🔥🔥 🆕: ```2023/12/26```:我们公开了 [CogVLM-SFT-311K](dataset_zh.md) 数据集,
它包含了超过15万条我们用于训练 **CogVLM v1.0(仅该模型)** 的数据。欢迎关注和使用。

🌟 **跳转到详细介绍: [CogVLM介绍](#introduction-to-cogvlm)
🆕 [CogAgent的介绍](#introduction-to-cogagent)**

📔 如需获取更详细的使用信息,请参阅: [CogVLM技术文档](https://zhipu-ai.feishu.cn/wiki/LXQIwqo1OiIVTykMh9Lc3w1Fn7g)
📔 如需获取更详细的使用信息,请参阅: [CogVLM&CogAgent技术文档](https://zhipu-ai.feishu.cn/wiki/LXQIwqo1OiIVTykMh9Lc3w1Fn7g)

<table>
<tr>
Expand Down
36 changes: 28 additions & 8 deletions basic_demo/cli_demo_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,51 @@
trust_remote_code=True
).to(DEVICE).eval()

text_only_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"

while True:
image_path = input("image path >>>>> ")
if image_path == "stop":
break

image = Image.open(image_path).convert('RGB')
if image_path == '':
print('You did not enter image path, the following will be a plain text conversation.')
image = None
text_only_first_query = True
else:
image = Image.open(image_path).convert('RGB')

history = []

while True:
query = input("Human:")
if query == "clear":
break
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])

if image is None:
if text_only_first_query:
query = text_only_template.format(query)
text_only_first_query = False
else:
old_prompt = ''
for _, (old_query, response) in enumerate(history):
old_prompt += old_query + " " + response + "\n"
query = old_prompt + "USER: {} ASSISTANT:".format(query)

if image is None:
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, template_version='base')
else:
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])

inputs = {
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]],
'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]] if image is not None else None,
}
if 'cross_images' in input_by_model and input_by_model['cross_images']:
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]

# add any transformers params here.
gen_kwargs = {"max_length": 2048,
"temperature": 0.9,
"do_sample": False}
"do_sample": False} # "temperature": 0.9
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
Expand Down

0 comments on commit 48dc3c6

Please sign in to comment.