-
Notifications
You must be signed in to change notification settings - Fork 267
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Tony Lee <tonyh.lee@yahoo.com> Co-authored-by: JosselinSomervilleRoberts <josselin.somerville@gmail.com>
- Loading branch information
1 parent
5004acf
commit 481d12e
Showing
17 changed files
with
1,203 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .src.flamingo import Flamingo | ||
from .src.factory import create_model_and_transforms |
Empty file.
147 changes: 147 additions & 0 deletions
147
src/helm/clients/vision_language/open_flamingo/src/factory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
""" | ||
Source: https://github.com/mlfoundations/open_flamingo | ||
""" | ||
|
||
from typing import Optional | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from helm.common.general import handle_module_not_found_error | ||
from .flamingo import Flamingo | ||
from .flamingo_lm import FlamingoLMMixin | ||
from .utils import extend_instance | ||
|
||
|
||
def create_model_and_transforms( | ||
clip_vision_encoder_path: str, | ||
clip_vision_encoder_pretrained: str, | ||
lang_encoder_path: str, | ||
tokenizer_path: str, | ||
cross_attn_every_n_layers: int = 1, | ||
use_local_files: bool = False, | ||
decoder_layers_attr_name: str = None, | ||
freeze_lm_embeddings: bool = False, | ||
cache_dir: Optional[str] = None, | ||
**flamingo_kwargs, | ||
): | ||
""" | ||
Initialize a Flamingo model from a pretrained vision encoder and language encoder. | ||
Appends special tokens to the tokenizer and freezes backbones. | ||
Args: | ||
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") | ||
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") | ||
lang_encoder_path (str): path to pretrained language encoder | ||
tokenizer_path (str): path to pretrained tokenizer | ||
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. | ||
use_local_files (bool, optional): whether to use local files. Defaults to False. | ||
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. | ||
freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver. | ||
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights. | ||
Returns: | ||
Flamingo: Flamingo model from pretrained vision and language encoders | ||
Image processor: Pipeline to preprocess input images | ||
Tokenizer: A tokenizer for the language model | ||
""" | ||
try: | ||
import open_clip | ||
except ModuleNotFoundError as e: | ||
handle_module_not_found_error(e, ["vlm"]) | ||
|
||
vision_encoder, _, image_processor = open_clip.create_model_and_transforms( | ||
clip_vision_encoder_path, | ||
pretrained=clip_vision_encoder_pretrained, | ||
cache_dir=cache_dir, | ||
) | ||
# set the vision encoder to output the visual features | ||
vision_encoder.visual.output_tokens = True | ||
|
||
text_tokenizer = AutoTokenizer.from_pretrained( | ||
tokenizer_path, | ||
local_files_only=use_local_files, | ||
trust_remote_code=True, | ||
cache_dir=cache_dir, | ||
) | ||
# add Flamingo special tokens to the tokenizer | ||
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]}) | ||
if text_tokenizer.pad_token is None: | ||
# Issue: GPT models don't have a pad token, which we use to | ||
# modify labels for the loss. | ||
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"}) | ||
|
||
lang_encoder = AutoModelForCausalLM.from_pretrained( | ||
lang_encoder_path, | ||
local_files_only=use_local_files, | ||
trust_remote_code=True, | ||
cache_dir=cache_dir, | ||
) | ||
|
||
# hacks for MPT-1B, which doesn't have a get_input_embeddings method | ||
if "mpt-1b-redpajama-200b" in lang_encoder_path: | ||
|
||
class EmbeddingFnMixin: | ||
def get_input_embeddings(self): | ||
return self.transformer.wte | ||
|
||
def set_input_embeddings(self, new_embeddings): | ||
self.transformer.wte = new_embeddings | ||
|
||
extend_instance(lang_encoder, EmbeddingFnMixin) | ||
|
||
# convert LM to FlamingoLM | ||
extend_instance(lang_encoder, FlamingoLMMixin) | ||
|
||
if decoder_layers_attr_name is None: | ||
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) | ||
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) | ||
lang_encoder.resize_token_embeddings(len(text_tokenizer)) | ||
|
||
model = Flamingo( | ||
vision_encoder, | ||
lang_encoder, | ||
text_tokenizer.encode("<|endofchunk|>")[-1], | ||
text_tokenizer.encode("<image>")[-1], | ||
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"], | ||
cross_attn_every_n_layers=cross_attn_every_n_layers, | ||
**flamingo_kwargs, | ||
) | ||
|
||
# Freeze all parameters | ||
model.requires_grad_(False) | ||
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 | ||
|
||
# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings | ||
model.perceiver.requires_grad_(True) | ||
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) | ||
if not freeze_lm_embeddings: | ||
model.lang_encoder.get_input_embeddings().requires_grad_(True) | ||
# TODO: investigate also training the output embeddings when untied | ||
|
||
print( | ||
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" | ||
) | ||
|
||
return model, image_processor, text_tokenizer | ||
|
||
|
||
def _infer_decoder_layers_attr_name(model): | ||
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: | ||
if k.lower() in model.__class__.__name__.lower(): | ||
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] | ||
|
||
raise ValueError( | ||
"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. " | ||
"Please supply this string manually." | ||
) | ||
|
||
|
||
__KNOWN_DECODER_LAYERS_ATTR_NAMES = { | ||
"opt": "model.decoder.layers", | ||
"gptj": "transformer.h", | ||
"gpt-j": "transformer.h", | ||
"pythia": "gpt_neox.layers", | ||
"llama": "model.layers", | ||
"gptneoxforcausallm": "gpt_neox.layers", | ||
"mpt": "transformer.blocks", | ||
"mosaicgpt": "transformer.blocks", | ||
} |
Oops, something went wrong.