Skip to content

Commit

Permalink
Merge pull request #1 from Akegarasu/main
Browse files Browse the repository at this point in the history
2023.4.30 Update From Akegarasu
  • Loading branch information
YunCheng66 authored Apr 29, 2023
2 parents 1885984 + b3dd848 commit 9357b04
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 46 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

A webui for ChatGLM made by THUDM. [chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)

![image](https://user-images.githubusercontent.com/36563862/226165277-acc5ba44-8041-4c30-a6aa-2746c87f8475.png)
![image](https://user-images.githubusercontent.com/36563862/226985330-48e3b7f8-8c03-4778-af39-fd9b3a993d19.png)

## Features

Expand Down
13 changes: 10 additions & 3 deletions modules/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,20 @@ def update_last(self, query, output) -> None:
self.rh[-1] = (query, output)

def refresh_last(self) -> None:
query, output = self.rh[-1]
self.rh[-1] = (query, parse_codeblock(output))
if self.rh:
query, output = self.rh[-1]
self.rh[-1] = (query, parse_codeblock(output))

def clear(self):
def clear(self) -> None:
self.history = []
self.rh = []

def revoke(self) -> List[Tuple[str, str]]:
if self.history and self.rh:
self.history.pop()
self.rh.pop()
return self.rh

def limit_round(self):
hl = len(self.history)
if hl == 0:
Expand Down
59 changes: 45 additions & 14 deletions modules/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional, List, Tuple

from torch.cuda import get_device_properties

from modules.device import torch_gc
from modules.options import cmd_opts

Expand All @@ -17,6 +19,23 @@ def prepare_model():
else:
model = model.float()
else:
if cmd_opts.precision is None:
total_vram_in_gb = get_device_properties(0).total_memory / 1e9
print(f'GPU memory: {total_vram_in_gb:.2f} GB')

if total_vram_in_gb > 30:
cmd_opts.precision = 'fp32'
elif total_vram_in_gb > 13:
cmd_opts.precision = 'fp16'
elif total_vram_in_gb > 10:
cmd_opts.precision = 'int8'
else:
cmd_opts.precision = 'int4'

print(f'Choosing precision {cmd_opts.precision} according to your VRAM.'
f' If you want to decide precision yourself,'
f' please add argument --precision when launching the application.')

if cmd_opts.precision == "fp16":
model = model.half().cuda()
elif cmd_opts.precision == "int4":
Expand Down Expand Up @@ -44,7 +63,7 @@ def load_model():

def infer(query,
history: Optional[List[Tuple]],
max_length, top_p, temperature):
max_length, top_p, temperature, use_stream_chat: bool):
if cmd_opts.ui_dev:
yield "hello", "hello, dev mode!"
return
Expand All @@ -56,17 +75,29 @@ def infer(query,
history = []

output_pos = 0
try:
for output, history in model.stream_chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
):
print(output[output_pos:], end='')
output_pos = len(output)
yield query, output
except Exception as e:
print(f"Generation failed: {repr(e)}")
print("")
if use_stream_chat:
try:
for output, history in model.stream_chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
):
print(output[output_pos:], end='', flush=True)
output_pos = len(output)
yield query, output
except Exception as e:
print(f"Generation failed: {repr(e)}")
else:
output, history = model.chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
)

print(output)
yield query, output

print()
torch_gc()
2 changes: 1 addition & 1 deletion modules/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

parser.add_argument("--port", type=int, default="17860")
parser.add_argument("--model-path", type=str, default="THUDM/chatglm-6b")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["fp32", "fp16", "int4", "int8"], default="fp16")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["fp32", "fp16", "int4", "int8"])
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--cpu", action='store_true', help="use cpu")
parser.add_argument("--share", action='store_true', help="use gradio share")
Expand Down
55 changes: 30 additions & 25 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@
import gradio as gr

from modules import options
from modules.context import ctx
from modules.context import Context
from modules.model import infer

css = "style.css"
script_path = "scripts"
_gradio_template_response_orig = gr.routes.templates.TemplateResponse


def predict(query, max_length, top_p, temperature):
def predict(ctx, query, max_length, top_p, temperature, use_stream_chat):
ctx.limit_round()
flag = True
for _, output in infer(
query=query,
history=ctx.history,
max_length=max_length,
top_p=top_p,
temperature=temperature
temperature=temperature,
use_stream_chat=use_stream_chat
):
if flag:
ctx.append(query, output)
Expand All @@ -31,12 +32,12 @@ def predict(query, max_length, top_p, temperature):
yield ctx.rh, ""


def clear_history():
def clear_history(ctx):
ctx.clear()
return gr.update(value=[])


def apply_max_round_click(max_round):
def apply_max_round_click(ctx, max_round):
ctx.max_rounds = max_round
return f"Applied: max round {ctx.max_rounds}"

Expand All @@ -45,7 +46,8 @@ def create_ui():
reload_javascript()

