From 4fd620a37c1b606dd9a0b25e9d0eb15304657676 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Thu, 2 Nov 2023 11:54:41 +0800 Subject: [PATCH] update infer demo. --- examples/gpt/inference_demo.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/gpt/inference_demo.py b/examples/gpt/inference_demo.py index df53e5a2..e559a8a0 100644 --- a/examples/gpt/inference_demo.py +++ b/examples/gpt/inference_demo.py @@ -22,8 +22,6 @@ def main(): parser.add_argument('--prompt_template_name', default="vicuna", type=str, help="Prompt template name, eg: alpaca, vicuna, baichuan-chat, chatglm2 etc.") parser.add_argument('--interactive', action='store_true', help="run in the instruction mode") - parser.add_argument('--single_round', action='store_true', - help="Whether to generate single round dialogue, default is multi-round dialogue") parser.add_argument('--data_file', default=None, type=str, help="A file that contains instructions (one instruction per line)") parser.add_argument('--predictions_file', default='./predictions_result.jsonl', type=str) @@ -49,19 +47,26 @@ def main(): for example in examples[:10]: print(example) if args.interactive: - print(f"Start inference with interactive mode. enable multi round: {not args.single_round}") + print(f"Start inference with interactive mode.") history = [] while True: raw_input_text = input("Input:") - if len(raw_input_text.strip()) == 0: + if raw_input_text.strip() == 'exit': break - if args.single_round: - response = model.predict([raw_input_text], prompt_template_name=args.prompt_template_name)[0] - else: - response, history = model.chat( - raw_input_text, history=history, prompt_template_name=args.prompt_template_name) - print("Response: ", response) - print("\n") + position = 0 + print("Response:", end='', flush=True) + try: + for response in model.chat( + raw_input_text, + history=history, + prompt_template_name=args.prompt_template_name, + stream=True + ): + print(response[position:], end='', flush=True) + position = len(response) + except KeyboardInterrupt: + pass + print() else: print("Start inference.") results = []