forked from THUDM/CogVLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
acb52b1
commit e359876
Showing
1 changed file
with
102 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
""" | ||
This is a demo for using CogAgent and CogVLM in CLI | ||
Make sure you have installed vicuna-7b-v1.5 tokenizer model (https://huggingface.co/lmsys/vicuna-7b-v1.5), full checkpoint of vicuna-7b-v1.5 LLM is not required. | ||
In this demo, We us chat template, you can use others to replace such as 'vqa'. | ||
Strongly suggest to use GPU with bfloat16 support, otherwise, it will be slow. | ||
Mention that only one picture can be processed at one conversation, which means you can not replace or insert another picture during the conversation. | ||
""" | ||
|
||
import argparse | ||
import torch | ||
|
||
from PIL import Image | ||
from transformers import AutoModelForCausalLM, LlamaTokenizer | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--quant", choices=[4], type=int, default=4, help='quantization bits') | ||
parser.add_argument("--from_pretrained", type=str, default="THUDM/cogvlm-chat-hf", help='pretrained ckpt') | ||
parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path') | ||
parser.add_argument("--fp16", action="store_true") | ||
parser.add_argument("--bf16", action="store_true") | ||
|
||
args = parser.parse_args() | ||
MODEL_PATH = args.from_pretrained | ||
TOKENIZER_PATH = args.local_tokenizer | ||
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
||
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH) | ||
if args.bf16: | ||
torch_type = torch.bfloat16 | ||
else: | ||
torch_type = torch.float16 | ||
|
||
print("========Use torch type as:{} with device:{}========\n".format(torch_type, DEVICE)) | ||
|
||
if args.quant: | ||
model = AutoModelForCausalLM.from_pretrained( | ||
MODEL_PATH, | ||
torch_dtype=torch_type, | ||
low_cpu_mem_usage=True, | ||
load_in_4bit=True, | ||
trust_remote_code=True | ||
).eval() | ||
else: | ||
model = AutoModelForCausalLM.from_pretrained( | ||
MODEL_PATH, | ||
torch_dtype=torch_type, | ||
low_cpu_mem_usage=True, | ||
load_in_4bit=args.quant is not None, | ||
trust_remote_code=True | ||
).to(DEVICE).eval() | ||
|
||
text_only_template = "A chat between a curious user and an artificial intelligence assistant. \ | ||
The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:" | ||
|
||
while True: | ||
image_or_text = input("Chat with an image or text only? Please choose one from 'image' and 'text': ") | ||
if image_or_text == 'image': | ||
image_path = input("image path >>>>> ") | ||
if image_path == "stop": | ||
break | ||
image = Image.open(image_path).convert('RGB') | ||
else: | ||
image = None | ||
first_query = True | ||
history = [] | ||
|
||
while True: | ||
query = input("Human:") | ||
if query == "clear": | ||
break | ||
|
||
if image is None: | ||
if first_query: | ||
query = text_only_template.format([query]) | ||
first_query = False | ||
else: | ||
query = "USER: {} ASSISTANT:".format([query]) | ||
|
||
if image is None: | ||
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history) | ||
else: | ||
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image]) | ||
|
||
inputs = { | ||
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), | ||
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), | ||
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE), | ||
'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]] if image is not None else None, | ||
} | ||
if 'cross_images' in input_by_model and input_by_model['cross_images']: | ||
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]] | ||
|
||
# add any transformers params here. | ||
gen_kwargs = {"max_length": 2048, | ||
"do_sample": False} | ||
with torch.no_grad(): | ||
outputs = model.generate(**inputs, **gen_kwargs) | ||
outputs = outputs[:, inputs['input_ids'].shape[1]:] | ||
response = tokenizer.decode(outputs[0]) | ||
response = response.split("</s>")[0] | ||
print("\nCog:", response) | ||
history.append((query, response)) |