with gr.Blocks(css=css, analytics_enabled=False) as chat_interface:
prompt = "输入你的内容..."
_ctx = Context()
state = gr.State(_ctx)
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("""<h2><center>ChatGLM WebUI</center></h2>""")
Expand All @@ -58,13 +60,19 @@ def create_ui():
temperature = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Temperature', value=0.95)

with gr.Row():
max_rounds = gr.Slider(minimum=1, maximum=100, step=1, label="最大对话轮数(调小可以显著改善爆显存,但是会丢失上下文)", value=20)
max_rounds = gr.Slider(minimum=1, maximum=100, step=1, label="最大对话轮数", value=20)
apply_max_rounds = gr.Button("✔", elem_id="del-btn")

cmd_output = gr.Textbox(label="Command Output")
with gr.Row():
use_stream_chat = gr.Checkbox(label='使用流式输出', value=True)
with gr.Row():
with gr.Column(variant="panel"):
with gr.Row():
clear = gr.Button("清空对话(上下文)")
clear_history_btn = gr.Button("清空对话")

with gr.Row():
sync_his_btn = gr.Button("同步对话")

with gr.Row():
save_his_btn = gr.Button("保存对话")
Expand All @@ -73,40 +81,37 @@ def create_ui():
with gr.Row():
save_md_btn = gr.Button("保存为 MarkDown")

with gr.Row():
cmd_output = gr.Textbox(label="Command Output")

with gr.Column(scale=7):
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=800)
with gr.Row():
input_message = gr.Textbox(placeholder=prompt, show_label=False, lines=2, elem_id="chat-input")
input_message = gr.Textbox(placeholder="输入你的内容...(按 Ctrl+Enter 发送)", show_label=False, lines=4, elem_id="chat-input").style(container=False)
clear_input = gr.Button("🗑️", elem_id="del-btn")

with gr.Row():
submit = gr.Button("发送", elem_id="c_generate")

with gr.Row():
revoke_btn = gr.Button("撤回")

submit.click(predict, inputs=[
state,
input_message,
max_length,
top_p,
temperature
temperature,
use_stream_chat
], outputs=[
chatbot,
input_message
])

clear.click(clear_history, outputs=[chatbot])
revoke_btn.click(lambda ctx: ctx.revoke(), inputs=[state], outputs=[chatbot])
clear_history_btn.click(clear_history, inputs=[state], outputs=[chatbot])
clear_input.click(lambda x: "", inputs=[input_message], outputs=[input_message])

save_his_btn.click(ctx.save_history, outputs=[cmd_output])
save_md_btn.click(ctx.save_as_md, outputs=[cmd_output])
load_his_btn.upload(ctx.load_history, inputs=[
load_his_btn,
], outputs=[
chatbot
])

apply_max_rounds.click(apply_max_round_click, inputs=[max_rounds], outputs=[cmd_output])
save_his_btn.click(lambda ctx: ctx.save_history(), inputs=[state], outputs=[cmd_output])
save_md_btn.click(lambda ctx: ctx.save_as_md(), inputs=[state], outputs=[cmd_output])
load_his_btn.upload(lambda ctx, f: ctx.load_history(f), inputs=[state, load_his_btn], outputs=[chatbot])
sync_his_btn.click(lambda ctx: ctx.rh, inputs=[state], outputs=[chatbot])
apply_max_rounds.click(apply_max_round_click, inputs=[state, max_rounds], outputs=[cmd_output])

with gr.Blocks(css=css, analytics_enabled=False) as settings_interface:
with gr.Row():
Expand Down
13 changes: 13 additions & 0 deletions style.css
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,17 @@

.message {
width: inherit !important;
padding-left: 20px !important;
}

.app {
max-width: 100% !important;
}

.math.inline > span {
display: none;
}

.math.inline > svg {
display: inline;
}
5 changes: 3 additions & 2 deletions webui-start.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ python -m venv .venv
call .venv\Scripts\activate.bat

echo "Install dependencies"
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
pip install --upgrade -r requirements.txt
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 -f https://mirror.sjtu.edu.cn/pytorch-wheels/torch_stable.html -i https://mirrors.bfsu.edu.cn/pypi/web/simple
pip install --upgrade -r requirements.txt -i https://mirrors.bfsu.edu.cn/pypi/web/simple
goto :run

:start
Expand All @@ -19,3 +19,4 @@ goto :run
:run
echo "Start WebUI"
python webui.py
pause
4 changes: 4 additions & 0 deletions webui.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import time
import sysconfig
from modules import options

from modules.model import load_model

from modules.options import cmd_opts
from modules.ui import create_ui

# patch PATH for cpm_kernels libcudart lookup
os.environ['PATH'] = os.environ.get("PATH", "") + os.pathsep + os.path.join(sysconfig.get_paths()["purelib"], "torch\lib")


def ensure_output_dirs():
folders = ["outputs/save", "outputs/markdown"]
Expand Down

0 comments on commit 9357b04

Please sign in to comment.