Skip to content

Commit

Permalink
修复了重试后提示词重复的bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zRzRzRzRzRzRzR committed Jan 2, 2024
1 parent 1920a97 commit d420c9b
Show file tree
Hide file tree
Showing 6 changed files with 1,002 additions and 8 deletions.
5 changes: 3 additions & 2 deletions basic_demo/cli_demo_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=[4], type=int, default=None, help='quantization bits')
parser.add_argument("--from_pretrained", type=str, default="THUDM/cogagent-chat-hf", help='pretrained ckpt')
parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
parser.add_argument("--from_pretrained", type=str, default="/share/home/zyx/Models/cogagent-chat-hf", help='pretrained ckpt')
parser.add_argument("--local_tokenizer", type=str, default="/share/official_pretrains/hf_home/vicuna-7b-v1.5", help='tokenizer path')
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")

Expand Down Expand Up @@ -96,6 +96,7 @@
"do_sample": False} # "temperature": 0.9
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
breakpoint()
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0])
response = response.split("</s>")[0]
Expand Down
3 changes: 2 additions & 1 deletion composite_demo/demo_agent_cogagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def main(
user_conversation = Conversation(
role=Role.USER,
translate=translate,
content_show=postprocess_text(template=template, text=prompt_text.strip()),
content_show=prompt_text.strip() if retry else postprocess_text(template=template,
text=prompt_text.strip()),
image=image_input
)
append_conversation(user_conversation, history)
Expand Down
3 changes: 2 additions & 1 deletion composite_demo/demo_chat_cogagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def main(
user_conversation = Conversation(
role=Role.USER,
translate=translate,
content_show=postprocess_text(template=template, text=prompt_text.strip()),
content_show=prompt_text.strip() if retry else postprocess_text(template=template,
text=prompt_text.strip()),
image=image_input
)
append_conversation(user_conversation, history)
Expand Down
3 changes: 2 additions & 1 deletion composite_demo/demo_chat_cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def main(
user_conversation = Conversation(
role=Role.USER,
translate=translate,
content_show=postprocess_text(template=template, text=prompt_text.strip()),
content_show=prompt_text.strip() if retry else postprocess_text(template=template,
text=prompt_text.strip()),
image=image_input
)
append_conversation(user_conversation, history)
Expand Down
7 changes: 4 additions & 3 deletions composite_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ class Mode(str, Enum):
selected_template_grounding_cogvlm = ""
with st.sidebar:
grounding = st.checkbox("Grounding")
if tab == Mode.CogVLM_Chat.value or Mode.CogAgent_Chat and grounding:
selected_template_grounding_cogvlm = st.selectbox("Template For Grounding", template_grounding_cogvlm)
if tab == Mode.CogVLM_Chat or tab == Mode.CogAgent_Chat:
if grounding:
selected_template_grounding_cogvlm = st.selectbox("Template For Grounding", template_grounding_cogvlm)

if tab == Mode.CogAgent_Agent.value:
if tab == Mode.CogAgent_Agent:
with st.sidebar:
selected_template_agent_cogagent = st.selectbox("Template For Agent", templates_agent_cogagent)

Expand Down
Loading

0 comments on commit d420c9b

Please sign in to comment.