From a9bc212bf271ab9e9f24351bb1a0ff71d7cbec76 Mon Sep 17 00:00:00 2001 From: Alex Sherstinsky Date: Fri, 29 Mar 2024 17:38:13 -0700 Subject: [PATCH] community[minor]: fix failing Predibase integration (#19776) - [x] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** Langchain-Predibase integration was failing, because it was not current with the Predibase SDK; in addition, Predibase integration tests were instantiating the Langchain Community `Predibase` class with one required argument (`model`) missing. This change updates the Predibase SDK usage and fixes the integration tests. - **Twitter handle:** `@alexsherstinsky` --------- Co-authored-by: Bagatur --- .../langchain_community/llms/predibase.py | 34 ++++++++++++++++--- .../integration_tests/llms/test_predibase.py | 4 +-- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/libs/community/langchain_community/llms/predibase.py b/libs/community/langchain_community/llms/predibase.py index f04ce49398..182ee0acd3 100644 --- a/libs/community/langchain_community/llms/predibase.py +++ b/libs/community/langchain_community/llms/predibase.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, Union from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM @@ -15,6 +15,13 @@ class Predibase(LLM): model: str predibase_api_key: SecretStr model_kwargs: Dict[str, Any] = Field(default_factory=dict) + default_options_for_generation: dict = Field( + { + "max_new_tokens": 256, + "temperature": 0.1, + }, + const=True, + ) @property def _llm_type(self) -> str: @@ -29,8 +36,17 @@ class Predibase(LLM): ) -> str: try: from predibase import PredibaseClient + from predibase.pql import get_session + from predibase.pql.api import Session + from predibase.resource.llm.interface import LLMDeployment + from predibase.resource.llm.response import GeneratedResponse - pc = PredibaseClient(token=self.predibase_api_key.get_secret_value()) + session: Session = get_session( + token=self.predibase_api_key.get_secret_value(), + gateway="https://api.app.predibase.com/v1", + serving_endpoint="serving.app.predibase.com", + ) + pc: PredibaseClient = PredibaseClient(session=session) except ImportError as e: raise ImportError( "Could not import Predibase Python package. " @@ -38,9 +54,17 @@ class Predibase(LLM): ) from e except ValueError as e: raise ValueError("Your API key is not correct. Please try again") from e - # load model and version - results = pc.prompt(prompt, model_name=self.model) - return results[0].response + options: Dict[str, Union[str, float]] = ( + kwargs or self.default_options_for_generation + ) + base_llm_deployment: LLMDeployment = pc.LLM( + uri=f"pb://deployments/{self.model}" + ) + result: GeneratedResponse = base_llm_deployment.generate( + prompt=prompt, + options=options, + ) + return result.response @property def _identifying_params(self) -> Mapping[str, Any]: diff --git a/libs/community/tests/integration_tests/llms/test_predibase.py b/libs/community/tests/integration_tests/llms/test_predibase.py index 5e1d19e084..88ac72cfc8 100644 --- a/libs/community/tests/integration_tests/llms/test_predibase.py +++ b/libs/community/tests/integration_tests/llms/test_predibase.py @@ -5,14 +5,14 @@ from langchain_community.llms.predibase import Predibase def test_api_key_is_string() -> None: - llm = Predibase(predibase_api_key="secret-api-key") + llm = Predibase(model="my_llm", predibase_api_key="secret-api-key") assert isinstance(llm.predibase_api_key, SecretStr) def test_api_key_masked_when_passed_via_constructor( capsys: CaptureFixture, ) -> None: - llm = Predibase(predibase_api_key="secret-api-key") + llm = Predibase(model="my_llm", predibase_api_key="secret-api-key") print(llm.predibase_api_key, end="") # noqa: T201 captured = capsys.readouterr()