Skip to content

Commit

Permalink
Feat: switch to turn on/off stream chat and bug fixes. (Akegarasu#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
remiliacn authored Mar 22, 2023
1 parent e4d219d commit 36eb9f0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 24 deletions.
5 changes: 3 additions & 2 deletions modules/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def update_last(self, query, output) -> None:
self.rh[-1] = (query, output)

def refresh_last(self) -> None:
query, output = self.rh[-1]
self.rh[-1] = (query, parse_codeblock(output))
if self.rh:
query, output = self.rh[-1]
self.rh[-1] = (query, parse_codeblock(output))

def clear(self):
self.history = []
Expand Down
40 changes: 26 additions & 14 deletions modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def load_model():

def infer(query,
history: Optional[List[Tuple]],
max_length, top_p, temperature):
max_length, top_p, temperature, use_stream_chat: bool):
if cmd_opts.ui_dev:
yield "hello", "hello, dev mode!"
return
Expand All @@ -75,17 +75,29 @@ def infer(query,
history = []

output_pos = 0
try:
for output, history in model.stream_chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
):
print(output[output_pos:], end='')
output_pos = len(output)
yield query, output
except Exception as e:
print(f"Generation failed: {repr(e)}")
print("")
if use_stream_chat:
try:
for output, history in model.stream_chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
):
print(output[output_pos:], end='', flush=True)
output_pos = len(output)
yield query, output
except Exception as e:
print(f"Generation failed: {repr(e)}")
else:
output, history = model.chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
)

print(output)
yield query, output

print()
torch_gc()
19 changes: 11 additions & 8 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
_gradio_template_response_orig = gr.routes.templates.TemplateResponse


def predict(query, max_length, top_p, temperature):
def predict(query, max_length, top_p, temperature, use_stream_chat):
ctx.limit_round()
flag = True
for _, output in infer(
query=query,
history=ctx.history,
max_length=max_length,
top_p=top_p,
temperature=temperature
query=query,
history=ctx.history,
max_length=max_length,
top_p=top_p,
temperature=temperature,
use_stream_chat=use_stream_chat
):
if flag:
ctx.append(query, output)
Expand Down Expand Up @@ -62,7 +63,8 @@ def create_ui():
apply_max_rounds = gr.Button("✔", elem_id="del-btn")

cmd_output = gr.Textbox(label="Command Output")

with gr.Row():
use_stream_chat = gr.Checkbox(label='使用流式输出', value=True)
with gr.Row():
with gr.Column(variant="panel"):
with gr.Row():
Expand Down Expand Up @@ -91,7 +93,8 @@ def create_ui():
input_message,
max_length,
top_p,
temperature
temperature,
use_stream_chat
], outputs=[
chatbot,
input_message
Expand Down

0 comments on commit 36eb9f0

Please sign in to comment.