-
Notifications
You must be signed in to change notification settings - Fork 0
/
webui.py
228 lines (202 loc) · 10.4 KB
/
webui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import copy
import json
from typing import List
import streamlit as st
from lagent.actions import BaseAction, ActionExecutor, IPythonInterpreter
from lagent.agents.internlm2_agent import Internlm2Agent, Internlm2Protocol
from lagent.llms import HFTransformer, LMDeployClient, INTERNLM2_META
from lagent.schema import AgentStatusCode
from utils.config import load_config, check_config_empty, model_require_keys
from action import WeatherQuery, KnowledgeQuery, DeviceAssistant
class StreamlitUI:
"""Streamlit UI class."""
def clear_state(self):
"""Clear the existing session state."""
st.session_state['model']['user'] = []
st.session_state['model']['assistant'] = []
if 'chatbot' in st.session_state:
st.session_state['chatbot']._session_history = []
def get_actions(self) -> List[BaseAction]:
"""Get the plugin actions."""
return [
DeviceAssistant(),
WeatherQuery(),
KnowledgeQuery(
embedding_model = st.session_state['model']['embedding_path'],
reranker_model = st.session_state['model']['reranker_path'],
db_path = st.session_state['model']['vector_db']
),
]
def initialize_chatbot(self):
"""Initialize the chatbot with the given model and plugin actions."""
if 'chatbot' not in st.session_state:
if '远程' == st.session_state['model']['model_type']:
model = LMDeployClient(
model_name=st.session_state['model']['model_name'],
url=st.session_state['model']['model_path'],
meta_template=INTERNLM2_META,
max_new_tokens=st.session_state['model']['max_tokens'],
top_p=st.session_state['model']['top_p'],
top_k=st.session_state['model']['top_k'],
temperature=st.session_state['model']['temperature'],
repetition_penalty=1.0,
stop_words=['<|im_end|>']
)
else:
model = HFTransformer(
path=st.session_state['model']['model_path'],
meta_template=INTERNLM2_META,
max_new_tokens=st.session_state['model']['max_tokens'],
top_p=st.session_state['model']['top_p'],
top_k=st.session_state['model']['top_k'],
temperature=st.session_state['model']['temperature'],
repetition_penalty=1.0,
stop_words=['<|im_end|>']
)
st.session_state['chatbot'] = Internlm2Agent(
llm=model,
plugin_executor=ActionExecutor(actions=self.get_actions()),
# interpreter_executor = ActionExecutor(actions=[IPythonInterpreter()]),
protocol=Internlm2Protocol(
meta_prompt=st.session_state['model']['prompt_meta'],
plugin_prompt=st.session_state['model']['prompt_plugin'],
interpreter_prompt=st.session_state['model']['prompt_da'],
tool=dict(
begin='{start_token}{name}\n',
start_token='<|action_start|>',
name_map=dict(
plugin='<|plugin|>',
interpreter='<|interpreter|>'),
belong='assistant',
end='<|action_end|>\n',
), ),
max_turn=7
)
def render_user(self, prompt: str):
"""Render the user prompt in the Streamlit UI."""
with st.chat_message('user'):
st.markdown(prompt)
def render_assistant(self, agent_return):
"""Render the assistant response in the Streamlit UI."""
with st.chat_message('assistant'):
for action in agent_return.actions:
if (action) and (action.type != 'FinishAction'):
self.render_action(action)
st.markdown(agent_return.response)
def render_plugin_args(self, action):
"""Render the plugin arguments in the Streamlit UI."""
action_name = action.type
args = action.args
parameter_dict = dict(name=action_name, parameters=args)
parameter_str = '```json\n' + json.dumps(parameter_dict, indent=4, ensure_ascii=False) + '\n```'
st.markdown(parameter_str)
def render_interpreter_args(self, action):
"""Render the interpreter arguments in the Streamlit UI."""
st.info(action.type)
st.markdown(action.args['text'])
def render_action(self, action):
"""Render the action in the Streamlit UI."""
st.markdown(action.thought)
if action.type == 'IPythonInterpreter':
self.render_interpreter_args(action)
elif action.type == 'FinishAction':
pass
else:
self.render_plugin_args(action)
self.render_action_results(action)
def render_action_results(self, action):
"""Render the results of action, including text, images, videos, and audios."""
if isinstance(action.result, dict):
if 'text' in action.result:
st.markdown('```\n' + action.result['text'] + '\n```')
if 'image' in action.result:
# image_path = action.result['image']
for image_path in action.result['image']:
image_data = open(image_path, 'rb').read()
st.image(image_data, caption='Generated Image')
if 'video' in action.result:
video_data = action.result['video']
video_data = open(video_data, 'rb').read()
st.video(video_data)
if 'audio' in action.result:
audio_data = action.result['audio']
audio_data = open(audio_data, 'rb').read()
st.audio(audio_data)
elif isinstance(action.result, list):
for item in action.result:
if item['type'] == 'text':
st.markdown('```\n' + item['content'] + '\n```')
elif item['type'] == 'image':
image_data = open(item['content'], 'rb').read()
st.image(image_data, caption='Generated Image')
elif item['type'] == 'video':
video_data = open(item['content'], 'rb').read()
st.video(video_data)
elif item['type'] == 'audio':
audio_data = open(item['content'], 'rb').read()
st.audio(audio_data)
if action.errmsg:
st.error(action.errmsg)
def init_ui():
"""Initialize Streamlit UI and setup sidebar"""
if 'ui' not in st.session_state:
st.session_state['ui'] = StreamlitUI()
st.session_state['ui'].initialize_chatbot()
for prompt, agent_return in zip(st.session_state['model']['user'], st.session_state['model']['assistant']):
st.session_state['ui'].render_user(prompt)
st.session_state['ui'].render_assistant(agent_return)
if user_input := st.chat_input('请输入问题或控制指令...'):
with st.container():
st.session_state['ui'].render_user(user_input)
st.session_state['model']['user'].append(user_input)
if isinstance(user_input, str):
user_input = [dict(role='user', content=user_input)]
st.session_state['model']['last_status'] = AgentStatusCode.SESSION_READY
for agent_return in st.session_state['chatbot'].stream_chat(
st.session_state['model']['session_history'] + user_input):
if agent_return.state == AgentStatusCode.PLUGIN_RETURN:
with st.container():
st.session_state['ui'].render_plugin_args(agent_return.actions[-1])
st.session_state['ui'].render_action_results(agent_return.actions[-1])
elif agent_return.state == AgentStatusCode.CODE_RETURN:
with st.container():
st.session_state['ui'].render_action_results(agent_return.actions[-1])
elif (agent_return.state == AgentStatusCode.STREAM_ING
or agent_return.state == AgentStatusCode.CODING):
# st.markdown(agent_return.response)
# 清除占位符的当前内容,并显示新内容
with st.container():
if agent_return.state != st.session_state['model']['last_status']:
st.session_state['model']['temp'] = ''
placeholder = st.empty()
st.session_state['model']['placeholder'] = placeholder
if isinstance(agent_return.response, dict):
action = f"\n\n {agent_return.response['name']}: \n\n"
action_input = agent_return.response['parameters']
if agent_return.response['name'] == 'IPythonInterpreter':
action_input = action_input['command']
response = action + action_input
else:
response = agent_return.response
st.session_state['model']['temp'] = response
st.session_state['model']['placeholder'].markdown(
st.session_state['model']['temp'])
elif agent_return.state == AgentStatusCode.END:
st.session_state['model']['session_history'] += (
user_input + agent_return.inner_steps)
agent_return = copy.deepcopy(agent_return)
agent_return.response = st.session_state['model']['temp']
st.session_state['model']['assistant'].append(copy.deepcopy(agent_return))
st.session_state['model']['last_status'] = agent_return.state
st.set_page_config(page_title='智慧大棚中心 - 农业助手', layout='wide')
st.title('🤠 智慧大棚中心')
if 'model' not in st.session_state or 'room' not in st.session_state:
load_config()
if 'model' not in st.session_state or check_config_empty(st.session_state['model'], model_require_keys):
with st.container():
st.warning('请先配置书生·浦语模型参数')
elif 'room' not in st.session_state or 'room' not in st.session_state['room'] or 0 == len(st.session_state['room']['room']):
with st.container():
st.warning('请先配置大棚参数')
else:
init_ui()