Skip to content

Commit

Permalink
fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
1049451037 committed Dec 14, 2023
1 parent 6948a38 commit 4d53d99
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 15 deletions.
40 changes: 29 additions & 11 deletions basic_demo/cli_demo_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,44 @@
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer
import argparse

MODEL_PATH = 'your path of CogAgent or CogVLM'
TOKENIZER_PATH = 'your path of vicuna'
parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=[4], type=int, default=None, help='quantization bits')
parser.add_argument("--from_pretrained", type=str, default="THUDM/cogagent-chat-hf", help='pretrained ckpt') # TODO
parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path') #TODO
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")

args = parser.parse_args()
MODEL_PATH = args.from_pretrained
TOKENIZER_PATH = args.local_tokenizer
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
torch_type = torch.float16
if args.bf16:
torch_type = torch.bfloat16
else:
torch_type = torch.float16
warnings.warn("Your GPU does not support bfloat16 type, use fp16 instead")

print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE))

model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(DEVICE).eval()
if args.quant:
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
load_in_4bit=True,
trust_remote_code=True
).eval()
else:
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
load_in_4bit=args.quant is not None,
trust_remote_code=True
).to(DEVICE).eval()

while True:
image_path = input("image path >>>>> ")
Expand Down
2 changes: 1 addition & 1 deletion basic_demo/cli_demo_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def main():
parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')

parser.add_argument("--from_pretrained", type=str, default="cogagent-chat", help='pretrained ckpt') # TODO
parser.add_argument("--local_tokenizer", type=str, default="/share/official_pretrains/hf_home/vicuna-7b-v1.5", help='tokenizer path') #TODO
parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path') #TODO
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--stream_chat", action="store_true")
Expand Down
4 changes: 2 additions & 2 deletions finetune_demo/evaluate_cogvlm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ script_dir=$(dirname $script_path)
main_dir=$(dirname $script_dir)
MODEL_TYPE="cogvlm-base-490"
VERSION="base"
MODEL_ARGS="--from_pretrained ./checkpoints/merged_lora \
MODEL_ARGS="--from_pretrained ./checkpoints/merged_lora_490 \
--max_length 1288 \
--lora_rank 10 \
--use_lora \
Expand Down Expand Up @@ -52,7 +52,7 @@ gpt_options=" \



run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} evaluate_demo.py ${gpt_options}"
run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} evaluate_cogvlm_demo.py ${gpt_options}"
echo ${run_cmd}
eval ${run_cmd}

Expand Down
2 changes: 1 addition & 1 deletion finetune_demo/finetune_cogvlm_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ gpt_options=" \



run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} finetune_demo.py ${gpt_options}"
run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} finetune_cogvlm_demo.py ${gpt_options}"
echo ${run_cmd}
eval ${run_cmd}

Expand Down

0 comments on commit 4d53d99

Please sign in to comment.