From 57bc22dabd75129176ee587ba6557d0392e71d7f Mon Sep 17 00:00:00 2001
From: zR <2448370773@qq.com>
Date: Sat, 23 Dec 2023 10:39:04 +0800
Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0cogagent=20chat=E4=B8=AD?=
=?UTF-8?q?=E7=9A=84grounding?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
composite_demo/demo_chat_cogagent.py | 11 ++++++-----
composite_demo/main.py | 23 +++++++++++------------
2 files changed, 17 insertions(+), 17 deletions(-)
diff --git a/composite_demo/demo_chat_cogagent.py b/composite_demo/demo_chat_cogagent.py
index 6d5bc8cc..ce49c50f 100644
--- a/composite_demo/demo_chat_cogagent.py
+++ b/composite_demo/demo_chat_cogagent.py
@@ -7,7 +7,7 @@
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from utils import images_are_same
-from conversation import Conversation, Role, postprocess_image
+from conversation import Conversation, Role, postprocess_image, postprocess_text
client = get_client()
@@ -30,6 +30,7 @@ def main(
max_new_tokens: int = 2048,
grounding: bool = False,
retry: bool = False,
+ template: str = "",
):
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
@@ -72,9 +73,9 @@ def main(
user_conversation = Conversation(
role=Role.USER,
- content_show=prompt_text.strip(),
- image=image_input,
- translate=translate
+ translate=translate,
+ content_show=postprocess_text(template=template, text=prompt_text.strip()),
+ image=image_input
)
append_conversation(user_conversation, history)
placeholder = st.empty()
@@ -85,7 +86,7 @@ def main(
output_text = ''
for response in client.generate_stream(
model_use='agent_chat',
- grounding=False,
+ grounding=grounding,
history=history,
do_sample=True,
max_new_tokens=max_new_tokens,
diff --git a/composite_demo/main.py b/composite_demo/main.py
index 714725f0..9cbd5a38 100644
--- a/composite_demo/main.py
+++ b/composite_demo/main.py
@@ -43,7 +43,7 @@
st.markdown("
CogAgent & CogVLM Chat Demo
", unsafe_allow_html=True)
st.markdown(
- "更多使用方法请参考文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof \n\n",
+ "更多使用方法请参考文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof \n\n 请根据文档的引导说明来尝试demo,以便理解demo的布局设计 \n",
unsafe_allow_html=True)
@@ -83,14 +83,12 @@ class Mode(str, Enum):
horizontal=True,
label_visibility='hidden',
)
-grounding = False
-selected_template_grounding_cogvlm = ""
-if tab != Mode.CogAgent_Chat.value:
- with st.sidebar:
- grounding = st.checkbox("Grounding")
- if tab == Mode.CogVLM_Chat.value and grounding:
- selected_template_grounding_cogvlm = st.selectbox("Template For Grounding", template_grounding_cogvlm)
+selected_template_grounding_cogvlm = ""
+with st.sidebar:
+ grounding = st.checkbox("Grounding")
+ if tab == Mode.CogVLM_Chat.value or Mode.CogAgent_Chat and grounding:
+ selected_template_grounding_cogvlm = st.selectbox("Template For Grounding", template_grounding_cogvlm)
if tab == Mode.CogAgent_Agent.value:
with st.sidebar:
@@ -101,7 +99,7 @@ class Mode(str, Enum):
match tab:
case Mode.CogVLM_Chat:
- st.info("This is a demo using the VQA and Chat type about CogVLM")
+ st.info("This option uses cogvlm-chat and cogvlm-grounding model.")
if uploaded_file is not None:
demo_chat_cogvlm.main(
retry=retry,
@@ -118,7 +116,7 @@ class Mode(str, Enum):
st.error(f'Please upload an image to start')
case Mode.CogAgent_Chat:
- st.info("This is a demo using the VQA and Chat type about CogAgent")
+ st.info("This option uses cogagent-chat model.")
if uploaded_file is not None:
demo_chat_cogagent.main(
retry=retry,
@@ -128,13 +126,14 @@ class Mode(str, Enum):
prompt_text=prompt_text,
metadata=encode_file_to_base64(uploaded_file),
max_new_tokens=max_new_token,
- grounding=False
+ grounding=grounding,
+ template=selected_template_grounding_cogvlm
)
else:
st.error(f'Please upload an image to start')
case Mode.CogAgent_Agent:
- st.info("This is a demo using the Agent type about CogAgent")
+ st.info("This option uses cogagent-chat model with agent template.")
if uploaded_file is not None:
demo_agent_cogagent.main(
retry=retry,