Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: continuous batching with vLLM #349

Merged
merged 5 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
chore: make sure to call generation within no grad
Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Sep 13, 2023
commit aa2c440607d0ade69071ba3fdbb5a4fb4800edac
4 changes: 2 additions & 2 deletions openllm-core/src/openllm_core/_typing_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class LLMRunnable(bentoml.Runnable, t.Generic[M, T]):
embeddings: RunnableMethod[LLMRunnable[M, T], [list[str]], EmbeddingsOutput]
generate: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
generate_one: RunnableMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]]
generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.Generator[str, None, str]]
generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.Iterator[t.Any]]

class LLMRunner(bentoml.Runner, t.Generic[M, T]):
__doc__: str
Expand All @@ -111,7 +111,7 @@ class LLMRunner(bentoml.Runner, t.Generic[M, T]):
embeddings: RunnerMethod[LLMRunnable[M, T], [list[str]], t.Sequence[EmbeddingsOutput]]
generate: RunnerMethod[LLMRunnable[M, T], [str], list[t.Any]]
generate_one: RunnerMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]]
generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Generator[str, None, str]]
generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Iterator[t.Any]]

def __init__(self,
runnable_class: type[LLMRunnable[M, T]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class StarCoderConfig(openllm_core.LLMConfig):
'url': 'https://github.com/bigcode-project/starcoder',
'architecture': 'GPTBigCodeForCausalLM',
'requirements': ['bitsandbytes'],
'workers_per_resource': 0.5,
'default_id': 'bigcode/starcoder',
'model_ids': ['bigcode/starcoder', 'bigcode/starcoderbase']
}
Expand Down
164 changes: 80 additions & 84 deletions openllm-python/src/openllm/_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,11 +640,9 @@ def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any:
quantization_config,
_quantize,
model_id,
args, {
**model_kwds, **normalized_model_kwds
}, {
**tokenizer_kwds, **normalized_tokenizer_kwds
},
args,
{**model_kwds, **normalized_model_kwds},
{**tokenizer_kwds, **normalized_tokenizer_kwds},
_tag,
_adapters_mapping,
_model_version,
Expand Down Expand Up @@ -751,7 +749,7 @@ def model(self) -> M:
# If OOM, then it is probably you don't have enough VRAM to run this model.
if self.__llm_backend__ == 'pt' and is_torch_available():
loaded_in_kbit = getattr(model, 'is_loaded_in_8bit', False) or getattr(model, 'is_loaded_in_4bit', False) or getattr(model, 'is_quantized', False)
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit:
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit and not isinstance(model, transformers.Pipeline):
try:
model = model.to('cuda')
except Exception as err:
Expand Down Expand Up @@ -971,92 +969,90 @@ def generate_iterator(self,
from ._generation import prepare_logits_processor

len_prompt = len(prompt)
config = self.config.model_construct_env(**attrs)
if stop_token_ids is None: stop_token_ids = []
stop_token_ids.append(self.tokenizer.eos_token_id)

logits_processor = prepare_logits_processor(self.config)
logits_processor = prepare_logits_processor(config)

input_ids = self.tokenizer(prompt).input_ids
with torch.inference_mode():
input_ids = self.tokenizer(prompt).input_ids

if context_length is None: context_length = get_context_length(self.model.config)
max_src_len = context_length - self.config['max_new_tokens'] - 1
if context_length is None: context_length = get_context_length(self.model.config)
max_src_len = context_length - config['max_new_tokens'] - 1

input_ids = input_ids[-max_src_len:]
output_ids = list(input_ids)
input_echo_len = len(input_ids)
input_ids = input_ids[-max_src_len:]
output_ids = list(input_ids)
input_echo_len = len(input_ids)

past_key_values = out = token = None
for i in range(self.config['max_new_tokens']):
torch.cuda.synchronize()
if i == 0: # prefill
out = self.model(torch.as_tensor([input_ids], device=self.device), use_cache=True)
past_key_values = out = None
finish_reason = None
for i in range(config['max_new_tokens']):
torch.cuda.synchronize()
if i == 0: # prefill
out = self.model(torch.as_tensor([input_ids], device=self.device), use_cache=True)
else: # decoding
out = self.model(torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
else: # decoding
out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values

if logits_processor:
if self.config['repetition_penalty'] > 1.0:
tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=self.device)
torch.cuda.synchronize()

if logits_processor:
if config['repetition_penalty'] > 1.0:
tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=self.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
else:
last_token_logits = logits[0, -1, :]
last_token_logits = logits[0, -1, :]

# Switch to CPU by avoiding some bugs in mps backend.
if self.device.type == 'mps': last_token_logits = last_token_logits.float().to('cpu')
# Switch to CPU by avoiding some bugs in mps backend.
if self.device.type == 'mps': last_token_logits = last_token_logits.float().to('cpu')

if self.config['temperature'] < 1e-5 or self.config['top_p'] < 1e-8:
token = int(torch.argmax(last_token_logits)) # greedy
else:
probs = torch.softmax(last_token_logits, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
torch.cuda.synchronize()

if token in stop_token_ids: stopped = True
else: stopped = False

# Yield the output tokens
if i % stream_interval == 0 or i == self.config['max_new_tokens'] - 1 or stopped:
if echo:
tmp_output_ids = output_ids
rfind_start = len_prompt
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8:
token = int(torch.argmax(last_token_logits)) # greedy
else:
tmp_output_ids = output_ids[input_echo_len:]
rfind_start = 0
output = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)

partially_stopped = False
if stop:
if isinstance(stop, str):
pos = output.rfind(stop, rfind_start)
if pos != -1: output, stopped = output[:pos], True
else: partially_stopped = is_partial_stop(output, stop)
elif isinstance(stop, t.Iterable):
for each_stop in stop:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output, stopped = output[:pos], True
break
else:
partially_stopped = is_partial_stop(output, each_stop)
if partially_stopped: break
else: raise ValueError('Invalid stop field type.')

# Prevent yielding partial stop sequence
if not partially_stopped:
yield {'text': output, 'usage': {'prompt_tokens': input_echo_len, 'completion_tokens': i, 'total_tokens': input_echo_len + i}, 'finish_reason': None}
if stopped: break

# Finish stream event, which contains finish reason
if i == self.config['max_new_tokens'] - 1: finish_reason = 'length'
elif stopped: finish_reason = 'stop'
else: finish_reason = None
yield {'text': output, 'usage': {'prompt_tokens': input_echo_len, 'completion_tokens': i, 'total_tokens': input_echo_len + i}, 'finish_reason': finish_reason}
probs = torch.softmax(last_token_logits, dim=-1)
indices = torch.multinomial(probs, num_samples=2)
token = int(indices.tolist()[0])
output_ids.append(token)

stopped = token in stop_token_ids

# Yield the output tokens
if i % stream_interval == 0 or i == config['max_new_tokens'] - 1 or stopped:
if echo:
tmp_output_ids = output_ids
rfind_start = len_prompt
else:
tmp_output_ids = output_ids[input_echo_len:]
rfind_start = 0
output = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)

partially_stopped = False
if stop:
if isinstance(stop, str):
pos = output.rfind(stop, rfind_start)
if pos != -1: output, stopped = output[:pos], True
else: partially_stopped = is_partial_stop(output, stop)
elif isinstance(stop, t.Iterable):
for each_stop in stop:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output, stopped = output[:pos], True
break
else:
partially_stopped = is_partial_stop(output, each_stop)
if partially_stopped: break
else: raise ValueError('Invalid stop field type.')

# Prevent yielding partial stop sequence
if not partially_stopped:
yield {'text': output, 'usage': {'prompt_tokens': input_echo_len, 'completion_tokens': i, 'total_tokens': input_echo_len + i}, 'finish_reason': None}
if stopped: break
else: finish_reason = 'length' # finish stream events
if stopped: finish_reason = 'stop'
yield {'text': output, 'usage': {'prompt_tokens': input_echo_len, 'completion_tokens': i, 'total_tokens': input_echo_len + i}, 'finish_reason': finish_reason}
# Clean
del past_key_values, out
gc.collect()
Expand Down Expand Up @@ -1178,30 +1174,30 @@ def set_adapter(__self: _Runnable, adapter_name: str) -> None:
if adapter_name != 'default': self.model.set_adapter(adapter_name)
logger.info('Successfully apply LoRA layer %s', adapter_name)

@bentoml.Runnable.method(**method_signature(embeddings_sig)) # type: ignore
@bentoml.Runnable.method(**method_signature(embeddings_sig))
def embeddings(__self: _Runnable, prompt: str | list[str]) -> t.Sequence[EmbeddingsOutput]:
return [self.embeddings([prompt] if isinstance(prompt, str) else prompt)]

@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
@bentoml.Runnable.method(**method_signature(generate_sig))
def __call__(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
return self.generate(prompt, **attrs)

@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
@bentoml.Runnable.method(**method_signature(generate_sig))
def generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
if __self.backend == 'vllm': attrs.setdefault('request_id', openllm_core.utils.gen_random_uuid())
return self.generate(prompt, **attrs)

@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
@bentoml.Runnable.method(**method_signature(generate_sig))
def generate_one(__self: _Runnable, prompt: str, stop: list[str], **attrs: t.Any) -> t.Sequence[dict[t.Literal['generated_text'], str]]:
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
return self.generate_one(prompt, stop, **attrs)

@bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore
@bentoml.Runnable.method(**method_signature(generate_iterator_sig))
def generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.Generator[str, None, str]:
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
Expand Down
2 changes: 2 additions & 0 deletions openllm-python/src/openllm/cli/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,8 @@ def query_command(
res = client.query(prompt, return_response='raw', **{**client.configuration, **_memoized})
if output == 'pretty':
response = client.config.postprocess_generate(prompt, res['responses'])
if client.backend == 'vllm': response = response['outputs'][0]['text']
else: response = response['text']
termui.echo('\n\n==Responses==\n', fg='white')
termui.echo(response, fg=generated_fg)
elif output == 'json':
Expand Down
16 changes: 0 additions & 16 deletions openllm-python/src/openllm/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,3 @@ class GPTNeoX(openllm.LLM['transformers.GPTNeoXForCausalLM', 'transformers.GPTNe
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
import torch
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}

def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM:
import transformers
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs)
if self.config.use_half_precision: model.half()
return model

def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
import torch
with torch.inference_mode():
return self.tokenizer.batch_decode(
self.model.generate(self.tokenizer(prompt, return_tensors='pt').to(self.device).input_ids,
do_sample=True,
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
pad_token_id=self.tokenizer.eos_token_id,
stopping_criteria=openllm.StoppingCriteriaList([openllm.StopOnTokens()])))
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
import logging
import typing as t

import bentoml
Expand Down Expand Up @@ -31,16 +30,3 @@ def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t
return bentoml.transformers.save_model(self.tag, model, custom_objects={'tokenizer': tokenizer}, labels=generate_labels(self))
finally:
torch.cuda.empty_cache()

def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
import torch
with torch.inference_mode():
# eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder
# NOTE: support fine-tuning starcoder
result_tensor = self.model.generate(self.tokenizer.encode(prompt, return_tensors='pt').to(self.device),
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
generation_config=self.config.model_construct_env(**attrs).to_generation_config())
# TODO: We will probably want to return the tokenizer here so that we can manually process this
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)