Skip to content

Commit

Permalink
add support for o1 but don't enable it (it's really slow and not ever…
Browse files Browse the repository at this point in the history
…yone has access to it)
  • Loading branch information
abi committed Dec 26, 2024
1 parent 3b3b5c1 commit adf50c9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
49 changes: 28 additions & 21 deletions backend/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,35 +50,42 @@ async def stream_openai_response(
params = {
"model": model.value,
"messages": messages,
"stream": True,
"timeout": 600,
"temperature": 0.0,
}

# Add 'max_tokens' only if the model is a GPT4 vision or Turbo model
if (
model == Llm.GPT_4_VISION
or model == Llm.GPT_4_TURBO_2024_04_09
or model == Llm.GPT_4O_2024_05_13
):
# O1 doesn't support streaming or temperature
if model != Llm.O1_2024_12_17:
params["temperature"] = 0
params["stream"] = True

# Add 'max_tokens' corresponding to the model
if model == Llm.GPT_4O_2024_05_13:
params["max_tokens"] = 4096

if model == Llm.GPT_4O_2024_11_20:
params["max_tokens"] = 16384

stream = await client.chat.completions.create(**params) # type: ignore
full_response = ""
async for chunk in stream: # type: ignore
assert isinstance(chunk, ChatCompletionChunk)
if (
chunk.choices
and len(chunk.choices) > 0
and chunk.choices[0].delta
and chunk.choices[0].delta.content
):
content = chunk.choices[0].delta.content or ""
full_response += content
await callback(content)
if model == Llm.O1_2024_12_17:
params["max_completion_tokens"] = 20000

# O1 doesn't support streaming
if model == Llm.O1_2024_12_17:
response = await client.chat.completions.create(**params) # type: ignore
full_response = response.choices[0].message.content # type: ignore
else:
stream = await client.chat.completions.create(**params) # type: ignore
full_response = ""
async for chunk in stream: # type: ignore
assert isinstance(chunk, ChatCompletionChunk)
if (
chunk.choices
and len(chunk.choices) > 0
and chunk.choices[0].delta
and chunk.choices[0].delta.content
):
content = chunk.choices[0].delta.content or ""
full_response += content
await callback(content)

await client.close()

Expand Down
4 changes: 2 additions & 2 deletions backend/routes/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ async def process_chunk(content: str, variantIndex: int):

tasks: List[Coroutine[Any, Any, Completion]] = []
for index, model in enumerate(variant_models):
if model == Llm.GPT_4O_2024_11_20:
if model == Llm.GPT_4O_2024_11_20 or model == Llm.O1_2024_12_17:
if openai_api_key is None:
await throw_error("OpenAI API key is missing.")
raise Exception("OpenAI API key is missing.")
Expand All @@ -296,7 +296,7 @@ async def process_chunk(content: str, variantIndex: int):
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x, i=index: process_chunk(x, i),
model=Llm.GPT_4O_2024_11_20,
model=model,
)
)
elif model == Llm.GEMINI_2_0_FLASH_EXP and GEMINI_API_KEY:
Expand Down

0 comments on commit adf50c9

Please sign in to comment.