Skip to content

Commit

Permalink
update wizardmath demo
Browse files Browse the repository at this point in the history
  • Loading branch information
flyinghpluo authored Sep 1, 2023
1 parent dec2f01 commit 5648be6
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion WizardMath/inference/MATH_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_hendrycks_math(model, data_path, start=0, end=MAX_INT, batch_size=1, te
print('lenght ====', len(hendrycks_math_ins))
batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size)

stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
stop_tokens = ["Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=2048, stop=stop_tokens)
print('sampleing =====', sampling_params)
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
Expand Down
2 changes: 1 addition & 1 deletion WizardMath/inference/gsm8k_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_para
print('lenght ====', len(gsm8k_ins))
batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size)

stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
stop_tokens = ["Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=1024, stop=stop_tokens)
print('sampleing =====', sampling_params)
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
Expand Down
2 changes: 1 addition & 1 deletion demo/wizardmath_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def evaluate_vllm(
prompt = problem_prompt.format(instruction=instruction)

problem_instruction = [prompt]
stop_tokens = ['</s>']
stop_tokens = ["Instruction:", "Instruction", "Response:", "Response", '</s>']
sampling_params = SamplingParams(temperature=temperature, top_p=1, max_tokens=max_new_tokens, stop=stop_tokens)
completions = llm.generate(problem_instruction, sampling_params)
for output in completions:
Expand Down

0 comments on commit 5648be6

Please sign in to comment.