You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/embeddings/gradient_ai.py

167 lines
5.2 KiB
Python

from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from packaging.version import parse
__all__ = ["GradientEmbeddings"]
class GradientEmbeddings(BaseModel, Embeddings):
"""Gradient.ai Embedding models.
GradientLLM is a class to interact with Embedding Models on gradient.ai
To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
or alternatively provide them as keywords to the constructor of this class.
Example:
.. code-block:: python
from langchain_community.embeddings import GradientEmbeddings
GradientEmbeddings(
model="bge-large",
gradient_workspace_id="12345614fc0_workspace",
gradient_access_token="gradientai-access_token",
)
"""
model: str
"Underlying gradient.ai model id."
gradient_workspace_id: Optional[str] = None
"Underlying gradient.ai workspace_id."
gradient_access_token: Optional[str] = None
"""gradient.ai API Token, which can be generated by going to
https://auth.gradient.ai/select-workspace
and selecting "Access tokens" under the profile drop-down.
"""
gradient_api_url: str = "https://api.gradient.ai/api"
"""Endpoint URL to use."""
query_prompt_for_retrieval: Optional[str] = None
"""Query pre-prompt"""
client: Any = None #: :meta private:
"""Gradient client."""
# LLM call kwargs
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["gradient_access_token"] = get_from_dict_or_env(
values, "gradient_access_token", "GRADIENT_ACCESS_TOKEN"
)
values["gradient_workspace_id"] = get_from_dict_or_env(
values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID"
)
values["gradient_api_url"] = get_from_dict_or_env(
values, "gradient_api_url", "GRADIENT_API_URL"
)
try:
import gradientai
except ImportError:
raise ImportError(
'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.'
)
if parse(gradientai.__version__) < parse("1.4.0"):
raise ImportError(
'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.'
)
gradient = gradientai.Gradient(
access_token=values["gradient_access_token"],
workspace_id=values["gradient_workspace_id"],
host=values["gradient_api_url"],
)
values["client"] = gradient.get_embeddings_model(slug=values["model"])
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Gradient's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
inputs = [{"input": text} for text in texts]
result = self.client.embed(inputs=inputs).embeddings
return [e.embedding for e in result]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Async call out to Gradient's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
inputs = [{"input": text} for text in texts]
result = (await self.client.aembed(inputs=inputs)).embeddings
return [e.embedding for e in result]
def embed_query(self, text: str) -> List[float]:
"""Call out to Gradient's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
query = (
f"{self.query_prompt_for_retrieval} {text}"
if self.query_prompt_for_retrieval
else text
)
return self.embed_documents([query])[0]
async def aembed_query(self, text: str) -> List[float]:
"""Async call out to Gradient's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
query = (
f"{self.query_prompt_for_retrieval} {text}"
if self.query_prompt_for_retrieval
else text
)
embeddings = await self.aembed_documents([query])
return embeddings[0]
class TinyAsyncGradientEmbeddingClient: #: :meta private:
"""Deprecated, TinyAsyncGradientEmbeddingClient was removed.
This class is just for backwards compatibility with older versions
of langchain_community.
It might be entirely removed in the future.
"""
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")