Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow explicitly setting the temperature for API model #121

Merged
merged 2 commits into from
Jul 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions opencompass/models/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .base_api import BaseAPIModel

PromptType = Union[PromptList, str]
OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'


@MODELS.register_module()
Expand Down Expand Up @@ -40,28 +41,32 @@ class OpenAI(BaseAPIModel):
wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'.
temperature (float, optional): What sampling temperature to use.
If not None, will override the temperature in the `generate()`
call. Defaults to None.
"""

is_api: bool = True

def __init__(
self,
path: str,
max_seq_len: int = 2048,
query_per_second: int = 1,
retry: int = 2,
key: Union[str, List[str]] = 'ENV',
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = None,
openai_api_base: str = 'https://api.openai.com/v1/chat/completions'
): # noqa
def __init__(self,
path: str,
max_seq_len: int = 2048,
query_per_second: int = 1,
retry: int = 2,
key: Union[str, List[str]] = 'ENV',
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = None,
openai_api_base: str = OPENAI_API_BASE,
temperature: Optional[float] = None):

super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template,
query_per_second=query_per_second,
retry=retry)
import tiktoken
self.tiktoken = tiktoken
self.temperature = temperature

if isinstance(key, str):
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
Expand Down Expand Up @@ -96,6 +101,9 @@ def generate(
Returns:
List[str]: A list of generated strings.
"""
if self.temperature is not None:
temperature = self.temperature

with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
Expand Down