From 57bc22dabd75129176ee587ba6557d0392e71d7f Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sat, 23 Dec 2023 10:39:04 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0cogagent=20chat=E4=B8=AD?= =?UTF-8?q?=E7=9A=84grounding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- composite_demo/demo_chat_cogagent.py | 11 ++++++----- composite_demo/main.py | 23 +++++++++++------------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/composite_demo/demo_chat_cogagent.py b/composite_demo/demo_chat_cogagent.py index 6d5bc8cc..ce49c50f 100644 --- a/composite_demo/demo_chat_cogagent.py +++ b/composite_demo/demo_chat_cogagent.py @@ -7,7 +7,7 @@ from streamlit.delta_generator import DeltaGenerator from client import get_client from utils import images_are_same -from conversation import Conversation, Role, postprocess_image +from conversation import Conversation, Role, postprocess_image, postprocess_text client = get_client() @@ -30,6 +30,7 @@ def main( max_new_tokens: int = 2048, grounding: bool = False, retry: bool = False, + template: str = "", ): if 'chat_history' not in st.session_state: st.session_state.chat_history = [] @@ -72,9 +73,9 @@ def main( user_conversation = Conversation( role=Role.USER, - content_show=prompt_text.strip(), - image=image_input, - translate=translate + translate=translate, + content_show=postprocess_text(template=template, text=prompt_text.strip()), + image=image_input ) append_conversation(user_conversation, history) placeholder = st.empty() @@ -85,7 +86,7 @@ def main( output_text = '' for response in client.generate_stream( model_use='agent_chat', - grounding=False, + grounding=grounding, history=history, do_sample=True, max_new_tokens=max_new_tokens, diff --git a/composite_demo/main.py b/composite_demo/main.py index 714725f0..9cbd5a38 100644 --- a/composite_demo/main.py +++ b/composite_demo/main.py @@ -43,7 +43,7 @@ st.markdown("

CogAgent & CogVLM Chat Demo

", unsafe_allow_html=True) st.markdown( - "更多使用方法请参考文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof \n\n", + "更多使用方法请参考文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof \n\n 请根据文档的引导说明来尝试demo,以便理解demo的布局设计 \n", unsafe_allow_html=True) @@ -83,14 +83,12 @@ class Mode(str, Enum): horizontal=True, label_visibility='hidden', ) -grounding = False -selected_template_grounding_cogvlm = "" -if tab != Mode.CogAgent_Chat.value: - with st.sidebar: - grounding = st.checkbox("Grounding") - if tab == Mode.CogVLM_Chat.value and grounding: - selected_template_grounding_cogvlm = st.selectbox("Template For Grounding", template_grounding_cogvlm) +selected_template_grounding_cogvlm = "" +with st.sidebar: + grounding = st.checkbox("Grounding") + if tab == Mode.CogVLM_Chat.value or Mode.CogAgent_Chat and grounding: + selected_template_grounding_cogvlm = st.selectbox("Template For Grounding", template_grounding_cogvlm) if tab == Mode.CogAgent_Agent.value: with st.sidebar: @@ -101,7 +99,7 @@ class Mode(str, Enum): match tab: case Mode.CogVLM_Chat: - st.info("This is a demo using the VQA and Chat type about CogVLM") + st.info("This option uses cogvlm-chat and cogvlm-grounding model.") if uploaded_file is not None: demo_chat_cogvlm.main( retry=retry, @@ -118,7 +116,7 @@ class Mode(str, Enum): st.error(f'Please upload an image to start') case Mode.CogAgent_Chat: - st.info("This is a demo using the VQA and Chat type about CogAgent") + st.info("This option uses cogagent-chat model.") if uploaded_file is not None: demo_chat_cogagent.main( retry=retry, @@ -128,13 +126,14 @@ class Mode(str, Enum): prompt_text=prompt_text, metadata=encode_file_to_base64(uploaded_file), max_new_tokens=max_new_token, - grounding=False + grounding=grounding, + template=selected_template_grounding_cogvlm ) else: st.error(f'Please upload an image to start') case Mode.CogAgent_Agent: - st.info("This is a demo using the Agent type about CogAgent") + st.info("This option uses cogagent-chat model with agent template.") if uploaded_file is not None: demo_agent_cogagent.main( retry=retry,