You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/embeddings/cohere.py

148 lines
4.6 KiB
Python

from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
class CohereEmbeddings(BaseModel, Embeddings):
"""Cohere embedding models.
To use, you should have the ``cohere`` python package installed, and the
environment variable ``COHERE_API_KEY`` set with your API key or pass it
as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain_community.embeddings import CohereEmbeddings
cohere = CohereEmbeddings(
model="embed-english-light-v3.0",
cohere_api_key="my-api-key"
)
"""
client: Any #: :meta private:
"""Cohere client."""
async_client: Any #: :meta private:
"""Cohere async client."""
model: str = "embed-english-v2.0"
"""Model name to use."""
truncate: Optional[str] = None
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
cohere_api_key: Optional[str] = None
max_retries: Optional[int] = 3
"""Maximum number of retries to make when generating."""
request_timeout: Optional[float] = None
"""Timeout in seconds for the Cohere API request."""
user_agent: str = "langchain"
"""Identifier for the application making the request."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
max_retries = values.get("max_retries")
request_timeout = values.get("request_timeout")
try:
import cohere
client_name = values["user_agent"]
values["client"] = cohere.Client(
cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
)
values["async_client"] = cohere.AsyncClient(
cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
)
except ImportError:
raise ValueError(
"Could not import cohere python package. "
"Please install it with `pip install cohere`."
)
return values
def embed(
self, texts: List[str], *, input_type: Optional[str] = None
) -> List[List[float]]:
embeddings = self.client.embed(
model=self.model,
texts=texts,
input_type=input_type,
truncate=self.truncate,
).embeddings
return [list(map(float, e)) for e in embeddings]
async def aembed(
self, texts: List[str], *, input_type: Optional[str] = None
) -> List[List[float]]:
embeddings = (
await self.async_client.embed(
model=self.model,
texts=texts,
input_type=input_type,
truncate=self.truncate,
)
).embeddings
return [list(map(float, e)) for e in embeddings]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of document texts.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
return self.embed(texts, input_type="search_document")
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Async call out to Cohere's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
return await self.aembed(texts, input_type="search_document")
def embed_query(self, text: str) -> List[float]:
"""Call out to Cohere's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed([text], input_type="search_query")[0]
async def aembed_query(self, text: str) -> List[float]:
"""Async call out to Cohere's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return (await self.aembed([text], input_type="search_query"))[0]