From e3598765f44d345d2eb7a503ce69aefb9db49e59 Mon Sep 17 00:00:00 2001 From: duchenzhuang <15652580397@163.com> Date: Tue, 19 Dec 2023 10:26:44 +0000 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0cogvlm=20demo=E8=BF=9B?= =?UTF-8?q?=E8=A1=8C=E7=BA=AF=E6=96=87=E6=9C=AC=E5=AF=B9=E8=AF=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basic_demo/cli_demo_hf_cogvlm_chat.py | 102 ++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 basic_demo/cli_demo_hf_cogvlm_chat.py diff --git a/basic_demo/cli_demo_hf_cogvlm_chat.py b/basic_demo/cli_demo_hf_cogvlm_chat.py new file mode 100644 index 00000000..349d07fd --- /dev/null +++ b/basic_demo/cli_demo_hf_cogvlm_chat.py @@ -0,0 +1,102 @@ +""" +This is a demo for using CogAgent and CogVLM in CLI +Make sure you have installed vicuna-7b-v1.5 tokenizer model (https://huggingface.co/lmsys/vicuna-7b-v1.5), full checkpoint of vicuna-7b-v1.5 LLM is not required. +In this demo, We us chat template, you can use others to replace such as 'vqa'. +Strongly suggest to use GPU with bfloat16 support, otherwise, it will be slow. +Mention that only one picture can be processed at one conversation, which means you can not replace or insert another picture during the conversation. +""" + +import argparse +import torch + +from PIL import Image +from transformers import AutoModelForCausalLM, LlamaTokenizer + +parser = argparse.ArgumentParser() +parser.add_argument("--quant", choices=[4], type=int, default=4, help='quantization bits') +parser.add_argument("--from_pretrained", type=str, default="THUDM/cogvlm-chat-hf", help='pretrained ckpt') +parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path') +parser.add_argument("--fp16", action="store_true") +parser.add_argument("--bf16", action="store_true") + +args = parser.parse_args() +MODEL_PATH = args.from_pretrained +TOKENIZER_PATH = args.local_tokenizer +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + +tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH) +if args.bf16: + torch_type = torch.bfloat16 +else: + torch_type = torch.float16 + +print("========Use torch type as:{} with device:{}========\n".format(torch_type, DEVICE)) + +if args.quant: + model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + torch_dtype=torch_type, + low_cpu_mem_usage=True, + load_in_4bit=True, + trust_remote_code=True + ).eval() +else: + model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + torch_dtype=torch_type, + low_cpu_mem_usage=True, + load_in_4bit=args.quant is not None, + trust_remote_code=True + ).to(DEVICE).eval() + +text_only_template = "A chat between a curious user and an artificial intelligence assistant. \ + The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:" + +while True: + image_or_text = input("Chat with an image or text only? Please choose one from 'image' and 'text': ") + if image_or_text == 'image': + image_path = input("image path >>>>> ") + if image_path == "stop": + break + image = Image.open(image_path).convert('RGB') + else: + image = None + first_query = True + history = [] + + while True: + query = input("Human:") + if query == "clear": + break + + if image is None: + if first_query: + query = text_only_template.format([query]) + first_query = False + else: + query = "USER: {} ASSISTANT:".format([query]) + + if image is None: + input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history) + else: + input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image]) + + inputs = { + 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), + 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), + 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE), + 'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]] if image is not None else None, + } + if 'cross_images' in input_by_model and input_by_model['cross_images']: + inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]] + + # add any transformers params here. + gen_kwargs = {"max_length": 2048, + "do_sample": False} + with torch.no_grad(): + outputs = model.generate(**inputs, **gen_kwargs) + outputs = outputs[:, inputs['input_ids'].shape[1]:] + response = tokenizer.decode(outputs[0]) + response = response.split("")[0] + print("\nCog:", response) + history.append((query, response)) \ No newline at end of file