WizardMath : Empowering Mathematical Reasoning for Large Language Models via Reinforced Evol-Instruct (RLEIF)
🤗 HF Repo • 🐦 Twitter • 📃 [WizardLM] • 📃 [WizardCoder] • 📃 [WizardMath]
👋 Join our Discord
To develop our WizardMath model, we begin with adapting the Evol-Instruct and Reinforcement Learning methods specifically for math tasks, like GSM8k and MATH. This involves tailoring the prompt to the domain of math-related instructions. Subsequently, we fine-tune the LLaMA 2, utilizing the newly created instruction-following math training set.
- 🔥 Our WizardMath-70B-V1.0 model slightly outperforms some closed-source LLMs on the GSM8K, including ChatGPT 3.5, Claude Instant 1 and PaLM 2 540B.
- 🔥 Our WizardMath-70B-V1.0 model achieves 81.6 pass@1 on the GSM8k Benchmarks, which is 24.8 points higher than the SOTA open-source LLM.
- 🔥 Our WizardMath-70B-V1.0 model achieves 22.7 pass@1 on the MATH Benchmarks, which is 9.2 points higher than the SOTA open-source LLM.
Model | Checkpoint | Paper | GSM8k | MATH | Online Demo | License |
---|---|---|---|---|---|---|
WizardMath-70B-V1.0 | 🤗 HF Link | 📃 [WizardMath] | 81.6 | 22.7 | Demo (only support English) | Llama 2 |
WizardMath-13B-V1.0 | 🤗 HF Link | 📃 [WizardMath] | 63.9 | 14.0 | Demo (only support English) | Llama 2 |
WizardMath-7B-V1.0 | 🤗 HF Link | 📃 [WizardMath] | 54.9 | 10.7 | Demo (only support English) | Llama 2 |
❗To commen concern about dataset:
Recently, there have been clear changes in the open-source policy and regulations of our overall organization's code, data, and models. Despite this, we have still worked hard to obtain opening the weights of the model first, but thecode and data involves stricter auditing and is in review with our legal team . Our researchers have no authority to publicly release them without authorization. Thank you for your understanding.
🔥 The following figure shows that our WizardMath attains the fifth position on the GSM8k benchmark, surpassing Claude Instant 1 (81.6 vs. 80.9), ChatGPT (81.6 vs. 80.8) and PaLM 2 540B (81.6 vs. 80.7). Notably, our model exhibits a substantially smaller size compared to these models.
❗❗❗Note: This performance is 100% reproducible! If you cannot reproduce it, please follow the steps in Evaluation.
❗❗❗Note: The score of ChatGPT reported by Model Selection is 80.8%.
❗❗❗Note: If you want to build a WizardMath demo, note for model system prompts usage:
Default version:
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:"
CoT Version: (❗For the simple math questions, we do NOT recommend to use the CoT prompt.)
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
The following table clearly demonstrates that our WizardMath exhibits a substantial performance advantage over all the open-source models on the GSM8k and MATH benchmarks.
❗ If you are confused with the different scores of our 7B, 13B and 70B models (54.9, 63.9 and 81.6), please check the Notes.
Model | GSM8k Pass@1 | MATH Pass@1 |
---|---|---|
MPT-7B | 6.8 | 3.0 |
Falcon-7B | 6.8 | 2.3 |
LLaMA-1-7B | 11.0 | 2.9 |
LLaMA-2-7B | 14.6 | 2.5 |
MPT-30B | 15.2 | 3.1 |
LLaMA-1-13B | 17.8 | 3.9 |
GPT-Neo-2.7B | 19.5 | -- |
Falcon-40B | 19.6 | 2.5 |
Baichuan-chat-13B | 23.9 | -- |
Vicuna-v1.3-13B | 27.6 | -- |
LLaMA-2-13B | 28.7 | 3.9 |
InternLM-7B | 31.2 | -- |
ChatGLM-2-6B | 32.4 | -- |
GPT-J-6B | 34.9 | -- |
LLaMA-1-33B | 35.6 | 3.9 |
LLaMA-2-34B | 42.2 | 6.24 |
RFT-7B | 50.3 | -- |
LLaMA-1-65B | 50.9 | 10.6 |
Qwen-7B | 51.6 | -- |
WizardMath-7B-v1.0 | 54.9 | 10.7 |
LLaMA-2-70B | 56.8 | 13.5 |
WizardMath-13B-v1.0 | 63.9 | 14.0 |
WizardMath-70B-v1.0 | 81.6 | 22.7 |
❗ Note: The above table conducts a comprehensive comparison of our WizardMath with other models on the GSM8k and MATH benchmarks. In this study, to ensure equitable and cohesive evaluations, we report the socres of all models within the settings of greedy decoding and CoT.
In the SFT stage, we train WizardMath with the code WizardMath/train/train_wizardmath.py
from Llama-X, which uses the open-source friendly .
We supervised fine-tune WizardMath-13B with the following hyperparameters:
Hyperparameter | LLaMA 2 13B |
---|---|
Batch size | 128 |
Learning rate | 2e-5 |
Epochs | 3 |
Max length | 2048 |
LR scheduler | cosine |
To reproduce our fine-tuning of WizardMath, please follow the following steps:
- According to the instructions of Llama-X, install the environment, download the training code, and deploy. (Note:
deepspeed==0.10.0
andtransformers==4.31.0
) - Replace the
train.py
with thetrain_wizardmath.py
in our repo (WizardMath/train/train_wizardmath.py
) - Login Huggingface:
huggingface-cli login
- Execute the following training command:
deepspeed train_wizardmath.py \
--model_name_or_path "/your/path/to/llama-2-13b" \
--data_path "/your/path/to/math_instruction_data.json"\
--output_dir "/your/path/to/save_ckpt"\
--num_train_epochs 3 \
--model_max_length 2048 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50 \
--save_total_limit 50 \
--learning_rate 2e-5 \
--warmup_steps 10 \
--logging_steps 2 \
--lr_scheduler_type "cosine" \
--report_to "tensorboard" \
--gradient_checkpointing True \
--deepspeed config/deepspeed_config.json \
--fp16 True \
Recently, there have been clear changes in the open-source policy and regulations of our overall organization's code, data, and models. The data and code involves stricter auditing and is in review with our legal team . Our researchers have no authority to publicly release them without authorization. Thank you for your understanding.
We provide the decoding script for WizardMath, which reads a input file and generates corresponding responses for each sample, and finally calculate the score.
Note: We used vllm for inference which can speed up inference and save time. Please refer to the official github vllm for questions about vllm installation.
conda create -n wizardmath python=3.8 -y
conda activate wizardmath
pip install vllm
pip install jsonlines
pip install Fraction
pip install openai
cd WizardMath
The inference prompt for our WizardMath is:
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
- The format of
gsm8k_test.jsonl
should be:
{"question": "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", "answer": "Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer\u2019s market.\n#### 18"}
{"question": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?", "answer": "It takes 2/2=<<2/2=1>>1 bolt of white fiber\nSo the total amount of fabric is 2+1=<<2+1=3>>3 bolts of fabric\n#### 3"}
{"question": "Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?", "answer": "The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000\n#### 70000"}
{"question": "James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?", "answer": "He sprints 3*3=<<3*3=9>>9 times\nSo he runs 9*60=<<9*60=540>>540 meters\n#### 540"}
.....
- Run the following script to generate the answer.
python inference/gsm8k_inference.py --data_file data/gsm8k_test.jsonl --model "/your/path/to/load_ckpt" --batch_size 60 --tensor_parallel_size 1
You can specify tensor_parallel_size
, which indicates the number of gpus. You are able to slice the datasets using the start
and end
.
- The format of
MATH_test.jsonl
should be:
{"idx": "hendrycks_math_1", "instruction": "Find the matrix $\\mathbf{M}$ such that\n\\[\\mathbf{M} \\begin{pmatrix} 1 & -2 \\\\ 1 & 4 \\end{pmatrix} = \\begin{pmatrix} 6 & 0 \\\\ 0 & 6 \\end{pmatrix}.\\]", "output": "The inverse of $\\begin{pmatrix} 1 & -2 \\\\ 1 & 4 \\end{pmatrix}$ is\n\\[\\frac{1}{(1)(4) - (-2)(1)} \\begin{pmatrix} 4 & 2 \\\\ -1 & 1 \\end{pmatrix} = \\frac{1}{6} \\begin{pmatrix} 4 & 2 \\\\ -1 & 1 \\end{pmatrix}.\\]So, multiplying by this inverse on the right, we get\n\\[\\mathbf{M} = \\begin{pmatrix} 6 & 0 \\\\ 0 & 6 \\end{pmatrix} \\cdot \\frac{1}{6} \\begin{pmatrix} 4 & 2 \\\\ -1 & 1 \\end{pmatrix} = \\boxed{\\begin{pmatrix} 4 & 2 \\\\ -1 & 1 \\end{pmatrix}}.\\]", "input": "", "type": "Precalculus"}
{"idx": "hendrycks_math_2", "instruction": "Compute $\\arccos (-1).$ Express your answer in radians.", "output": "Since $\\cos \\pi = -1,$ $\\arccos (-1) = \\boxed{\\pi}.$", "input": "", "type": "Precalculus"}
.....
- Run the following script to generate the answer.
python inference/MATH_inference.py --data_file data/MATH_test.jsonl --model "/your/path/to/load_ckpt" --batch_size 50 --tensor_parallel_size 1
You can specify tensor_parallel_size
, which indicates the number of gpus. You are able to slice the datasets using the start
and end
.
@article{luo2023wizardmath,
title={WizardMath: Empowering Mathematical Reasoning for Large Language Models via Reinforced Evol-Instruct},
author={Luo, Haipeng and Sun, Qingfeng and Xu, Can and Zhao, Pu and Lou, Jianguang and Tao, Chongyang and Geng, Xiubo and Lin, Qingwei and Chen, Shifeng and Zhang, Dongmei},
journal={arXiv preprint arXiv:2308.09583},
year={2023}
}
WizardMath model follows the same license as LLaMA 2. The content produced by any version of WizardMath is influenced by uncontrollable variables such as randomness, and therefore, the accuracy of the output cannot be guaranteed by this project. This project does not accept any legal liability for the content of the model output, nor does it assume responsibility for any losses incurred due to the use of associated resources and output results.