Skip to content

Commit

Permalink
restart
Browse files Browse the repository at this point in the history
  • Loading branch information
Akegarasu committed Mar 19, 2023
1 parent 250687d commit d0c38b7
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 29 deletions.
27 changes: 13 additions & 14 deletions modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def infer(query,
history: Optional[List[Tuple]],
max_length, top_p, temperature):
if cmd_opts.ui_dev:
return "hello", "hello, dev mode!"
yield "hello", "hello, dev mode!"
return

if not model:
raise "Model not loaded"
Expand All @@ -55,19 +56,17 @@ def infer(query,
history = []

output_pos = 0

for output, history in model.stream_chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
):
try:
try:
for output, history in model.stream_chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
):
print(output[output_pos:], end='')
except Exception as e:
print(f"Generation failed: {repr(e)}")
output_pos = len(output)
yield query, output

output_pos = len(output)
yield query, output
except Exception as e:
print(f"Generation failed: {repr(e)}")
print("")
torch_gc()
1 change: 1 addition & 0 deletions modules/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
parser.add_argument("--ui-dev", action='store_true', help="ui develop mode", default=None)

cmd_opts = parser.parse_args()
need_restart = False
28 changes: 19 additions & 9 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import gradio as gr

from modules import options
from modules.context import ctx
from modules.device import torch_gc
from modules.model import infer

css = "style.css"
Expand All @@ -15,11 +15,11 @@ def predict(query, max_length, top_p, temperature):
ctx.limit_round()
flag = True
for _, output in infer(
query=query,
history=ctx.history,
max_length=max_length,
top_p=top_p,
temperature=temperature
query=query,
history=ctx.history,
max_length=max_length,
top_p=top_p,
temperature=temperature
):
if flag:
ctx.append(query, output)
Expand Down Expand Up @@ -107,9 +107,19 @@ def create_ui():

apply_max_rounds.click(apply_max_round_click, inputs=[max_rounds], outputs=[cmd_output])

interfaces = [
(chat_interface, "Chat", "chat"),
]
with gr.Blocks(css=css, analytics_enabled=False) as settings_interface:
with gr.Row():
reload_ui = gr.Button("Reload UI")

def restart_ui():
options.need_restart = True

reload_ui.click(restart_ui)

interfaces = [
(chat_interface, "Chat", "chat"),
(settings_interface, "Settings", "settings")
]

with gr.Blocks(css=css, analytics_enabled=False, title="ChatGLM") as demo:
with gr.Tabs(elem_id="tabs") as tabs:
Expand Down
31 changes: 25 additions & 6 deletions webui.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
import time
from modules import options

from modules.model import load_model

from modules.options import cmd_opts
from modules.ui import create_ui

Expand All @@ -20,13 +24,28 @@ def init():
load_model()


def wait_on_server(ui=None):
while 1:
time.sleep(1)
if options.need_restart:
options.need_restart = False
time.sleep(0.5)
ui.close()
time.sleep(0.5)
break


def main():
ui = create_ui()
ui.queue(concurrency_count=5, max_size=20).launch(
server_name="0.0.0.0" if cmd_opts.listen else None,
server_port=cmd_opts.port,
share=cmd_opts.share
)
while True:
ui = create_ui()
ui.launch(
server_name="0.0.0.0" if cmd_opts.listen else None,
server_port=cmd_opts.port,
share=cmd_opts.share,
prevent_thread_lock=True
)
wait_on_server(ui)
print('Restarting UI...')


if __name__ == "__main__":
Expand Down

0 comments on commit d0c38b7

Please sign in to comment.