From e162893d7fc38abc3cf46323162e8480a904499c Mon Sep 17 00:00:00 2001 From: Jorge Piedrahita Ortiz Date: Wed, 19 Jun 2024 12:26:56 -0500 Subject: [PATCH] community[patch]: update sambastudio embeddings (#23133) Description: update sambastudio embeddings integration, now compatible with generic endpoints and CoE endpoints --- .../text_embedding/sambanova.ipynb | 46 +++++ .../embeddings/sambanova.py | 160 +++++++++++++++--- 2 files changed, 179 insertions(+), 27 deletions(-) diff --git a/docs/docs/integrations/text_embedding/sambanova.ipynb b/docs/docs/integrations/text_embedding/sambanova.ipynb index f0e4131aa2..aa1da928d6 100644 --- a/docs/docs/integrations/text_embedding/sambanova.ipynb +++ b/docs/docs/integrations/text_embedding/sambanova.ipynb @@ -43,12 +43,14 @@ "import os\n", "\n", "sambastudio_base_url = \"\"\n", + "sambastudio_base_uri = \"\"\n", "sambastudio_project_id = \"\"\n", "sambastudio_endpoint_id = \"\"\n", "sambastudio_api_key = \"\"\n", "\n", "# Set the environment variables\n", "os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URL\"] = sambastudio_base_url\n", + "os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URI\"] = sambastudio_base_uri\n", "os.environ[\"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID\"] = sambastudio_project_id\n", "os.environ[\"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID\"] = sambastudio_endpoint_id\n", "os.environ[\"SAMBASTUDIO_EMBEDDINGS_API_KEY\"] = sambastudio_api_key" @@ -79,6 +81,50 @@ "results = embeddings.embed_documents(texts)\n", "print(results)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can manually pass the endpoint parameters and manually set the batch size you have in your SambaStudio embeddings endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = SambaStudioEmbeddings(\n", + " sambastudio_embeddings_base_url=sambastudio_base_url,\n", + " sambastudio_embeddings_base_uri=sambastudio_base_uri,\n", + " sambastudio_embeddings_project_id=sambastudio_project_id,\n", + " sambastudio_embeddings_endpoint_id=sambastudio_endpoint_id,\n", + " sambastudio_embeddings_api_key=sambastudio_api_key,\n", + " batch_size=32,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or You can use an embedding model expert included in your deployed CoE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = SambaStudioEmbeddings(\n", + " batch_size=1,\n", + " model_kwargs={\n", + " \"select_expert\": \"e5-mistral-7b-instruct\",\n", + " },\n", + ")" + ] } ], "metadata": { diff --git a/libs/community/langchain_community/embeddings/sambanova.py b/libs/community/langchain_community/embeddings/sambanova.py index a0efa3b685..1b360f2547 100644 --- a/libs/community/langchain_community/embeddings/sambanova.py +++ b/libs/community/langchain_community/embeddings/sambanova.py @@ -1,4 +1,5 @@ -from typing import Dict, Generator, List +import json +from typing import Dict, Generator, List, Optional import requests from langchain_core.embeddings import Embeddings @@ -10,8 +11,9 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): """SambaNova embedding models. To use, you should have the environment variables - ``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, - ``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, ``SAMBASTUDIO_EMBEDDINGS_API_KEY``, + ``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_BASE_URI`` + ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, ``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, + ``SAMBASTUDIO_EMBEDDINGS_API_KEY`` set with your personal sambastudio variable or pass it as a named parameter to the constructor. @@ -19,20 +21,34 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): .. code-block:: python from langchain_community.embeddings import SambaStudioEmbeddings + embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url, + sambastudio_embeddings_base_uri=base_uri, sambastudio_embeddings_project_id=project_id, sambastudio_embeddings_endpoint_id=endpoint_id, - sambastudio_embeddings_api_key=api_key) - (or) - embeddings = SambaStudioEmbeddings() - """ + sambastudio_embeddings_api_key=api_key, + batch_size=32) + (or) - API_BASE_PATH = "/api/predict/nlp/" - """Base path to use for the API usage""" + embeddings = SambaStudioEmbeddings(batch_size=32) + + (or) + + # CoE example + embeddings = SambaStudioEmbeddings( + batch_size=1, + model_kwargs={ + 'select_expert':'e5-mistral-7b-instruct' + } + ) + """ sambastudio_embeddings_base_url: str = "" """Base url to use""" + sambastudio_embeddings_base_uri: str = "" + """endpoint base uri""" + sambastudio_embeddings_project_id: str = "" """Project id on sambastudio for model""" @@ -42,12 +58,24 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): sambastudio_embeddings_api_key: str = "" """sambastudio api key""" + model_kwargs: dict = {} + """Key word arguments to pass to the model.""" + + batch_size: int = 32 + """Batch size for the embedding models""" + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["sambastudio_embeddings_base_url"] = get_from_dict_or_env( values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL" ) + values["sambastudio_embeddings_base_uri"] = get_from_dict_or_env( + values, + "sambastudio_embeddings_base_uri", + "SAMBASTUDIO_EMBEDDINGS_BASE_URI", + default="api/predict/generic", + ) values["sambastudio_embeddings_project_id"] = get_from_dict_or_env( values, "sambastudio_embeddings_project_id", @@ -63,6 +91,20 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): ) return values + def _get_tuning_params(self) -> str: + """ + Get the tuning parameters to use when calling the model + + Returns: + The tuning parameters as a JSON string. + """ + tuning_params_dict = { + k: {"type": type(v).__name__, "value": str(v)} + for k, v in (self.model_kwargs.items()) + } + tuning_params = json.dumps(tuning_params_dict) + return tuning_params + def _get_full_url(self, path: str) -> str: """ Return the full API URL for a given path. @@ -71,7 +113,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): :returns: the full API URL for the sub-path :rtype: str """ - return f"{self.sambastudio_embeddings_base_url}{self.API_BASE_PATH}{path}" + return f"{self.sambastudio_embeddings_base_url}/{self.sambastudio_embeddings_base_uri}/{path}" # noqa: E501 def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator: """Generator for creating batches in the embed documents method @@ -86,7 +128,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): yield texts[i : i + batch_size] def embed_documents( - self, texts: List[str], batch_size: int = 32 + self, texts: List[str], batch_size: Optional[int] = None ) -> List[List[float]]: """Returns a list of embeddings for the given sentences. Args: @@ -97,22 +139,56 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences """ + if batch_size is None: + batch_size = self.batch_size http_session = requests.Session() url = self._get_full_url( f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" ) - + params = json.loads(self._get_tuning_params()) embeddings = [] - for batch in self._iterate_over_batches(texts, batch_size): - data = {"inputs": batch} - response = http_session.post( - url, - headers={"key": self.sambastudio_embeddings_api_key}, - json=data, + if "nlp" in self.sambastudio_embeddings_base_uri: + for batch in self._iterate_over_batches(texts, batch_size): + data = {"inputs": batch, "params": params} + response = http_session.post( + url, + headers={"key": self.sambastudio_embeddings_api_key}, + json=data, + ) + try: + embedding = response.json()["data"] + embeddings.extend(embedding) + except KeyError: + raise KeyError( + "'data' not found in endpoint response", + response.json(), + ) + + elif "generic" in self.sambastudio_embeddings_base_uri: + for batch in self._iterate_over_batches(texts, batch_size): + data = {"instances": batch, "params": params} + response = http_session.post( + url, + headers={"key": self.sambastudio_embeddings_api_key}, + json=data, + ) + try: + if params.get("select_expert"): + embedding = response.json()["predictions"][0] + else: + embedding = response.json()["predictions"] + embeddings.extend(embedding) + except KeyError: + raise KeyError( + "'predictions' not found in endpoint response", + response.json(), + ) + + else: + raise ValueError( + f"handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented" # noqa: E501 ) - embedding = response.json()["data"] - embeddings.extend(embedding) return embeddings @@ -129,14 +205,44 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): url = self._get_full_url( f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" ) + params = json.loads(self._get_tuning_params()) - data = {"inputs": [text]} + if "nlp" in self.sambastudio_embeddings_base_uri: + data = {"inputs": [text], "params": params} + response = http_session.post( + url, + headers={"key": self.sambastudio_embeddings_api_key}, + json=data, + ) + try: + embedding = response.json()["data"][0] + except KeyError: + raise KeyError( + "'data' not found in endpoint response", + response.json(), + ) - response = http_session.post( - url, - headers={"key": self.sambastudio_embeddings_api_key}, - json=data, - ) - embedding = response.json()["data"][0] + elif "generic" in self.sambastudio_embeddings_base_uri: + data = {"instances": [text], "params": params} + response = http_session.post( + url, + headers={"key": self.sambastudio_embeddings_api_key}, + json=data, + ) + try: + if params.get("select_expert"): + embedding = response.json()["predictions"][0][0] + else: + embedding = response.json()["predictions"][0] + except KeyError: + raise KeyError( + "'predictions' not found in endpoint response", + response.json(), + ) + + else: + raise ValueError( + f"handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented" # noqa: E501 + ) return embedding