mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
143 lines
5.1 KiB
Python
143 lines
5.1 KiB
Python
|
from typing import Dict, Generator, List
|
||
|
|
||
|
import requests
|
||
|
from langchain_core.embeddings import Embeddings
|
||
|
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||
|
from langchain_core.utils import get_from_dict_or_env
|
||
|
|
||
|
|
||
|
class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||
|
"""SambaNova embedding models.
|
||
|
|
||
|
To use, you should have the environment variables
|
||
|
``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``,
|
||
|
``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, ``SAMBASTUDIO_EMBEDDINGS_API_KEY``,
|
||
|
set with your personal sambastudio variable or pass it as a named parameter
|
||
|
to the constructor.
|
||
|
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
|
from langchain_community.embeddings import SambaStudioEmbeddings
|
||
|
embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url,
|
||
|
sambastudio_embeddings_project_id=project_id,
|
||
|
sambastudio_embeddings_endpoint_id=endpoint_id,
|
||
|
sambastudio_embeddings_api_key=api_key)
|
||
|
(or)
|
||
|
embeddings = SambaStudioEmbeddings()
|
||
|
"""
|
||
|
|
||
|
API_BASE_PATH = "/api/predict/nlp/"
|
||
|
"""Base path to use for the API usage"""
|
||
|
|
||
|
sambastudio_embeddings_base_url: str = ""
|
||
|
"""Base url to use"""
|
||
|
|
||
|
sambastudio_embeddings_project_id: str = ""
|
||
|
"""Project id on sambastudio for model"""
|
||
|
|
||
|
sambastudio_embeddings_endpoint_id: str = ""
|
||
|
"""endpoint id on sambastudio for model"""
|
||
|
|
||
|
sambastudio_embeddings_api_key: str = ""
|
||
|
"""sambastudio api key"""
|
||
|
|
||
|
@root_validator()
|
||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||
|
"""Validate that api key and python package exists in environment."""
|
||
|
values["sambastudio_embeddings_base_url"] = get_from_dict_or_env(
|
||
|
values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL"
|
||
|
)
|
||
|
values["sambastudio_embeddings_project_id"] = get_from_dict_or_env(
|
||
|
values,
|
||
|
"sambastudio_embeddings_project_id",
|
||
|
"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID",
|
||
|
)
|
||
|
values["sambastudio_embeddings_endpoint_id"] = get_from_dict_or_env(
|
||
|
values,
|
||
|
"sambastudio_embeddings_endpoint_id",
|
||
|
"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID",
|
||
|
)
|
||
|
values["sambastudio_embeddings_api_key"] = get_from_dict_or_env(
|
||
|
values, "sambastudio_embeddings_api_key", "SAMBASTUDIO_EMBEDDINGS_API_KEY"
|
||
|
)
|
||
|
return values
|
||
|
|
||
|
def _get_full_url(self, path: str) -> str:
|
||
|
"""
|
||
|
Return the full API URL for a given path.
|
||
|
|
||
|
:param str path: the sub-path
|
||
|
:returns: the full API URL for the sub-path
|
||
|
:rtype: str
|
||
|
"""
|
||
|
return f"{self.sambastudio_embeddings_base_url}{self.API_BASE_PATH}{path}"
|
||
|
|
||
|
def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator:
|
||
|
"""Generator for creating batches in the embed documents method
|
||
|
Args:
|
||
|
texts (List[str]): list of strings to embed
|
||
|
batch_size (int, optional): batch size to be used for the embedding model.
|
||
|
Will depend on the RDU endpoint used.
|
||
|
Yields:
|
||
|
List[str]: list (batch) of strings of size batch size
|
||
|
"""
|
||
|
for i in range(0, len(texts), batch_size):
|
||
|
yield texts[i : i + batch_size]
|
||
|
|
||
|
def embed_documents(
|
||
|
self, texts: List[str], batch_size: int = 32
|
||
|
) -> List[List[float]]:
|
||
|
"""Returns a list of embeddings for the given sentences.
|
||
|
Args:
|
||
|
texts (`List[str]`): List of texts to encode
|
||
|
batch_size (`int`): Batch size for the encoding
|
||
|
|
||
|
Returns:
|
||
|
`List[np.ndarray]` or `List[tensor]`: List of embeddings
|
||
|
for the given sentences
|
||
|
"""
|
||
|
http_session = requests.Session()
|
||
|
url = self._get_full_url(
|
||
|
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
|
||
|
)
|
||
|
|
||
|
embeddings = []
|
||
|
|
||
|
for batch in self._iterate_over_batches(texts, batch_size):
|
||
|
data = {"inputs": batch}
|
||
|
response = http_session.post(
|
||
|
url,
|
||
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||
|
json=data,
|
||
|
)
|
||
|
embedding = response.json()["data"]
|
||
|
embeddings.extend(embedding)
|
||
|
|
||
|
return embeddings
|
||
|
|
||
|
def embed_query(self, text: str) -> List[float]:
|
||
|
"""Returns a list of embeddings for the given sentences.
|
||
|
Args:
|
||
|
sentences (`List[str]`): List of sentences to encode
|
||
|
|
||
|
Returns:
|
||
|
`List[np.ndarray]` or `List[tensor]`: List of embeddings
|
||
|
for the given sentences
|
||
|
"""
|
||
|
http_session = requests.Session()
|
||
|
url = self._get_full_url(
|
||
|
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
|
||
|
)
|
||
|
|
||
|
data = {"inputs": [text]}
|
||
|
|
||
|
response = http_session.post(
|
||
|
url,
|
||
|
headers={"key": self.sambastudio_embeddings_api_key},
|
||
|
json=data,
|
||
|
)
|
||
|
embedding = response.json()["data"][0]
|
||
|
|
||
|
return embedding
|