forked from THUDM/CogVLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo_agent_cogagent.py
118 lines (103 loc) · 3.89 KB
/
demo_agent_cogagent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from io import BytesIO
import base64
import streamlit as st
import re
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, Conversation, Role, postprocess_image
from PIL import Image
from utils import images_are_same
client = get_client()
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None = None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(
top_p: float = 0.8,
temperature: float = 0.95,
prompt_text: str = "",
metadata: str = "",
top_k: int = 2,
max_new_tokens: int = 2048,
grounding: bool = False,
retry: bool = False,
template: str = ""
):
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if prompt_text == "" and retry == False:
print("\n== Clean ==\n")
st.session_state.chat_history = []
return
history: list[Conversation] = st.session_state.chat_history
for conversation in history:
conversation.show()
if retry:
print("\n== Retry ==\n")
last_user_conversation_idx = None
for idx, conversation in enumerate(history):
if conversation.role == Role.USER:
last_user_conversation_idx = idx
if last_user_conversation_idx is not None:
prompt_text = history[last_user_conversation_idx].content_show
del history[last_user_conversation_idx:]
if prompt_text:
image = Image.open(BytesIO(base64.b64decode(metadata))).convert('RGB') if metadata else None
image.thumbnail((1120, 1120))
image_input = image
if history and image:
last_user_image = next(
(conv.image for conv in reversed(history) if conv.role == Role.USER and conv.image), None)
if last_user_image and images_are_same(image, last_user_image):
image_input = None
# Not necessary to clear history
# else:
# # new picture means new conversation
# st.session_state.chat_history = []
# history = []
# Set conversation
if re.search('[\u4e00-\u9fff]', prompt_text):
translate = True
else:
translate = False
user_conversation = Conversation(
role=Role.USER,
translate=translate,
content_show=postprocess_text(template=template, text=prompt_text.strip()),
image=image_input
)
append_conversation(user_conversation, history)
placeholder = st.empty()
assistant_conversation = placeholder.chat_message(name="assistant", avatar="assistant")
assistant_conversation = assistant_conversation.empty()
# steam Answer
output_text = ''
for response in client.generate_stream(
model_use='agent_chat',
grounding=grounding,
history=history,
do_sample=True,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
):
output_text += response.token.text
assistant_conversation.markdown(output_text.strip() + '▌')
## Final Answer with image.
print("\n==Output:==\n", output_text)
content_output, image_output = postprocess_image(output_text, image)
assistant_conversation = Conversation(
role=Role.ASSISTANT,
content=content_output,
image=image_output,
translate=translate,
)
append_conversation(
conversation=assistant_conversation,
history=history,
placeholder=placeholder.chat_message(name="assistant", avatar="assistant"),
)