Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update FlashAttention & Fix load_from_name #58

Merged
merged 5 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,147 changes: 574 additions & 573 deletions README.md

Large diffs are not rendered by default.

1,155 changes: 578 additions & 577 deletions README_En.md

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion cn_clip/clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from flash_attn.flash_attention import FlashMHA

import importlib.util
if importlib.util.find_spec('flash_attn'):
FlashMHA = importlib.import_module('flash_attn.flash_attention').FlashMHA

from cn_clip.clip import _tokenizer
from cn_clip.clip.configuration_bert import BertConfig
Expand Down
5 changes: 4 additions & 1 deletion cn_clip/clip/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from flash_attn.flash_attention import FlashMHA

import importlib.util
if importlib.util.find_spec('flash_attn'):
FlashMHA = importlib.import_module('flash_attn.flash_attention').FlashMHA

from .configuration_bert import BertConfig

Expand Down
9 changes: 6 additions & 3 deletions cn_clip/clip/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,27 @@ def available_models() -> List[str]:


def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
download_root: str = None):
download_root: str = None, vision_model_name: str = None, text_model_name: str = None, input_resolution: int = None):
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
elif os.path.isfile(name):
assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
model_path = name
model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

with open(model_path, 'rb') as opened_file:
# loading saved checkpoint
checkpoint = torch.load(opened_file, map_location="cpu")

model = create_model(_MODEL_INFO[name]['struct'], checkpoint)
model = create_model(model_name, checkpoint)
if str(device) == "cpu":
model.float()
else:
model.to(device)
return model, image_transform(_MODEL_INFO[name]['input_resolution'])
return model, image_transform(model_input_resolution)


