Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Akegarasu committed Mar 19, 2023
1 parent 0bffa79 commit 250687d
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ def prepare_model():
if cmd_opts.cpu:
if cmd_opts.precision == "fp32":
model = model.float()
elif cmd_opts.precision == "fp16":
elif cmd_opts.precision == "bf16":
model = model.bfloat16()
else:
model = model.bfloat16()
else:
print("--precision ERROR: INT modes are only for CUDA GPUs.")
exit(1)
else:
if cmd_opts.precision == "fp16":
model = model.half().cuda()
Expand All @@ -25,8 +24,7 @@ def prepare_model():
elif cmd_opts.precision == "int8":
model = model.half().quantize(8).cuda()
elif cmd_opts.precision == "fp32":
print("--precision ERROR: fp32 mode is only for CPU. Are you really ready to have such a large amount of vmem XD")
exit(1)
model = model.float()

model = model.eval()

Expand Down Expand Up @@ -59,17 +57,17 @@ def infer(query,
output_pos = 0

for output, history in model.stream_chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
):
try:
print(output[output_pos:],end='')
except:
pass
print(output[output_pos:], end='')
except Exception as e:
print(f"Generation failed: {repr(e)}")
output_pos = len(output)
yield query, output

print()
print("")
torch_gc()

0 comments on commit 250687d

Please sign in to comment.