Skip to content

Commit

Permalink
update infer demo.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Nov 2, 2023
1 parent 39d2a24 commit 4fd620a
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions examples/gpt/inference_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def main():
parser.add_argument('--prompt_template_name', default="vicuna", type=str,
help="Prompt template name, eg: alpaca, vicuna, baichuan-chat, chatglm2 etc.")
parser.add_argument('--interactive', action='store_true', help="run in the instruction mode")
parser.add_argument('--single_round', action='store_true',
help="Whether to generate single round dialogue, default is multi-round dialogue")
parser.add_argument('--data_file', default=None, type=str,
help="A file that contains instructions (one instruction per line)")
parser.add_argument('--predictions_file', default='./predictions_result.jsonl', type=str)
Expand All @@ -49,19 +47,26 @@ def main():
for example in examples[:10]:
print(example)
if args.interactive:
print(f"Start inference with interactive mode. enable multi round: {not args.single_round}")
print(f"Start inference with interactive mode.")
history = []
while True:
raw_input_text = input("Input:")
if len(raw_input_text.strip()) == 0:
if raw_input_text.strip() == 'exit':
break
if args.single_round:
response = model.predict([raw_input_text], prompt_template_name=args.prompt_template_name)[0]
else:
response, history = model.chat(
raw_input_text, history=history, prompt_template_name=args.prompt_template_name)
print("Response: ", response)
print("\n")
position = 0
print("Response:", end='', flush=True)
try:
for response in model.chat(
raw_input_text,
history=history,
prompt_template_name=args.prompt_template_name,
stream=True
):
print(response[position:], end='', flush=True)
position = len(response)
except KeyboardInterrupt:
pass
print()
else:
print("Start inference.")
results = []
Expand Down

0 comments on commit 4fd620a

Please sign in to comment.