def load(model, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", clip_path=None,
Expand Down
5 changes: 5 additions & 0 deletions cn_clip/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import time
from time import gmtime, strftime
import importlib.util

import torch
from torch import optim
Expand Down Expand Up @@ -111,6 +112,10 @@ def main():
model.set_grad_checkpointing()
logging.info("Grad-checkpointing activated.")

if args.use_flash_attention:
assert importlib.util.find_spec("flash_attn"), "flash_attn is not installed."
logging.info("Using FlashAttention.")

if args.use_bn_sync:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

Expand Down
140 changes: 70 additions & 70 deletions flash_attention_En.md
Original file line number Diff line number Diff line change
@@ -1,70 +1,70 @@
[**中文说明**](flash_attention.md) | [**English**](flash_attention_En.md)

# Accelerate Chinese-CLIP with FlashAttention

Chinese-CLIP now supports the acceleration of training process through [FlashAttention](https://github.com/HazyResearch/flash-attention).

## Environmental Preparation

+ Nvidia GPUs **with Volta or Ampere architecture** (such as A100, RTX 3090, T4, and RTX 2080). Please refer to [this document](https://en.wikipedia.org/wiki/CUDA#GPUs_supported) for the corresponding GPUs of each Nvidia architecture.
+ CUDA 11, NVCC
+ **FlashAttention**:Install FlashAttention by executing `pip install flash-attn`. Please refer to the [FlashAttention project repository](https://github.com/HazyResearch/flash-attention).

## Use it in Chinese-CLIP!

Applying FlashAttention to the finetune process of Chinese-CLIP is very simple, just add `--use-flash-attention` to the sh script of finetune. We provide the sample script `run_scripts/muge_finetune_vit-b-16_rbt-base_flashattn.sh`.


## Training Speed and Memory Usage Comparison

Enabling FlashAttention can significantly speed up the finetune process of Chinese-CLIP and reduce the memory usage without affecting the precision. Our experiments are conducted on an 8-card A100 GPU (80GB memory) machine.

We present the comparison of the batch time and memory usage of FP16 precision finetune for each scale model. The improvement in training speed and reduction in memory usage are more significant for larger models.

<table border="1" width="120%">
<tr align="center">
<th></th><th colspan="4">Batch Time</th>
</tr>
<th>Unit: s/it</th><th>Batch size</th><th>w/o FlashAttention</th><th>w/ FlashAttention</th><th>Speedup</th>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>1200*8</td><td>1.710</td><td>1.680</td><td>1.02×</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>400*8</td><td>1.477</td><td>0.960</td><td>1.54×</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>128*8</td><td>1.293</td><td>0.785</td><td>1.65×</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-L/14@336px</sub></td><td>40*8</td><td>1.397</td><td>0.587</td><td>2.38×</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-H/14</sub></td><td>64*8</td><td>1.265</td><td>0.845</td><td>1.50×</td>
</tr>
</table>
<br>

<table border="1" width="120%">
<tr align="center">
<th></th><th colspan="4">Memory</th>
</tr>
<th>Unit: GB</th><th>Batch size</th><th>w/o FlashAttention</th><th>w/ FlashAttention</th>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>1200*8</td><td>79</td><td>75</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>400*8</td><td>80</td><td>56</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>128*8</td><td>77</td><td>50</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-L/14@336px</sub></td><td>40*8</td><td>78</td><td>37</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-H/14</sub></td><td>64*8</td><td>76</td><td>57</td>
</tr>
</table>
<br>
[**中文说明**](flash_attention.md) | [**English**](flash_attention_En.md)
# Accelerate Chinese-CLIP with FlashAttention
Chinese-CLIP now supports the acceleration of training process through [FlashAttention](https://github.com/HazyResearch/flash-attention).
## Environmental Preparation
+ Nvidia GPUs **with Volta or Ampere architecture** (such as A100, RTX 3090, T4, and RTX 2080). Please refer to [this document](https://en.wikipedia.org/wiki/CUDA#GPUs_supported) for the corresponding GPUs of each Nvidia architecture.
+ CUDA 11, NVCC
+ **FlashAttention**:Install FlashAttention by executing `pip install flash-attn`. Please refer to the [FlashAttention project repository](https://github.com/HazyResearch/flash-attention).
## Use it in Chinese-CLIP!
Applying FlashAttention to the finetune process of Chinese-CLIP is very simple, just add `--use-flash-attention` to the sh script of finetune. We provide the sample script `run_scripts/muge_finetune_vit-b-16_rbt-base_flashattn.sh`.
## Training Speed and Memory Usage Comparison
Enabling FlashAttention can significantly speed up the finetune process and reduce the memory usage of Chinese-CLIP without affecting the precision. Our experiments are conducted on an 8-card A100 GPU (80GB memory) machine.
We present the comparison of the batch time and memory usage of FP16 precision finetune for each scale model. The improvement in training speed and reduction in memory usage are more significant for larger models.
<table border="1" width="120%">
<tr align="center">
<th></th><th colspan="4">Batch Time</th>
</tr>
<th>Unit: s/it</th><th>Batch size</th><th>w/o FlashAttention</th><th>w/ FlashAttention</th><th>Speedup</th>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>1200*8</td><td>1.710</td><td>1.680</td><td>1.02×</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>400*8</td><td>1.477</td><td>0.960</td><td>1.54×</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>128*8</td><td>1.293</td><td>0.785</td><td>1.65×</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-L/14@336px</sub></td><td>40*8</td><td>1.397</td><td>0.587</td><td>2.38×</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-H/14</sub></td><td>64*8</td><td>1.265</td><td>0.845</td><td>1.50×</td>
</tr>
</table>
<br>
<table border="1" width="120%">
<tr align="center">
<th></th><th colspan="4">Memory</th>
</tr>
<th>Unit: GB</th><th>Batch size</th><th>w/o FlashAttention</th><th>w/ FlashAttention</th>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>1200*8</td><td>79</td><td>75</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>400*8</td><td>80</td><td>56</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>128*8</td><td>77</td><td>50</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-L/14@336px</sub></td><td>40*8</td><td>78</td><td>37</td>
</tr>
<tr align="center">
<td width="120%">CN-CLIP<sub>ViT-H/14</sub></td><td>64*8</td><td>76</td><td>57</td>
</tr>
</table>
<br>