Skip to content

Commit

Permalink
Merge pull request OFA-Sys#55 from DtYXs/master
Browse files Browse the repository at this point in the history
Support FlashAttention & Add COCO-CN script
  • Loading branch information
yangapku authored Feb 20, 2023
2 parents 2c586c5 + 8b701a6 commit 006a981
Show file tree
Hide file tree
Showing 16 changed files with 461 additions and 36 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<br><br>

# 新闻
* 2023.2.16 新增[FlashAttention](https://github.com/HazyResearch/flash-attention)支持,提升训练速度,降低显存占用,详见[flash_attention.md](flash_attention.md)
* 2023.1.15 新增部署[ONNX](https://onnx.ai/)[TensorRT](https://developer.nvidia.com/tensorrt)模型支持(并提供预训练TensorRT模型),提升特征推理速度,满足部署需求,详见[deployment.md](deployment.md)
* 2022.12.12 新增实现[FLIP](https://arxiv.org/abs/2212.00794)训练策略,在finetune训练时可[激活使用](#FLIP)(感谢[@zwkkk](https://github.com/zwkkk)同学[贡献代码](https://github.com/OFA-Sys/Chinese-CLIP/pull/26)❤️)
* 2022.12.3 公开[ELEVATER](https://eval.ai/web/challenges/challenge-page/1832)图像分类数据集的中文版本,详见[数据文档](https://github.com/OFA-Sys/Chinese-CLIP/blob/master/zeroshot_dataset.md)
Expand Down Expand Up @@ -309,7 +310,7 @@ ${DATAPATH}
└── test
```

为了降低上手难度,我们也提供了按上述步骤预处理好的MUGE数据([下载链接](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/MUGE.zip))和Flickr30K-CN数据([下载链接](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/Flickr30k-CN.zip))压缩包,直接下载解压并放置于`${DATAPATH}/datasets/`目录下即可。
为了降低上手难度,我们也提供了按上述步骤预处理好的MUGE数据([下载链接](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/MUGE.zip))和Flickr30K-CN数据([下载链接](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/Flickr30k-CN.zip))压缩包,直接下载解压并放置于`${DATAPATH}/datasets/`目录下即可。如果需要[COCO-CN](https://github.com/li-xirong/coco-cn)数据,请向原作者进行申请许可完成后,邮件联系我们吧。

### 模型finetune

Expand Down
4 changes: 3 additions & 1 deletion README_En.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ This is the Chinese version of CLIP. We use a large-scale Chinese image-text pai
<br><br>

# News
* 2023.2.16 Support [FlashAttention](https://github.com/HazyResearch/flash-attention) to improve training speed and reduce memory usage. See [flash_attention_En.md](flash_attention_En.md) for more information.
* 2023.1.15 Support the conversion of Pytorch models into [ONNX](https://onnx.ai/) or [TensorRT](https://developer.nvidia.com/tensorrt) formats (and provide pretrained TensorRT models) to improve inference speed and meet deployment requirements. See [deployment_En.md](deployment_En.md) for more information.
* 2022.12.12 Implement [FLIP](https://arxiv.org/abs/2212.00794) strategy, which can be [activated](#FLIP) during finetuning (Thanks [@zwkkk](https://github.com/zwkkk) for [the PR](https://github.com/OFA-Sys/Chinese-CLIP/pull/26) ❤️)
* 2022.12.3 The datasets of the Chinese version of the [ELEVATER](https://eval.ai/web/challenges/challenge-page/1832) benchmark are publicly available. See [Notes for datasets](zeroshot_dataset_en.md) for more information.
* 2022.12.1 Chinese-CLIP model & representation generation API are officially merged into Huggingface transformers🤗 codebase.
Expand Down Expand Up @@ -309,7 +311,7 @@ ${DATAPATH}
└── test
```

For easier use, we have provided preprocessed MUGE ([download link](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/MUGE.zip)) and Flickr30K-CN ([download link](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/Flickr30k-CN.zip)) datasets in zip format. To use them, just download and unzip it under `${DATAPATH}/datasets/`.
For easier use, we have provided preprocessed MUGE ([download link](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/MUGE.zip)) and Flickr30K-CN ([download link](https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/Flickr30k-CN.zip)) datasets in zip format. To use them, just download and unzip it under `${DATAPATH}/datasets/`. If you need [COCO-CN](https://github.com/li-xirong/coco-cn) dataset, please contact us by email when you have finished applying for permission from the original author.

### Finetuning

Expand Down
2 changes: 1 addition & 1 deletion cn_clip/clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .bert_tokenizer import FullTokenizer

_tokenizer = FullTokenizer()
from .model import convert_state_dict
from .utils import load_from_name, available_models, tokenize, image_transform, load

6 changes: 4 additions & 2 deletions cn_clip/clip/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(self,
initializer_range=0.02,
layer_norm_eps=1e-12,
output_attentions=False,
output_hidden_states=False
output_hidden_states=False,
use_flash_attention=False
):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
Expand All @@ -81,4 +82,5 @@ def __init__(self,
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.output_hidden_states = output_hidden_states
self.use_flash_attention = use_flash_attention
89 changes: 80 additions & 9 deletions cn_clip/clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from flash_attn.flash_attention import FlashMHA

from cn_clip.clip import _tokenizer
from cn_clip.clip.configuration_bert import BertConfig
Expand Down Expand Up @@ -179,10 +180,10 @@ def forward(self, x: torch.Tensor):


class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_flash_attention: bool = False):
super().__init__()

self.attn = nn.MultiheadAttention(d_model, n_head)
self.attn = nn.MultiheadAttention(d_model, n_head) if not use_flash_attention else FlashMHA(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
Expand All @@ -191,10 +192,15 @@ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.use_flash_attention = use_flash_attention

def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
if self.use_flash_attention:
# Batch first is needed for FlashAttention. See https://github.com/HazyResearch/flash-attention/issues/84 for more information.
return self.attn(x.transpose(1, 0))[0].transpose(1, 0)
else:
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
Expand All @@ -203,12 +209,12 @@ def forward(self, x: torch.Tensor):


class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_flash_attention: bool = False):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_flash_attention) for _ in range(layers)])

def forward(self, x: torch.Tensor):
if self.grad_checkpointing and not torch.jit.is_scripting():
Expand All @@ -219,7 +225,7 @@ def forward(self, x: torch.Tensor):


class VisualTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, use_flash_attention: bool = False):
super().__init__()
self.input_resolution = input_resolution
self.grid_size = (self.input_resolution // patch_size, self.input_resolution // patch_size)
Expand All @@ -231,7 +237,7 @@ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: i
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)

self.transformer = Transformer(width, layers, heads)
self.transformer = Transformer(width, layers, heads, use_flash_attention=use_flash_attention)

self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
Expand Down Expand Up @@ -301,6 +307,7 @@ def __init__(self,
tokenizer = _tokenizer,
# vision head width, added this param for ViT-H
vision_head_width: int = 64,
use_flash_attention: bool = False,
):
super().__init__()

Expand All @@ -321,7 +328,8 @@ def __init__(self,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
output_dim=embed_dim,
use_flash_attention=use_flash_attention
)

self.bert_config = BertConfig(
Expand All @@ -337,6 +345,7 @@ def __init__(self,
type_vocab_size=text_type_vocab_size,
initializer_range=text_initializer_range,
layer_norm_eps=1e-12,
use_flash_attention=use_flash_attention
)
self.bert = BertModel(self.bert_config)

Expand Down Expand Up @@ -453,7 +462,7 @@ def _convert_weights_to_fp16(l):
model.apply(_convert_weights_to_fp16)


def restore_model(model, clip_state_dict: dict, bert_state_dict: dict):
def restore_model(model, clip_state_dict: dict, bert_state_dict: dict, use_flash_attention: bool):
merged_state_dict = {}

# use clip_state_dict to initialize the image encoder & logit scale
Expand All @@ -468,12 +477,74 @@ def restore_model(model, clip_state_dict: dict, bert_state_dict: dict):
if k.startswith("bert") and "bert.pooler" not in k:
merged_state_dict[k] = v

# adapt flash attention
if use_flash_attention:
merged_state_dict = convert_state_dict(merged_state_dict)

convert_weights(model)
resize_pos_embed(merged_state_dict, model)
model.load_state_dict(merged_state_dict, strict=False)
return model.eval()


def convert_state_dict(state_dict):
"""Adapt to Flash Attention"""
if not state_dict:
return state_dict

prefix = 'module.' if list(state_dict.keys())[0].startswith('module') else ''

if f'{prefix}visual.transformer.resblocks.0.attn.in_proj_weight' in state_dict:
for k in list(state_dict.keys()):
if 'attn.in_proj_weight' in k:
state_dict[k.replace('attn.in_proj_weight', 'attn.Wqkv.weight')] = state_dict.pop(k)
elif 'attn.in_proj_bias' in k:
state_dict[k.replace('attn.in_proj_bias', 'attn.Wqkv.bias')] = state_dict.pop(k)
elif f'{prefix}visual.transformer.resblocks.0.attn.Wqkv.weight' in state_dict:
for k in list(state_dict.keys()):
if 'attn.Wqkv.weight' in k:
state_dict[k.replace('attn.Wqkv.weight', 'attn.in_proj_weight')] = state_dict.pop(k)
elif 'attn.Wqkv.bias' in k:
state_dict[k.replace('attn.Wqkv.bias', 'attn.in_proj_bias')] = state_dict.pop(k)

if f'{prefix}bert.encoder.layer.0.attention.self.query.weight' in state_dict:
i = 0
while f'{prefix}bert.encoder.layer.{i}.attention.self.query.weight' in state_dict:
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight'] = torch.cat(
(state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.query.weight'),
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.key.weight'),
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.value.weight'))
)
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.bias'] = torch.cat(
(state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.query.bias'),
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.key.bias'),
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.value.bias'))
)
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.out_proj.weight'] = \
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.output.dense.weight')
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.out_proj.bias'] = \
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.output.dense.bias')
i += 1
elif f'{prefix}bert.encoder.layer.0.attention.self.Wqkv.weight' in state_dict:
i = 0
while f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight' in state_dict:
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.query.weight'], \
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.key.weight'], \
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.value.weight'] = \
torch.chunk(state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight'), chunks=3)
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.query.bias'], \
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.key.bias'], \
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.self.value.bias'] = \
torch.chunk(state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.bias'), chunks=3)
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.output.dense.weight'] = \
state_dict.pop(f'{prefix}bert.encoder.layer.{i}.attention.self.out_proj.weight')
state_dict[f'{prefix}bert.encoder.layer.{i}.attention.output.dense.bias'] = \
state_dict.pop(f'module.bert.encoder.layer.{i}.attention.self.out_proj.bias')
i += 1

