-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
364 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Optional | ||
|
||
import paddle | ||
from paddlenlp.trainer import PdArgumentParser, Trainer, TrainingArguments | ||
from paddlenlp.transformers import PretrainedTokenizer | ||
from paddlenlp.transformers.qwen.configuration import QWenConfig | ||
|
||
from paddlemix import QWenLMHeadModel, QWenTokenizer | ||
from paddlemix.utils.log import logger | ||
|
||
IGNORE_TOKEN_ID = -100 | ||
|
||
|
||
@dataclass | ||
class ModelArguments: | ||
model_name_or_path: Optional[str] = field(default="qwen-vl/qwen-vl-chat-7b") | ||
dtype: str = "bfloat16" | ||
|
||
|
||
@dataclass | ||
class DataArguments: | ||
data_path: str = field(default=None, metadata={"help": "Path to the training data."}) | ||
lazy_preprocess: bool = False | ||
|
||
|
||
@dataclass | ||
class PreTrainingArguments(TrainingArguments): | ||
cache_dir: Optional[str] = field(default=None) | ||
optim: str = field(default="adamw") | ||
model_max_length: int = field( | ||
default=8192, | ||
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, | ||
) | ||
use_lora: bool = False | ||
fix_vit: bool = True | ||
|
||
|
||
def preprocess( | ||
sources, tokenizer: PretrainedTokenizer, max_len: int, system_message: str = "You are a helpful assistant." | ||
) -> Dict: | ||
roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"} | ||
im_start = tokenizer.im_start_id | ||
im_end = tokenizer.im_end_id | ||
nl_tokens = tokenizer("\n").input_ids | ||
_system = tokenizer("system").input_ids + nl_tokens | ||
|
||
input_ids, targets = [], [] | ||
for i, source in enumerate(sources): | ||
if roles[source[0]["from"]] != roles["user"]: | ||
source = source[1:] | ||
input_id, target = [], [] | ||
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens | ||
input_id += system | ||
target += [im_start] + [IGNORE_TOKEN_ID] * (len(system) - 3) + [im_end] + nl_tokens | ||
assert len(input_id) == len(target) | ||
for j, sentence in enumerate(source): | ||
role = roles[sentence["from"]] | ||
_input_id = ( | ||
tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens | ||
) | ||
input_id += _input_id | ||
if role == "<|im_start|>user": | ||
_target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id) - 3) + [im_end] + nl_tokens | ||
elif role == "<|im_start|>assistant": | ||
_target = ( | ||
[im_start] | ||
+ [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) | ||
+ _input_id[len(tokenizer(role).input_ids) + 1 : -2] | ||
+ [im_end] | ||
+ nl_tokens | ||
) | ||
else: | ||
raise NotImplementedError | ||
target += _target | ||
assert len(input_id) == len(target) | ||
input_id += [tokenizer.pad_token_id] * (max_len - len(input_id)) | ||
target += [IGNORE_TOKEN_ID] * (max_len - len(target)) | ||
input_ids.append(input_id[:max_len]) | ||
targets.append(target[:max_len]) | ||
input_ids = paddle.to_tensor(data=input_ids, dtype="int32") | ||
targets = paddle.to_tensor(data=targets, dtype="int32") | ||
return dict( | ||
input_ids=input_ids, | ||
labels=targets, | ||
attention_mask=input_ids.not_equal(y=paddle.to_tensor(tokenizer.pad_token_id, dtype="int32")), | ||
) | ||
|
||
|
||
class SupervisedDataset(paddle.io.Dataset): | ||
"""Dataset for supervised fine-tuning.""" | ||
|
||
def __init__(self, raw_data, tokenizer: PretrainedTokenizer, max_len: int): | ||
super(SupervisedDataset, self).__init__() | ||
|
||
sources = [example["conversations"] for example in raw_data] | ||
data_dict = preprocess(sources, tokenizer, max_len) | ||
self.input_ids = data_dict["input_ids"] | ||
self.labels = data_dict["labels"] | ||
self.attention_mask = data_dict["attention_mask"] | ||
|
||
def __len__(self): | ||
return len(self.input_ids) | ||
|
||
def __getitem__(self, i) -> Dict[str, paddle.Tensor]: | ||
return dict(input_ids=self.input_ids[i], labels=self.labels[i], attention_mask=self.attention_mask[i]) | ||
|
||
|
||
class LazySupervisedDataset(paddle.io.Dataset): | ||
"""Dataset for supervised fine-tuning.""" | ||
|
||
def __init__(self, raw_data, tokenizer: PretrainedTokenizer, max_len: int): | ||
super(LazySupervisedDataset, self).__init__() | ||
self.tokenizer = tokenizer | ||
self.max_len = max_len | ||
|
||
self.tokenizer = tokenizer | ||
self.raw_data = raw_data | ||
self.cached_data_dict = {} | ||
|
||
def __len__(self): | ||
return len(self.raw_data) | ||
|
||
def __getitem__(self, i) -> Dict[str, paddle.Tensor]: | ||
if i in self.cached_data_dict: | ||
return self.cached_data_dict[i] | ||
ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len) | ||
ret = dict(input_ids=ret["input_ids"][0], labels=ret["labels"][0], attention_mask=ret["attention_mask"][0]) | ||
self.cached_data_dict[i] = ret | ||
return ret | ||
|
||
|
||
def make_supervised_data_module(tokenizer: PretrainedTokenizer, data_args, max_len) -> Dict: | ||
"""Make dataset and collator for supervised fine-tuning.""" | ||
dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset | ||
|
||
train_json = json.load(open(data_args.data_path, "r")) | ||
train_dataset = dataset_cls(train_json, tokenizer=tokenizer, max_len=max_len) | ||
|
||
if data_args.eval_data_path: | ||
eval_json = json.load(open(data_args.eval_data_path, "r")) | ||
eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, max_len=max_len) | ||
else: | ||
eval_dataset = None | ||
|
||
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) | ||
|
||
|
||
def train(): | ||
global local_rank | ||
parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) | ||
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | ||
|
||
local_rank = training_args.local_rank | ||
|
||
if model_args.dtype == "bfloat16" and not paddle.amp.is_bfloat16_supported(): | ||
logger.warning("bfloat16 is not supported on your device,change to float32") | ||
model_args.dtype = "float32" | ||
|
||
config = QWenConfig.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir) | ||
config.use_cache = False | ||
config.dtype = model_args.dtype | ||
|
||
model = QWenLMHeadModel.from_pretrained( | ||
model_args.model_name_or_path, | ||
config=config, | ||
cache_dir=training_args.cache_dir, | ||
) | ||
|
||
if training_args.fix_vit and hasattr(model, "visual"): | ||
model.freeze_vit() | ||
|
||
tokenizer = QWenTokenizer.from_pretrained( | ||
model_args.model_name_or_path, | ||
cache_dir=training_args.cache_dir, | ||
model_max_length=training_args.model_max_length, | ||
padding_side="right", | ||
) | ||
tokenizer.pad_token_id = tokenizer.eod_id | ||
|
||
data_module = make_supervised_data_module( | ||
tokenizer=tokenizer, data_args=data_args, max_len=training_args.model_max_length | ||
) | ||
|
||
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) | ||
|
||
trainer.train() | ||
trainer.save_state() | ||
trainer.save_model() | ||
|
||
|
||
if __name__ == "__main__": | ||
train() |