forked from THUDM/CogVLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcogvlm_model.py
165 lines (146 loc) · 8.01 KB
/
cogvlm_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from sat.model.official.llama_model import LLaMAModel
import json
import torch
from sat.model.base_model import BaseMixin
import torch.nn as nn
from models.mixin import LlamaVisionExpertFCMixin, LlamaVisionExpertAttnMixin
from sat.resources.urls import MODEL_URLS
MODEL_URLS["cogvlm-base-224"] = "r2://cogvlm-base-224.zip"
MODEL_URLS["cogvlm-base-490"] = "r2://cogvlm-base-224.zip"
MODEL_URLS["cogvlm-chat"] = "r2://cogvlm-chat.zip"
MODEL_URLS["cogvlm-grounding-base"] = "r2://cogvlm-grounding-base.zip"
MODEL_URLS["cogvlm-grounding-generalist"] = "r2://cogvlm-grounding-generalist.zip"
class GLU(nn.Module):
def __init__(self, args, in_features):
super().__init__()
self.linear_proj = nn.Linear(in_features, args.hidden_size, bias=False)
self.norm1 = nn.LayerNorm(args.hidden_size)
self.act1 = nn.GELU()
self.act2 = nn.functional.silu
self.dense_h_to_4h = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False)
self.gate_proj = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False)
self.dense_4h_to_h = nn.Linear(args.inner_hidden_size, args.hidden_size, bias=False)
def forward(self, x):
x = self.linear_proj(x)
x = self.act1(self.norm1(x))
x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
x = self.dense_4h_to_h(x)
return x
from models.eva_clip_model import EVA2CLIPModel
import argparse
from copy import deepcopy
def override_dist_dtype_device_args(args, b={}):
if args.mode == 'inference':
minimal_args = argparse.Namespace(
world_size=args.world_size,
rank=args.rank,
local_rank=args.local_rank,
skip_init=args.skip_init,
use_gpu_initialization=args.use_gpu_initialization,
deepspeed=args.deepspeed,
bf16=args.bf16,
fp16=args.fp16,
mode=args.mode,
device=args.device
)
else:
minimal_args = argparse.Namespace(
world_size=args.world_size,
rank=args.rank,
local_rank=args.local_rank,
skip_init=args.skip_init,
use_gpu_initialization=args.use_gpu_initialization,
deepspeed=args.deepspeed,
bf16=args.bf16,
fp16=args.fp16,
mode=args.mode,
checkpoint_activations=args.checkpoint_activations if not hasattr(args, 'vit_checkpoint_activations') else args.vit_checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
device=args.device,
hidden_dropout=0.,
attention_dropout=0.,
)
if hasattr(args, 'model_parallel_size'):
b['model_parallel_size'] = args.model_parallel_size
return argparse.Namespace(**deepcopy(b), **vars(minimal_args))
class ImageMixin(BaseMixin):
def __init__(self, args):
super().__init__()
vit_args = override_dist_dtype_device_args(args, args.eva_args)
self.vit_model = EVA2CLIPModel(EVA2CLIPModel.get_args(**vars(vit_args)))
self.in_features = 1792
self.linear_proj = GLU(args, self.in_features)
self.image_length = args.image_length
self.boi = nn.Parameter(torch.zeros(1, 1, args.hidden_size))
self.eoi = nn.Parameter(torch.zeros(1, 1, args.hidden_size))
def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args):
vision_inputs = {}
for k in kw_args:
if k.startswith('vision_') and k != 'vision_expert_mask':
vision_inputs[k[7:]] = kw_args[k]
if input_ids.shape[1] == 1 or not vision_inputs:
return self.transformer.word_embeddings(input_ids)
image_emb = self.vit_model(**vision_inputs)[0]
image_emb = self.linear_proj(image_emb)
image_embed_mask = kw_args['image_embed_mask']
word_embedding = self.transformer.word_embeddings(input_ids).clone()
word_embedding[image_embed_mask.bool()] = torch.cat([self.boi.repeat(len(image_emb), 1, 1), image_emb, self.eoi.repeat(len(image_emb), 1, 1)], dim=1).reshape(-1, image_emb.shape[-1])
return word_embedding.contiguous()
class CogVLMModel(LLaMAModel):
def __init__(self, args, transformer=None, parallel_output=True, **kwargs):
super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs)
self.image_length = args.image_length
self.add_mixin("eva", ImageMixin(args))
self.del_mixin("mlp")
self.add_mixin("mlp", LlamaVisionExpertFCMixin(args.hidden_size, args.inner_hidden_size, args.num_layers, 32))
self.del_mixin("rotary")
self.add_mixin("rotary", LlamaVisionExpertAttnMixin(args.hidden_size, args.num_attention_heads, args.num_layers, 32))
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('CogVLM', 'CogVLM Configurations')
group.add_argument('--image_length', type=int, default=256)
group.add_argument('--eva_args', type=json.loads, default={})
return super().add_model_specific_args(parser)
def forward(self, input_ids, vision_expert_mask, image_embed_mask, **kwargs):
if input_ids.shape[1] > 1:
return super().forward(input_ids=input_ids, vision_expert_mask=vision_expert_mask, image_embed_mask=image_embed_mask, **kwargs)
return super().forward(input_ids=input_ids, **kwargs)
class FineTuneTrainCogVLMModel(CogVLMModel):
def __init__(self, args, transformer=None, parallel_output=True, **kw_args):
super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kw_args)
self.args = args
# If you want to use model parallel with a mp_size=1 checkpoint, and meanwhile you also want to use lora,
# you have to add_mixin after loading model checkpoint.
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('CogVLM-finetune', 'CogVLM finetune Configurations')
group.add_argument('--pre_seq_len', type=int, default=8)
group.add_argument('--lora_rank', type=int, default=10)
group.add_argument('--use_ptuning', action="store_true")
group.add_argument('--use_lora', action="store_true")
group.add_argument('--use_qlora', action="store_true")
group.add_argument('--layer_range', nargs='+', type=int, default=None)
return super().add_model_specific_args(parser)
from sat.model.finetune import PTuningV2Mixin
from sat.model.finetune.lora2 import LoraMixin
class FineTuneTestCogVLMModel(CogVLMModel):
def __init__(self, args, transformer=None, parallel_output=True, **kw_args):
super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kw_args)
if args.use_ptuning:
self.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len))
if args.use_lora:
self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True)
self.get_mixin("eva").vit_model.add_mixin("lora", LoraMixin(args.eva_args['num_layers'], args.lora_rank, layer_range=args.layer_range), reinit=True)
elif args.use_qlora:
self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True)
self.args = args
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('CogVLM-finetune', 'CogVLM finetune Configurations')
group.add_argument('--pre_seq_len', type=int, default=8)
group.add_argument('--lora_rank', type=int, default=10)
group.add_argument('--use_ptuning', action="store_true")
group.add_argument('--use_lora', action="store_true")
group.add_argument('--use_qlora', action="store_true")
group.add_argument('--layer_range', nargs='+', type=int, default=None)
return super().add_model_specific_args(parser)