Skip to content

Commit

Permalink
community[minor]: fix failing Predibase integration (langchain-ai#19776)
Browse files Browse the repository at this point in the history
- [x] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [x] **PR message**: ***Delete this entire checklist*** and replace
with
- **Description:** Langchain-Predibase integration was failing, because
it was not current with the Predibase SDK; in addition, Predibase
integration tests were instantiating the Langchain Community `Predibase`
class with one required argument (`model`) missing. This change updates
the Predibase SDK usage and fixes the integration tests.
    - **Twitter handle:** `@alexsherstinsky`


---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
alexsherstinsky and baskaryan authored Mar 30, 2024
1 parent e9caa22 commit a9bc212
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
34 changes: 29 additions & 5 deletions libs/community/langchain_community/llms/predibase.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional, Union

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
Expand All @@ -15,6 +15,13 @@ class Predibase(LLM):
model: str
predibase_api_key: SecretStr
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
default_options_for_generation: dict = Field(
{
"max_new_tokens": 256,
"temperature": 0.1,
},
const=True,
)

@property
def _llm_type(self) -> str:
Expand All @@ -29,18 +36,35 @@ def _call(
) -> str:
try:
from predibase import PredibaseClient
from predibase.pql import get_session
from predibase.pql.api import Session
from predibase.resource.llm.interface import LLMDeployment
from predibase.resource.llm.response import GeneratedResponse

pc = PredibaseClient(token=self.predibase_api_key.get_secret_value())
session: Session = get_session(
token=self.predibase_api_key.get_secret_value(),
gateway="https://api.app.predibase.com/v1",
serving_endpoint="serving.app.predibase.com",
)
pc: PredibaseClient = PredibaseClient(session=session)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install predibase`."
) from e
except ValueError as e:
raise ValueError("Your API key is not correct. Please try again") from e
# load model and version
results = pc.prompt(prompt, model_name=self.model)
return results[0].response
options: Dict[str, Union[str, float]] = (
kwargs or self.default_options_for_generation
)
base_llm_deployment: LLMDeployment = pc.LLM(
uri=f"pb://deployments/{self.model}"
)
result: GeneratedResponse = base_llm_deployment.generate(
prompt=prompt,
options=options,
)
return result.response

@property
def _identifying_params(self) -> Mapping[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions libs/community/tests/integration_tests/llms/test_predibase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@


def test_api_key_is_string() -> None:
llm = Predibase(predibase_api_key="secret-api-key")
llm = Predibase(model="my_llm", predibase_api_key="secret-api-key")
assert isinstance(llm.predibase_api_key, SecretStr)


def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
llm = Predibase(predibase_api_key="secret-api-key")
llm = Predibase(model="my_llm", predibase_api_key="secret-api-key")
print(llm.predibase_api_key, end="") # noqa: T201
captured = capsys.readouterr()

Expand Down

0 comments on commit a9bc212

Please sign in to comment.