Skip to content

Commit

Permalink
Use Together Python library for Snowflake Arctic Instruct (#2599)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored May 6, 2024
1 parent d9e3864 commit c21cf3d
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 95 deletions.
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ openai =
google =
google-cloud-aiplatform~=1.44

together =
together~=1.1

tsinghua =
icetk~=0.0.4

Expand All @@ -152,6 +155,7 @@ models =
crfm-helm[google]
crfm-helm[mistral]
crfm-helm[openai]
crfm-helm[together]
crfm-helm[tsinghua]
crfm-helm[yandex]

Expand Down
90 changes: 0 additions & 90 deletions src/helm/benchmark/test_model_deployment_definition.py

This file was deleted.

95 changes: 93 additions & 2 deletions src/helm/clients/together_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from copy import deepcopy
from typing import List, Dict, Any, Optional, Union
from itertools import zip_longest
from typing import List, Dict, Any, Optional, TypedDict, Union

import requests
from retrying import retry

from helm.common.cache import CacheConfig
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
from .client import CachingClient, truncate_sequence, cleanup_str
from helm.clients.client import CachingClient, truncate_sequence, cleanup_str

try:
from together import Together
from together.types import ChatCompletionResponse
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["together"])


class _RewriteRequestTags:
Expand Down Expand Up @@ -272,3 +280,86 @@ def do_it_sync() -> Dict[Any, Any]:
completions=completions,
embedding=[],
)


class TogetherRawChatRequest(TypedDict):
messages: List[Dict[str, str]]
model: str
max_tokens: int
stop: List[str]
temperature: float
top_p: float
top_k: int
logprobs: int
echo: bool
n: int


def convert_to_raw_chat_request(request: Request) -> TogetherRawChatRequest:
if request.messages:
messages = request.messages
else:
messages = [{"role": "user", "content": request.prompt}]
return {
"messages": messages,
"model": request.model,
"max_tokens": request.max_tokens,
"stop": request.stop_sequences,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k_per_token,
"logprobs": min(request.top_k_per_token, 1),
"echo": request.echo_prompt,
"n": request.num_completions,
}


class TogetherChatClient(CachingClient):
"""Client that uses the Python Together library for chat models."""

def __init__(self, cache_config: CacheConfig, api_key: str, together_model: Optional[str] = None):
super().__init__(cache_config=cache_config)
self._client = Together(api_key=api_key)

def make_request(self, request: Request) -> RequestResult:
raw_request = convert_to_raw_chat_request(request)
cache_key = CachingClient.make_cache_key(raw_request, request)

def do_it() -> Dict[Any, Any]:
response = self._client.chat.completions.create(**raw_request)
return response.model_dump(mode="json")

try:
raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
response = ChatCompletionResponse.model_validate(raw_response)
except Exception as error:
return RequestResult(
success=False,
cached=False,
error=str(error),
completions=[],
embedding=[],
)

generated_outputs: List[GeneratedOutput] = []
for choice in response.choices:
# NOTE: Together always returns None for choice.finish_reason
# NOTE: Together does not return logprobs for the whole generated output, only for individual tokens
tokens: List[Token] = []
if choice.logprobs:
for token_text, token_logprob in zip_longest(
choice.logprobs.tokens or [], choice.logprobs.token_logprobs or []
):
if token_text is None:
break
tokens.append(Token(text=token_text, logprob=token_logprob or 0.0))
assert choice.message.role == "assistant"
generated_outputs.append(GeneratedOutput(text=choice.message.content, logprob=0.0, tokens=tokens))
return RequestResult(
success=True,
cached=cached,
request_time=raw_response["request_time"],
request_datetime=raw_response["request_datetime"],
completions=generated_outputs,
embedding=[],
)
2 changes: 1 addition & 1 deletion src/helm/common/images_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def copy_image(src: str, dest: str, width: Optional[int] = None, height: Optiona
if (width is not None and height is not None) or is_url(src):
image = open_image(src)
if width is not None and height is not None:
image = image.resize((width, height), Image.LANCZOS)
image = image.resize((width, height), Image.Resampling.LANCZOS)
image.save(dest)
else:
shutil.copy(src, dest)
Expand Down
4 changes: 2 additions & 2 deletions src/helm/config/model_deployments.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1759,9 +1759,9 @@ model_deployments:
- name: together/snowflake-arctic-instruct
model_name: snowflake/snowflake-arctic-instruct
tokenizer_name: snowflake/snowflake-arctic-instruct
max_sequence_length: 4095
max_sequence_length: 4000 # Lower than 4096 because of chat tokens
client_spec:
class_name: "helm.clients.together_client.TogetherClient"
class_name: "helm.clients.together_client.TogetherChatClient"

## Stanford
- name: together/alpaca-7b
Expand Down

0 comments on commit c21cf3d

Please sign in to comment.