Skip to content

Commit

Permalink
Merge pull request #15 from neulab/jean-llavaone-molmo-llamavision
Browse files Browse the repository at this point in the history
Adding new models: llava-onevision, molmo, llama-3.2-vision
  • Loading branch information
xiangyue9607 authored Oct 11, 2024
2 parents 7ac093e + 6cb04d7 commit b2c18ff
Show file tree
Hide file tree
Showing 4 changed files with 1,154 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
"cambrian1": "Cambrian1",
"mblip": "MBlip",
"paligemma": "PaliGemma",
"llava_onevision": "Llava_OneVision",
"molmo": "Molmo",
"llama_vision": "LlamaVision",
}

for model_name, model_class in AVAILABLE_MODELS.items():
Expand Down
218 changes: 218 additions & 0 deletions lmms_eval/models/llama_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import warnings

warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore")

from transformers import MllamaForConditionalGeneration, AutoProcessor, AutoTokenizer
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
import torch
from typing import List, Optional, Union, Tuple
from lmms_eval import utils
from lmms_eval.api.instance import Instance
from tqdm import tqdm
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
from accelerate.state import AcceleratorState

from loguru import logger as eval_logger
from datetime import timedelta

@register_model("llama_vision")
class LlamaVision(lmms):
"""
LlamaVision Model
"""
def __init__(
self,
pretrained: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
device: Optional[str] = "cuda",
device_map="cuda:0",
max_new_tokens: int = 256,
batch_size: Optional[Union[int, str]] = 1,
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"

accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
self.accelerator = accelerator
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
elif accelerator.num_processes == 1 and device_map == "auto":
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"

self._tokenizer = AutoTokenizer.from_pretrained(pretrained)
self._model = MllamaForConditionalGeneration.from_pretrained(pretrained, torch_dtype=torch.bfloat16, device_map=self.device_map)
self.model.eval()
self.model.tie_weights()
self._config = self.model.config
self.processor = AutoProcessor.from_pretrained(pretrained)
self.max_new_tokens = max_new_tokens
self.batch_size_per_gpu = int(batch_size)
assert self.batch_size_per_gpu == 1, "Batch size must be 1 for LlamaVision model"
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")

if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == "auto":
eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
self._rank = 0
self._word_size = 1
else:
eval_logger.info(f"Using single device: {self._device}")
self.model.to(self._device)
self._rank = 0
self._world_size = 1

@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config

@property
def tokenizer(self):
return self._tokenizer

@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model

@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id

@property
def batch_size(self):
return self.batch_size_per_gpu

@property
def device(self):
return self._device

@property
def rank(self):
return self._rank

@property
def world_size(self):
return self._world_size

def flatten(self, input, only_get_first=False):
new_list = []
for i in input:
for j in i:
new_list.append(j)
if only_get_first:
break
return new_list

def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
""" """
add_special_tokens = False if add_special_tokens is None else add_special_tokens
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding

def generate_until(self, requests: List[Instance]) -> List[str]:
res = []

def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]

re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")

for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
visuals = self.flatten(visuals)
gen_kwargs = all_gen_kwargs[0]
input_text = contexts[0]
image_list = [{"type": "image"} for _ in range(len(visuals))]
image_list.append({"type": "text", "text": input_text})
message = [
{"role": "user", "content": image_list}
]

prompts = self.processor.apply_chat_template(message, add_generation_prompt=True)
model_inputs = self.processor(
images=visuals,
text=prompts,
add_special_tokens=False,
return_tensors="pt"
)
model_inputs = model_inputs.to(self._model.device)

# preconfigure gen_kwargs with defaults
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = self.max_new_tokens
if "max_length" in gen_kwargs and "max_new_tokens" in gen_kwargs:
gen_kwargs.pop("max_length")
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
if "until" in gen_kwargs:
gen_kwargs.pop("until")
if "do_sample" not in gen_kwargs:
gen_kwargs["do_sample"] = False

generation_output = self._model.generate(**model_inputs, **gen_kwargs)
generated_tokens = generation_output[0, model_inputs['input_ids'].size(1):]
response = self.processor.decode(generated_tokens, skip_special_tokens=True)
assert type(response) == str
res.append(response)
pbar.update(1)

# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
pbar.close()
return res

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
assert False, "Not implemented yet."
Loading

0 comments on commit b2c18ff

Please sign in to comment.