Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Qwen2 Family #1541

Merged
merged 3 commits into from
Dec 2, 2024
Merged

Refactor Qwen2 Family #1541

merged 3 commits into from
Dec 2, 2024

Conversation

Wei-Lin-Intel
Copy link
Contributor

@Wei-Lin-Intel Wei-Lin-Intel commented Dec 2, 2024

What does this PR do?

Fixes # (issue)

  1. Remove code path of Flash Attention V1 for Qwen2 Family
    Because SDPA can support long sequence since 1.16.0, this code path is no longer needed.

  2. WA for accuracy issue of Qwen2 Family
    Qwen2 Family has bias add for Q/K/V linear, and it requires to keep FP32 accuracy for Q*K, a WA method has been added to keep the accuracy:

  query_states = query_states * self.norm_factor
  attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)).float()
  htcore.mark_step()

This is necessary for Qwen2 Family, especially for Qwen2-57B-A14B-Instruct. Without such WA:

deepspeed --num_nodes 1 --num_gpus 4 --master_addr 127.0.0.1 \
        --master_port 60008 run_generation.py \
        --model_name_or_path Qwen/Qwen2-57B-A14B-Instruct \
        --trust_remote_code \
        --n_iterations 1 \
        --warmup 1 \
        --bf16 \
        --trim_logits \
        --batch_size 1 \
        --max_input_tokens 128 \
        --max_new_tokens 128 \
        --use_kv_cache \
        --reuse_cache \
        --bucket_size 128 \
        --bucket_internal \
        --use_hpu_graphs \
        --use_flash_attention

the output would be:

Input/outputs:
input 1: ('DeepSpeed is a machine learning framework',)
output 1: ('DeepSpeed is a machine learning framework that provides thenumber one hundred and one dalm thenumber one then\nDeepSpeed is a deep thenumber one thenumber one hundred and thenumber one�\nDeepSpeed is a machine learning framework developed by the MoE (Mixture of Experts thenumber thenumber one hundred]Human thenumber one hundred and one d thenumber one hundred and thenumber one hundred thenumber then thenumber then thenumber one hundred]�\nDeepSpeed is a machine learning framework that is used to optimize and accelerate the thenumber one hundred and thenumber one hundred] training of large language�f thenumber one then thenumber one hundred',)

With such WA:

deepspeed --num_nodes 1 --num_gpus 4 --master_addr 127.0.0.1 \
        --master_port 60008 run_generation.py \
        --model_name_or_path Qwen/Qwen2-57B-A14B-Instruct \
        --trust_remote_code \
        --n_iterations 1 \
        --warmup 1 \
        --bf16 \
        --trim_logits \
        --batch_size 1 \
        --max_input_tokens 128 \
        --max_new_tokens 128 \
        --use_kv_cache \
        --reuse_cache \
        --bucket_size 128 \
        --bucket_internal \
        --use_hpu_graphs

output would be:

Input/outputs:
input 1: ('DeepSpeed is a machine learning framework',)
output 1: ('DeepSpeed is a machine learning framework that provides optimization for deep learning models. It is designed to make deep learning more efficient and scalable by utilizing the power of GPUs, distributed computing, and other hardware accelerators. DeepSpeed can be used with popular deep learning frameworks such as PyTorch and TensorFlow.\n\nSome of the key features of DeepSpeed include:\n\n1. Zero Redundancy Optimizer (ZeRO): This feature allows for model parallelism without any redundancy in the optimizer state, which can significantly reduce memory usage.\n2. Pipeline Parallelism: This feature allows for the partitioning of a model across multiple GPUs, enabling faster training times on large models.\n3. Dynamic',)

From lm_eval task, such WA can improve the score of Qwen2-7B-Instruct from 0.3 to 0.5.

  1. Add beam search for Qwen2-MoE

  2. Refactor the general code path for Qwen2 Family

  3. Add Dynamic MoE for Qwen2-MoE in inference mode

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@Wei-Lin-Intel
Copy link
Contributor Author

@libinta Please help to review it. This is critical for the accuracy issue of Qwen2 family.

Copy link

github-actions bot commented Dec 2, 2024

The code quality check failed, please run make style.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@regisss regisss merged commit 3261493 into huggingface:main Dec 2, 2024
4 checks passed
@Wei-Lin-Intel Wei-Lin-Intel mentioned this pull request Dec 4, 2024
3 tasks
Liangyx2 pushed a commit to HabanaAI/optimum-habana-fork that referenced this pull request Jan 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants