Skip to content

Commit

Permalink
Merge pull request nlpxucan#195 from flyinghpluo/main
Browse files Browse the repository at this point in the history
update wizardmath demo
  • Loading branch information
nlpxucan authored Sep 1, 2023
2 parents 132d815 + 24cdd2e commit 01f17bf
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 5 deletions.
2 changes: 1 addition & 1 deletion WizardMath/inference/MATH_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_hendrycks_math(model, data_path, start=0, end=MAX_INT, batch_size=1, te
print('lenght ====', len(hendrycks_math_ins))
batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size)

stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
stop_tokens = ["Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=2048, stop=stop_tokens)
print('sampleing =====', sampling_params)
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
Expand Down
2 changes: 1 addition & 1 deletion WizardMath/inference/gsm8k_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_para
print('lenght ====', len(gsm8k_ins))
batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size)

stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
stop_tokens = ["Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=1024, stop=stop_tokens)
print('sampleing =====', sampling_params)
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
Expand Down
17 changes: 17 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,20 @@ CUDA_VISIBLE_DEVICES=0 python wizardmath_demo.py \
--base_model "xxx/path/to/wizardmath_7b_model" \
--n_gpus 1
```

## WizardLM Inference Demo

We provide the inference demo script for **WizardLM-Family**.

1. According to the instructions of [Llama-X](https://github.com/AetherCortex/Llama-X), install the environment.
2. Install these packages:
```bash
pip install transformers==4.31.0
pip install vllm==0.1.4
```
3. Enjoy your demo:
```bash
CUDA_VISIBLE_DEVICES=0 python wizardLM_demo.py \
--base_model "xxx/path/to/wizardLM_7b_model" \
--n_gpus 1
```
53 changes: 53 additions & 0 deletions demo/wizardLM_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import gradio as gr
import argparse
import os
import json
from vllm import LLM, SamplingParams


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model", type=str) # model path
parser.add_argument("--n_gpus", type=int, default=1) # n_gpu
return parser.parse_args()

def predict(message, history, system_prompt, temperature, max_tokens):
instruction = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
for human, assistant in history:
instruction += 'USER: '+ human + ' ASSISTANT: '+ assistant + '</s>'
instruction += 'USER: '+ message + ' ASSISTANT:'
problem = [instruction]
stop_tokens = ["USER:", "USER", "ASSISTANT:", "ASSISTANT"]
sampling_params = SamplingParams(temperature=temperature, top_p=1, max_tokens=max_tokens, stop=stop_tokens)
completions = llm.generate(problem, sampling_params)
for output in completions:
prompt = output.prompt
print('==========================question=============================')
print(prompt)
generated_text = output.outputs[0].text
print('===========================answer=============================')
print(generated_text)
for idx in range(len(generated_text)):
yield generated_text[:idx+1]


if __name__ == "__main__":
args = parse_args()
llm = LLM(model=args.base_model, tensor_parallel_size=args.n_gpus)
gr.ChatInterface(
predict,
title="LLM playground - WizardLM-13B-V1.2",
description="This is a LLM playground for WizardLM-13B-V1.2, github: https://github.com/nlpxucan/WizardLM, huggingface: https://huggingface.co/WizardLM",
theme="soft",
chatbot=gr.Chatbot(height=1400, label="Chat History",),
textbox=gr.Textbox(placeholder="input", container=False, scale=7),
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
additional_inputs=[
gr.Textbox("A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", label="System Prompt"),
gr.Slider(0, 1, 0.9, label="Temperature"),
gr.Slider(100, 2048, 1024, label="Max Tokens"),
],
additional_inputs_accordion_name="Parameters",
).queue().launch(share=False, server_port=7870)
4 changes: 2 additions & 2 deletions demo/wizardcoder_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def evaluate_vllm(
)
],
title="WizardCoder",
description="Empowering Code Large Language Models with Evol-Instruct"
description="Empowering Code Large Language Models with Evol-Instruct, github: https://github.com/nlpxucan/WizardLM, huggingface: https://huggingface.co/WizardLM"
).queue().launch(share=True, server_port=port)

if __name__ == "__main__":
fire.Fire(main)
fire.Fire(main)
2 changes: 1 addition & 1 deletion demo/wizardmath_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def evaluate_vllm(
)
],
title=title,
description="Empowering Mathematical Reasoning for Large Language Models via Reinforced Evol-Instruct"
description="Empowering Mathematical Reasoning for Large Language Models via Reinforced Evol-Instruct, github: https://github.com/nlpxucan/WizardLM, huggingface: https://huggingface.co/WizardLM"
).queue().launch(share=False, server_port=port)

if __name__ == "__main__":
Expand Down

0 comments on commit 01f17bf

Please sign in to comment.