Skip to content

Commit

Permalink
feat: Added Stream Chat; Added fp32, fp16 option for cpu with low mem (
Browse files Browse the repository at this point in the history
…Akegarasu#14)

* Added fp32, fp16 option for cpu with low mem

* Added Stream Chat
  • Loading branch information
haofanurusai authored Mar 19, 2023
1 parent 3123feb commit 0bffa79
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ python webui.py

`--share`: use gradio to share

`--precision`: fp16, int4, int8
`--precision`: fp32(CPU only), fp16, int4(CUDA GPU only), int8(CUDA GPU only)

`--cpu`: use cpu
6 changes: 6 additions & 0 deletions modules/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def append(self, query, output) -> str:
self.rh.append((query, ok))
return ok

def refresh_last(self, query, output) -> str:
ok = parse_codeblock(output)
self.history[-1] = (query, output)
self.rh[-1] = (query, ok)
return ok

def clear(self):
self.history = []
self.rh = []
Expand Down
28 changes: 23 additions & 5 deletions modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,23 @@
def prepare_model():
global model
if cmd_opts.cpu:
model = model.float()
if cmd_opts.precision == "fp32":
model = model.float()
elif cmd_opts.precision == "fp16":
model = model.bfloat16()
else:
print("--precision ERROR: INT modes are only for CUDA GPUs.")
exit(1)
else:
if cmd_opts.precision == "fp16":
model = model.half().cuda()
elif cmd_opts.precision == "int4":
model = model.half().quantize(4).cuda()
elif cmd_opts.precision == "int8":
model = model.half().quantize(8).cuda()
elif cmd_opts.precision == "fp32":
print("--precision ERROR: fp32 mode is only for CPU. Are you really ready to have such a large amount of vmem XD")
exit(1)

model = model.eval()

Expand Down Expand Up @@ -46,12 +55,21 @@ def infer(query,

if history is None:
history = []
output, history = model.chat(

output_pos = 0

for output, history in model.stream_chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
)
print(output)
):
try:
print(output[output_pos:],end='')
except:
pass
output_pos = len(output)
yield query, output

print()
torch_gc()
return query, output
2 changes: 1 addition & 1 deletion modules/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

parser.add_argument("--port", type=int, default="17860")
parser.add_argument("--model-path", type=str, default="THUDM/chatglm-6b")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["fp16", "int4", "int8"], default="fp16")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["fp32", "fp16", "int4", "int8"], default="fp16")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--cpu", action='store_true', help="use cpu")
parser.add_argument("--share", action='store_true', help="use gradio share")
Expand Down
15 changes: 10 additions & 5 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,21 @@

def predict(query, max_length, top_p, temperature):
ctx.limit_round()
_, output = infer(
flag = True
for _, output in infer(
query=query,
history=ctx.history,
max_length=max_length,
top_p=top_p,
temperature=temperature
)
ctx.append(query, output)
# for clear input textbox
return ctx.history, ""
):
if flag:
ctx.append(query, output)
flag = False
else:
ctx.refresh_last(query, output)
# for clear input textbox
yield ctx.history, ""


def clear_history():
Expand Down
2 changes: 1 addition & 1 deletion webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def init():

def main():
ui = create_ui()
ui.launch(
ui.queue(concurrency_count=5, max_size=20).launch(
server_name="0.0.0.0" if cmd_opts.listen else None,
server_port=cmd_opts.port,
share=cmd_opts.share
Expand Down

0 comments on commit 0bffa79

Please sign in to comment.