Skip to content

Commit

Permalink
add Qwen-vl sft (#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
LokeZhou authored Nov 22, 2023
1 parent c3a1d42 commit f854c4a
Show file tree
Hide file tree
Showing 2 changed files with 364 additions and 0 deletions.
156 changes: 156 additions & 0 deletions paddlemix/examples/qwen_vl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,159 @@ python paddlemix/examples/qwen_vl/run_predict.py \
```bash
python paddlemix/examples/qwen_vl/chat_demo.py
```

## 2.3 stage3 微调
我们提供 `finetune.py` 脚本,用于 stage3 微调模型。
### 2.3.1 数据准备
将自己的数据放到一个列表中并存入json文件中,示例如下,或参考[sft_examples](https://bj.bcebos.com/v1/paddlenlp/models/community/qwen-vl/sft_examples.json)
```json
[
{
"id": "identity_0",
"conversations": [
{
"from": "user",
"value": "你好"
},
{
"from": "assistant",
"value": "我是Qwen-VL,一个支持视觉输入的大模型。"
}
]
},
{
"id": "identity_1",
"conversations": [
{
"from": "user",
"value": "Picture 1: <img>https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg\n图中的巴士是什么颜色的?"
},
{
"from": "assistant",
"value": "红色的。"
},
{
"from": "user",
"value": "框出图中的巴士的位置"
},
{
"from": "assistant",
"value": "<ref>巴士</ref><box>(178,279),(806,884)</box>"
}
]
},
{
"id": "identity_2",
"conversations": [
{
"from": "user",
"value": "Picture 1: <img>Chongqing.jpeg</img>\nPicture 2: <img>Beijing.jpeg</img>\n图中都是哪"
},
{
"from": "assistant",
"value": "第一张图片是重庆的城市天际线,第二张图片是北京的天际线。"
}
]
}
]
```

对于带图像输入的内容可表示为 `Picture id: <img>img_path</img>\n{your prompt}`,其中`id`表示对话中的第几张图片。"img_path"可以是本地的图片或网络地址。

对话中的检测框可以表示为`<box>(x1,y1),(x2,y2)</box>`,其中 `(x1, y1)``(x2, y2)`分别对应左上角和右下角的坐标,并且被归一化到`[0, 1000)`的范围内. 检测框对应的文本描述也可以通过`<ref>text_caption</ref>`表示。

### 2.3.2 训练
训练时使用`paddlemix/examples/qwen_vl/finetune.py`程序进行训练,**训练前请先检查数据集路径,如果使用url,请确保环境网络正常**。推荐使用A100训练。

训练命令及参数配置示例:
```
MODEL_NAME="qwen-vl/qwen-vl-chat-7b"
MASTER='127.0.0.1:8080'
DATA="train.json"
python -m paddle.distributed.launch --master ${MASTER} --nnodes 1 --nproc_per_node 8 \
paddlemix/examples/qwen_vl/finetune.py \
--model_name_or_path ${MODEL_NAME} \
--data_path ${DATA} \
--dtype 'bfloat16' \
--fix_vit True \
--output_dir output_qwen_vl \
--num_train_epochs 5 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--save_steps 1000 \
--save_strategy "steps" \
--save_total_limit 10 \
--learning_rate 1e-5 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "none" \
--model_max_length 2048 \
--lazy_preprocess True \
--sharding "stage2" \
--tensor_parallel_degree 1 \
--sharding_parallel_degree 8 \
--pipeline_parallel_degree 1
```


```
# 参数说明
--model_name_or_path #设置实际使用的模型,默认‘model_name_or_path’
--data_path #数据 json文件路径
--dtype #数据类型,默认‘bfloat16’
--fix_vit #训练时是否固定visual vit的参数,默认True
--output_dir #模型存储路径
--num_train_epochs #训练epoch次数
--per_device_train_batch_size #训练batch大小
--gradient_accumulation_steps #在执行backward更新过程之前,用于累积梯度的更新步骤数。默认16,即执行16个step后,更新一次参数
--save_strategy #训练期间要采用保存模型策略。可选择:
#“no”:在训练期间不进行任何保存。
#“epoch”`:每个epoch后保存。
#“steps”`:每“Save_steps”保存一次。
--save_steps #每多少个steps保存一次模型
--save_total_limit #最多保存多少个模型
--learning_rate #学习率
--adam_beta2 #optimizer中beta2参数
--warmup_ratio #学习率warm up比例
--weight_decay #权重衰减
--lr_scheduler_type 1 #学习率衰减策略,可选cosine、linear
--logging_steps #日志打印间隔
--report_to #日志集成,‘none’表示不集成,‘visualdl’表示集成到visualdl中
--model_max_length #模型最大长度,默认2048
--lazy_preprocess #lazy 数据加载
--tensor_parallel_degree # 模型并行系数,设置为N则进行N卡间模型并行。可选参数。
--sharding_parallel_degree #显存优化策略,详情参考 [《ZeRO: Memory Optimizations Toward Training Trillion Parameter Models》](https://arxiv.org/abs/1910.02054)可选参数。
--sharding #显存优化策略stage选择,目前支持stage1、stage2。可选参数。
--pipeline_parallel_degree #流水线并行。详情参考[飞桨大语言模型工具链](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/README.md)可选参数。
```

> 注:若不需要 sharding 策略,则无需指定tensor_parallel_degree、sharding_parallel_degree、sharding、pipeline_parallel_degree参数
208 changes: 208 additions & 0 deletions paddlemix/examples/qwen_vl/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from dataclasses import dataclass, field
from typing import Dict, Optional

import paddle
from paddlenlp.trainer import PdArgumentParser, Trainer, TrainingArguments
from paddlenlp.transformers import PretrainedTokenizer
from paddlenlp.transformers.qwen.configuration import QWenConfig

from paddlemix import QWenLMHeadModel, QWenTokenizer
from paddlemix.utils.log import logger

IGNORE_TOKEN_ID = -100


@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="qwen-vl/qwen-vl-chat-7b")
dtype: str = "bfloat16"


@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
lazy_preprocess: bool = False


@dataclass
class PreTrainingArguments(TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw")
model_max_length: int = field(
default=8192,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
use_lora: bool = False
fix_vit: bool = True


def preprocess(
sources, tokenizer: PretrainedTokenizer, max_len: int, system_message: str = "You are a helpful assistant."
) -> Dict:
roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}
im_start = tokenizer.im_start_id
im_end = tokenizer.im_end_id
nl_tokens = tokenizer("\n").input_ids
_system = tokenizer("system").input_ids + nl_tokens

input_ids, targets = [], []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != roles["user"]:
source = source[1:]
input_id, target = [], []
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
input_id += system
target += [im_start] + [IGNORE_TOKEN_ID] * (len(system) - 3) + [im_end] + nl_tokens
assert len(input_id) == len(target)
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
_input_id = (
tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
)
input_id += _input_id
if role == "<|im_start|>user":
_target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id) - 3) + [im_end] + nl_tokens
elif role == "<|im_start|>assistant":
_target = (
[im_start]
+ [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids)
+ _input_id[len(tokenizer(role).input_ids) + 1 : -2]
+ [im_end]
+ nl_tokens
)
else:
raise NotImplementedError
target += _target
assert len(input_id) == len(target)
input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))
target += [IGNORE_TOKEN_ID] * (max_len - len(target))
input_ids.append(input_id[:max_len])
targets.append(target[:max_len])
input_ids = paddle.to_tensor(data=input_ids, dtype="int32")
targets = paddle.to_tensor(data=targets, dtype="int32")
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.not_equal(y=paddle.to_tensor(tokenizer.pad_token_id, dtype="int32")),
)


class SupervisedDataset(paddle.io.Dataset):
"""Dataset for supervised fine-tuning."""

def __init__(self, raw_data, tokenizer: PretrainedTokenizer, max_len: int):
super(SupervisedDataset, self).__init__()

sources = [example["conversations"] for example in raw_data]
data_dict = preprocess(sources, tokenizer, max_len)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
self.attention_mask = data_dict["attention_mask"]

def __len__(self):
return len(self.input_ids)

def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i], attention_mask=self.attention_mask[i])


class LazySupervisedDataset(paddle.io.Dataset):
"""Dataset for supervised fine-tuning."""

def __init__(self, raw_data, tokenizer: PretrainedTokenizer, max_len: int):
super(LazySupervisedDataset, self).__init__()
self.tokenizer = tokenizer
self.max_len = max_len

self.tokenizer = tokenizer
self.raw_data = raw_data
self.cached_data_dict = {}

def __len__(self):
return len(self.raw_data)

def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
if i in self.cached_data_dict:
return self.cached_data_dict[i]
ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len)
ret = dict(input_ids=ret["input_ids"][0], labels=ret["labels"][0], attention_mask=ret["attention_mask"][0])
self.cached_data_dict[i] = ret
return ret


def make_supervised_data_module(tokenizer: PretrainedTokenizer, data_args, max_len) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset

train_json = json.load(open(data_args.data_path, "r"))
train_dataset = dataset_cls(train_json, tokenizer=tokenizer, max_len=max_len)

if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, max_len=max_len)
else:
eval_dataset = None

return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)


def train():
global local_rank
parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

local_rank = training_args.local_rank

if model_args.dtype == "bfloat16" and not paddle.amp.is_bfloat16_supported():
logger.warning("bfloat16 is not supported on your device,change to float32")
model_args.dtype = "float32"

config = QWenConfig.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir)
config.use_cache = False
config.dtype = model_args.dtype

model = QWenLMHeadModel.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
)

if training_args.fix_vit and hasattr(model, "visual"):
model.freeze_vit()

tokenizer = QWenTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
)
tokenizer.pad_token_id = tokenizer.eod_id

data_module = make_supervised_data_module(
tokenizer=tokenizer, data_args=data_args, max_len=training_args.model_max_length
)

trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)

trainer.train()
trainer.save_state()
trainer.save_model()


if __name__ == "__main__":
train()

0 comments on commit f854c4a

Please sign in to comment.