Skip to content

Commit

Permalink
更新cogvlm demo进行纯文本对话
Browse files Browse the repository at this point in the history
  • Loading branch information
duchenzhuang committed Dec 19, 2023
1 parent acb52b1 commit e359876
Showing 1 changed file with 102 additions and 0 deletions.
102 changes: 102 additions & 0 deletions basic_demo/cli_demo_hf_cogvlm_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
This is a demo for using CogAgent and CogVLM in CLI
Make sure you have installed vicuna-7b-v1.5 tokenizer model (https://huggingface.co/lmsys/vicuna-7b-v1.5), full checkpoint of vicuna-7b-v1.5 LLM is not required.
In this demo, We us chat template, you can use others to replace such as 'vqa'.
Strongly suggest to use GPU with bfloat16 support, otherwise, it will be slow.
Mention that only one picture can be processed at one conversation, which means you can not replace or insert another picture during the conversation.
"""

import argparse
import torch

from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer

parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=[4], type=int, default=4, help='quantization bits')
parser.add_argument("--from_pretrained", type=str, default="THUDM/cogvlm-chat-hf", help='pretrained ckpt')
parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")

args = parser.parse_args()
MODEL_PATH = args.from_pretrained
TOKENIZER_PATH = args.local_tokenizer
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
if args.bf16:
torch_type = torch.bfloat16
else:
torch_type = torch.float16

print("========Use torch type as:{} with device:{}========\n".format(torch_type, DEVICE))

if args.quant:
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
load_in_4bit=True,
trust_remote_code=True
).eval()
else:
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
load_in_4bit=args.quant is not None,
trust_remote_code=True
).to(DEVICE).eval()

text_only_template = "A chat between a curious user and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"

while True:
image_or_text = input("Chat with an image or text only? Please choose one from 'image' and 'text': ")
if image_or_text == 'image':
image_path = input("image path >>>>> ")
if image_path == "stop":
break
image = Image.open(image_path).convert('RGB')
else:
image = None
first_query = True
history = []

while True:
query = input("Human:")
if query == "clear":
break

if image is None:
if first_query:
query = text_only_template.format([query])
first_query = False
else:
query = "USER: {} ASSISTANT:".format([query])

if image is None:
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history)
else:
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])

inputs = {
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]] if image is not None else None,
}
if 'cross_images' in input_by_model and input_by_model['cross_images']:
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]

# add any transformers params here.
gen_kwargs = {"max_length": 2048,
"do_sample": False}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0])
response = response.split("</s>")[0]
print("\nCog:", response)
history.append((query, response))

0 comments on commit e359876

Please sign in to comment.