Skip to content

Commit

Permalink
Add _parse_json_output to LiteLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
henchaves committed Nov 19, 2024
1 parent 911d6e5 commit 40bede9
Showing 1 changed file with 61 additions and 10 deletions.
71 changes: 61 additions & 10 deletions giskard/llm/client/litellm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import Any, Dict, Optional, Sequence

import json
import logging
import re

from ...client.python_utils import warning
from ..errors import LLMImportError
from . import LLMClient
from .base import ChatMessage

logger = logging.getLogger(__name__)

try:
import litellm
except ImportError as err:
Expand All @@ -22,18 +28,53 @@ def _get_response_format(format):
return None


def _json_trim(response_message: str):
# Dumb trim for when model response message in addition to the JSON response
def _trim_json(response_message: str):
if "{" not in response_message or "}" not in response_message:
raise ValueError("The model output doesn't contain any JSON")
return response_message

json_start = response_message.index("{")
json_end = len(response_message) - response_message[::-1].index("}")

if json_start > json_end:
raise ValueError("The model output doesn't contain any JSON")
return response_message if json_start > json_end else response_message[json_start:json_end]


def _parse_json_output(
raw_json: str, llm_client: LLMClient, keys: Optional[Sequence[str]] = None, caller_id: Optional[str] = None
) -> dict:
try:
return json.loads(_trim_json(raw_json), strict=False)
except json.JSONDecodeError:
logger.debug("JSON decoding error, trying to fix the JSON string.")

logger.debug("Raw output: %s", raw_json)
# Let's see if it's just a matter of markdown format (```json ... ```)
match = re.search(r"```json\s{0,5}(.*?)\s{0,5}```", raw_json, re.DOTALL)
if match:
try:
return json.loads(match.group(1), strict=False)
except json.JSONDecodeError:
logger.debug("String matching didn't fix the format, trying to fix it with the LLM itself.")
pass

# Final attempt, let's try to fix the JSON with the LLM itself
out = llm_client.complete(
messages=[
ChatMessage(
role="system",
content="Fix the following text so it contains a single valid JSON object. You answer MUST start and end with curly brackets.",
),
ChatMessage(role="user", content=raw_json),
],
temperature=0,
caller_id=caller_id,
)

parsed_dict = json.loads(_trim_json(out.content), strict=False)

if keys is not None and any([k not in parsed_dict for k in keys]):
raise ValueError(f"Keys {keys} not found in the JSON output: {parsed_dict}")

return response_message
return parsed_dict


class LiteLLMClient(LLMClient):
Expand Down Expand Up @@ -91,7 +132,17 @@ def complete(

response_message = completion.choices[0].message

return ChatMessage(
role=response_message.role,
content=response_message.content if format is None else _json_trim(response_message.content),
)
if format:
# Max 3 attempts to parse the JSON output
for i in range(3):
try:
json_dict = _parse_json_output(response_message.content, self, caller_id=caller_id)
response_message.content = json.dumps(json_dict)
break
except ValueError as e:
if i == 2:
raise e
response_message = completion.choices[i + 1].message
continue

return ChatMessage(role=response_message.role, content=response_message.content)

0 comments on commit 40bede9

Please sign in to comment.