return state_dict


def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1, prefix=""):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get(prefix + 'visual.positional_embedding', None)
Expand Down
28 changes: 25 additions & 3 deletions cn_clip/clip/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from flash_attn.flash_attention import FlashMHA

from .configuration_bert import BertConfig

Expand Down Expand Up @@ -165,16 +166,25 @@ def forward(self, hidden_states, input_tensor):
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
self.self = BertSelfAttention(config) if not config.use_flash_attention else FlashMHA(config.hidden_size, config.num_attention_heads)
self.output = BertSelfOutput(config) if not config.use_flash_attention else BertSelfOutputForFlashAttention(config)
self.pruned_heads = set()
self.config = config

def forward(self, input_tensor, attention_mask=None, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask)
if not self.config.use_flash_attention:
self_outputs = self.self(input_tensor, attention_mask, head_mask)
else:
key_padding_mask = self.get_key_padding_mask(attention_mask)
self_outputs = self.self(input_tensor, key_padding_mask=key_padding_mask)
attention_output = self.output(self_outputs[0], input_tensor)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs

def get_key_padding_mask(self, attention_mask):
# key_padding_mask: bool tensor of shape (batch, seqlen)
return attention_mask.squeeze(1).squeeze(1) == 0


class BertIntermediate(nn.Module):
def __init__(self, config):
Expand Down Expand Up @@ -205,6 +215,18 @@ def forward(self, hidden_states, input_tensor):
return hidden_states


class BertSelfOutputForFlashAttention(nn.Module): # remove linear layer
def __init__(self, config):
super(BertSelfOutputForFlashAttention, self).__init__()
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states


class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
Expand Down
6 changes: 2 additions & 4 deletions cn_clip/clip/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch


def load(model, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", clip_path=None,
bert_path=None):
bert_path=None, use_flash_attention=False):
"""Load CLIP and BERT model weights
"""

bert_state_dict = torch.load(bert_path, map_location="cpu") if bert_path else None
clip_state_dict = torch.load(clip_path, map_location="cpu") if clip_path else None

restore_model(model, clip_state_dict, bert_state_dict).to(device)
restore_model(model, clip_state_dict, bert_state_dict, use_flash_attention).to(device)

if str(device) == "cpu":
model.float()
Expand Down Expand Up @@ -189,5 +189,3 @@ def create_model(model_name, checkpoint=None):
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
model.load_state_dict(sd)
return model


Loading

0 comments on commit 006a981

Please sign in to comment.