Skip to content

Commit

Permalink
Add OpenFlamingo (#2237)
Browse files Browse the repository at this point in the history
Co-authored-by: Tony Lee <tonyh.lee@yahoo.com>
Co-authored-by: JosselinSomervilleRoberts <josselin.somerville@gmail.com>
  • Loading branch information
3 people authored Mar 4, 2024
1 parent 5004acf commit 481d12e
Show file tree
Hide file tree
Showing 17 changed files with 1,203 additions and 3 deletions.
18 changes: 18 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ certifi==2024.2.2
cffi==1.16.0
cfgv==3.4.0
charset-normalizer==2.1.1
Cython==0.29.32
einops-exts==0.0.4
emoji==2.1.0
et-xmlfile==1.1.0
chex==0.1.7
click==8.1.7
clip-anytorch==2.5.2
Expand Down Expand Up @@ -145,6 +149,10 @@ nltk==3.8.1
nodeenv==1.8.0
NudeNet==2.0.9
numba==0.56.4
open-clip-torch==2.24.0
openpyxl==3.0.10
outcome==1.2.0
pathy==0.10.2
numpy==1.23.5
oauthlib==3.2.2
omegaconf==2.3.0
Expand Down Expand Up @@ -253,6 +261,15 @@ timm==0.6.13
tokenizers==0.15.2
toml==0.10.2
tomli==2.0.1
trio==0.22.0
trio-websocket==0.9.2
types-Pillow==9.3.0.4
types-pytz==2022.4.0.0
types-redis==4.3.21.1
types-requests==2.28.11.2
types-tabulate==0.9.0.0
types-urllib3==1.26.25
typing==3.7.4.3
toolz==0.12.1
torch~=2.1.2
torch-fidelity==0.3.0
Expand Down Expand Up @@ -281,3 +298,4 @@ xxhash==3.4.1
yarl==1.9.4
zipp==3.17.0
zstandard==0.18.0
fairlearn==0.9.0
12 changes: 9 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,14 @@ models =
crfm-helm[yandex]

vlm =
# For OpenFlamingo
einops~=0.7.0
einops-exts~=0.0.4
open-clip-torch~=2.24.0

# VLM models
crfm-helm[openai]
torch~=2.1.2 # For IDEFICS
torch~=2.1.2 # For IDEFICS

# VLM scenarios
crfm-helm[images]
Expand Down Expand Up @@ -178,7 +183,7 @@ heim =
crfm-helm[openai]

# For model, kakaobrain/mindall-e
einops~=0.6.0
einops~=0.7.0
omegaconf~=2.3.0
pytorch-lightning~=2.0.5

Expand Down Expand Up @@ -259,6 +264,7 @@ exclude =
venv/*
src/helm/clients/image_generation/dalle_mini/*
src/helm/clients/image_generation/mindalle/*
src/helm/clients/vision_language/open_flamingo/*

# Ignore completely:
# E203 - White space before ':', (conflicts with black)
Expand All @@ -276,7 +282,7 @@ check_untyped_defs = True
disable_error_code = annotation-unchecked
# TODO: Change disallow_untyped_defs to True
disallow_untyped_defs = False
exclude = dalle_mini|mindalle
exclude = dalle_mini|mindalle|open_flamingo

[tool:pytest]
addopts =
Expand Down
4 changes: 4 additions & 0 deletions src/helm/benchmark/model_metadata_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,17 @@
IDEFICS_MODEL_TAG: str = "IDEFICS_MODEL_TAG"
# Llava should use a special prompt format (see `LlavaRunExpander`)
LLAVA_MODEL_TAG: str = "LLAVA_MODEL_TAG"
# OpenFlamingo has a special prompt format (see `OpenFlamingoRunExpander`)
OPEN_FLAMINGO_MODEL_TAG: str = "OPEN_FLAMINGO_MODEL_TAG"
# Some VLMs do not support multiple images in the prompt
LIMITED_FUNCTIONALITY_VLM_TAG: str = "LIMITED_FUNCTIONALITY_VLM_TAG"
FULL_FUNCTIONALITY_VLM_TAG: str = "FULL_FUNCTIONALITY_VLM_TAG"


# Frozen is set to false as the model_deployment_registry.py file
# might populate the deployment_names field.


@dataclass(frozen=False)
class ModelMetadata:
name: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ entries: [
# sheetmusic2lilypond
{description: "sheetmusic2lilypond:model=vlm", priority: 1}

# webpages
{description: "image2webpage:subset=css,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=html,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=javascript,model=vlm", priority: 1, groups: ["image2webpage"]}

# chart2csv
# {description: "chart2csv:model=vlm", priority: 1}
]
20 changes: 20 additions & 0 deletions src/helm/benchmark/run_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,26 @@ def expand(self, run_spec: RunSpec) -> List[RunSpec]:
]


class OpenFlamingoRunExpander(RunExpander):
"""
Custom prompt for OpenFlamingo following: https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b
"""

name = "open_flamingo"

def expand(self, run_spec: RunSpec) -> List[RunSpec]:
return [
replace(
run_spec,
name=run_spec.name,
adapter_spec=replace(
run_spec.adapter_spec,
input_prefix=f"<|endofchunk|>{run_spec.adapter_spec.input_prefix}",
),
),
]


class FormatPromptRunExpander(RunExpander):
"""Adds a prefix and suffix to the prompt."""

Expand Down
6 changes: 6 additions & 0 deletions src/helm/benchmark/run_spec_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
GOOGLE_PALM_2_MODEL_TAG,
IDEFICS_INSTRUCT_MODEL_TAG,
LLAVA_MODEL_TAG,
OPEN_FLAMINGO_MODEL_TAG,
NLG_PREFIX_TAG,
NO_NEWLINES_TAG,
OPENAI_CHATGPT_MODEL_TAG,
Expand All @@ -33,6 +34,7 @@
IDEFICSInstructRunExpander,
IncreaseTemperatureRunExpander,
LlavaRunExpander,
OpenFlamingoRunExpander,
OpenAIRunExpander,
MistralRunExpander,
StopRunExpander,
Expand Down Expand Up @@ -147,6 +149,10 @@ def alter_run_spec(run_spec: RunSpec) -> RunSpec:
if LLAVA_MODEL_TAG in model.tags:
run_spec = singleton(LlavaRunExpander().expand(run_spec))

# OpenFlamingo
if OPEN_FLAMINGO_MODEL_TAG in model.tags:
run_spec = singleton(OpenFlamingoRunExpander().expand(run_spec))

# For multiple choice
if BUGGY_TEMP_0_TAG in model.tags and run_spec.adapter_spec.temperature == 0:
increase_temperature_expander = IncreaseTemperatureRunExpander(value=1e-4)
Expand Down
2 changes: 2 additions & 0 deletions src/helm/clients/vision_language/open_flamingo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .src.flamingo import Flamingo
from .src.factory import create_model_and_transforms
Empty file.
147 changes: 147 additions & 0 deletions src/helm/clients/vision_language/open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
Source: https://github.com/mlfoundations/open_flamingo
"""

from typing import Optional

from transformers import AutoModelForCausalLM, AutoTokenizer

from helm.common.general import handle_module_not_found_error
from .flamingo import Flamingo
from .flamingo_lm import FlamingoLMMixin
from .utils import extend_instance


def create_model_and_transforms(
clip_vision_encoder_path: str,
clip_vision_encoder_pretrained: str,
lang_encoder_path: str,
tokenizer_path: str,
cross_attn_every_n_layers: int = 1,
use_local_files: bool = False,
decoder_layers_attr_name: str = None,
freeze_lm_embeddings: bool = False,
cache_dir: Optional[str] = None,
**flamingo_kwargs,
):
"""
Initialize a Flamingo model from a pretrained vision encoder and language encoder.
Appends special tokens to the tokenizer and freezes backbones.
Args:
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
lang_encoder_path (str): path to pretrained language encoder
tokenizer_path (str): path to pretrained tokenizer
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
use_local_files (bool, optional): whether to use local files. Defaults to False.
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Image processor: Pipeline to preprocess input images
Tokenizer: A tokenizer for the language model
"""
try:
import open_clip
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["vlm"])

vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
clip_vision_encoder_path,
pretrained=clip_vision_encoder_pretrained,
cache_dir=cache_dir,
)
# set the vision encoder to output the visual features
vision_encoder.visual.output_tokens = True

text_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
local_files_only=use_local_files,
trust_remote_code=True,
cache_dir=cache_dir,
)
# add Flamingo special tokens to the tokenizer
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
if text_tokenizer.pad_token is None:
# Issue: GPT models don't have a pad token, which we use to
# modify labels for the loss.
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})

lang_encoder = AutoModelForCausalLM.from_pretrained(
lang_encoder_path,
local_files_only=use_local_files,
trust_remote_code=True,
cache_dir=cache_dir,
)

# hacks for MPT-1B, which doesn't have a get_input_embeddings method
if "mpt-1b-redpajama-200b" in lang_encoder_path:

class EmbeddingFnMixin:
def get_input_embeddings(self):
return self.transformer.wte

def set_input_embeddings(self, new_embeddings):
self.transformer.wte = new_embeddings

extend_instance(lang_encoder, EmbeddingFnMixin)

# convert LM to FlamingoLM
extend_instance(lang_encoder, FlamingoLMMixin)

if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(len(text_tokenizer))

model = Flamingo(
vision_encoder,
lang_encoder,
text_tokenizer.encode("<|endofchunk|>")[-1],
text_tokenizer.encode("<image>")[-1],
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"],
cross_attn_every_n_layers=cross_attn_every_n_layers,
**flamingo_kwargs,
)

# Freeze all parameters
model.requires_grad_(False)
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0

# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
model.perceiver.requires_grad_(True)
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
if not freeze_lm_embeddings:
model.lang_encoder.get_input_embeddings().requires_grad_(True)
# TODO: investigate also training the output embeddings when untied

print(
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
)

return model, image_processor, text_tokenizer


def _infer_decoder_layers_attr_name(model):
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
if k.lower() in model.__class__.__name__.lower():
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]

raise ValueError(
"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. "
"Please supply this string manually."
)


__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
"opt": "model.decoder.layers",
"gptj": "transformer.h",
"gpt-j": "transformer.h",
"pythia": "gpt_neox.layers",
"llama": "model.layers",
"gptneoxforcausallm": "gpt_neox.layers",
"mpt": "transformer.blocks",
"mosaicgpt": "transformer.blocks",
}
Loading

0 comments on commit 481d12e

Please sign in to comment.