Skip to content

Commit

Permalink
Added way to disable structured output
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmessiaen committed Nov 20, 2024
1 parent 40bede9 commit 77e6a4f
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 20 deletions.
4 changes: 2 additions & 2 deletions docs/open_source/scan/scan_llm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ import giskard
api_base = "http://localhost:11434" # default api_base for local Ollama

# See supported models here: https://docs.litellm.ai/docs/providers/ollama#ollama-models
giskard.llm.set_llm_model("ollama/llama3", api_base=api_base)
giskard.llm.set_llm_model("ollama/llama3", disable_structured_output=True, api_base=api_base)
giskard.llm.set_embedding_model("ollama/nomic-embed-text", api_base=api_base)
```

Expand Down Expand Up @@ -145,7 +145,7 @@ import giskard

os.environ["GEMINI_API_KEY"] = "" # "my-gemini-api-key"

giskard.llm.set_llm_model("gemini/gemini-pro")
giskard.llm.set_llm_model("gemini/gemini-1.5-pro")
giskard.llm.set_embedding_model("gemini/text-embedding-004")
```

Expand Down
4 changes: 2 additions & 2 deletions docs/open_source/setting_up/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ import giskard
api_base = "http://localhost:11434" # default api_base for local Ollama

# See supported models here: https://docs.litellm.ai/docs/providers/ollama#ollama-models
giskard.llm.set_llm_model("ollama/llama3", api_base=api_base)
giskard.llm.set_llm_model("ollama/llama3", disable_structured_output=True, api_base=api_base)
giskard.llm.set_embedding_model("ollama/nomic-embed-text", api_base=api_base)
```

Expand Down Expand Up @@ -145,7 +145,7 @@ import giskard

os.environ["GEMINI_API_KEY"] = "" # "my-gemini-api-key"

giskard.llm.set_llm_model("gemini/gemini-pro")
giskard.llm.set_llm_model("gemini/gemini-1.5-pro")
giskard.llm.set_embedding_model("gemini/text-embedding-004")
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ import giskard
api_base = "http://localhost:11434" # default api_base for local Ollama

# See supported models here: https://docs.litellm.ai/docs/providers/ollama#ollama-models
giskard.llm.set_llm_model("ollama/llama3", api_base=api_base)
giskard.llm.set_llm_model("ollama/llama3", disable_structured_output=True, api_base=api_base)
giskard.llm.set_embedding_model("ollama/nomic-embed-text", api_base=api_base)
```

Expand Down Expand Up @@ -183,7 +183,7 @@ import giskard

os.environ["GEMINI_API_KEY"] = "" # "my-gemini-api-key"

giskard.llm.set_llm_model("gemini/gemini-pro")
giskard.llm.set_llm_model("gemini/gemini-1.5-pro")
giskard.llm.set_embedding_model("gemini/text-embedding-004")
```

Expand Down
21 changes: 13 additions & 8 deletions giskard/llm/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_default_client = None

_default_llm_model = os.getenv("GSK_LLM_MODEL", "gpt-4o")
_disable_structured_output = False
_default_completion_params = dict()
_default_llm_api = None

Expand Down Expand Up @@ -61,19 +62,22 @@ def set_llm_api(llm_api: str):
@deprecated("set_default_client is deprecated: https://docs.giskard.ai/en/latest/open_source/setting_up/index.html")
def set_llm_base_url(llm_base_url: Optional[str]):
global _default_completion_params
_default_completion_params["api_base"] = os.getenv("GSK_LLM_BASE_URL")
_default_completion_params["api_base"] = llm_base_url


def set_llm_model(llm_model: str, disable_structured_output=False, **kwargs):
"""
def set_llm_model(llm_model: str, **kwargs):
:param llm_model: The model to be used
:param disable_structured_output: Set this to True when the used model doesn't support structured output
:param kwargs: Additional fixed params to be passed during completion
"""
global _default_llm_model
global _disable_structured_output
global _default_completion_params

if llm_model.startswith("ollama/"):
logger.warning(
"Giskard might not work properly with ollama. Please consider switching to another model provider."
)

_default_llm_model = llm_model
_disable_structured_output = disable_structured_output
_default_completion_params = kwargs

# If the model is set, we unset the default client
Expand All @@ -84,6 +88,7 @@ def get_default_client() -> LLMClient:
global _default_client
global _default_llm_api
global _default_llm_model
global _disable_structured_output

if _default_client is not None:
return _default_client
Expand All @@ -102,7 +107,7 @@ def get_default_client() -> LLMClient:
if _default_llm_api is not None and "/" not in _default_llm_model:
_default_llm_model = f"{_default_llm_api}/{_default_llm_model}"

_default_client = LiteLLMClient(_default_llm_model, _default_completion_params)
_default_client = LiteLLMClient(_default_llm_model, _disable_structured_output, _default_completion_params)
except ImportError:
raise ValueError(f"LLM scan using {_default_llm_model} requires litellm")

Expand Down
17 changes: 11 additions & 6 deletions giskard/llm/client/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ def _parse_json_output(


class LiteLLMClient(LLMClient):
def __init__(self, model: str = "gpt-4o", completion_params: Optional[Dict[str, Any]] = None):
def __init__(
self,
model: str = "gpt-4o",
disable_structured_output: bool = False,
completion_params: Optional[Dict[str, Any]] = None,
):
"""Initialize a LiteLLM completion client
Parameters
Expand All @@ -89,15 +94,12 @@ def __init__(self, model: str = "gpt-4o", completion_params: Optional[Dict[str,
A dictionary containing params for the completion.
"""
self.model = model
self.disable_structured_output = disable_structured_output
self.completion_params = completion_params or dict()

def _build_supported_completion_params(self, **kwargs):
supported_params = litellm.get_supported_openai_params(model=self.model)

# response_format causes issues with ollama: https://github.com/BerriAI/litellm/issues/6359
if self.model.startswith("ollama/"):
supported_params.remove("response_format")

return {
param_name: param_value
for param_name, param_value in kwargs.items()
Expand All @@ -117,7 +119,10 @@ def complete(
model=self.model,
messages=[{"role": message.role, "content": message.content} for message in messages],
**self._build_supported_completion_params(
temperature=temperature, max_tokens=max_tokens, seed=seed, response_format=_get_response_format(format)
temperature=temperature,
max_tokens=max_tokens,
seed=seed,
response_format=None if self.disable_structured_output else _get_response_format(format),
),
**self.completion_params,
)
Expand Down

0 comments on commit 77e6a4f

Please sign in to comment.