diff --git a/libs/langchain/langchain/embeddings/cohere.py b/libs/langchain/langchain/embeddings/cohere.py index 1a6b3cdca0..e64019dd20 100644 --- a/libs/langchain/langchain/embeddings/cohere.py +++ b/libs/langchain/langchain/embeddings/cohere.py @@ -37,6 +37,8 @@ class CohereEmbeddings(BaseModel, Embeddings): """Maximum number of retries to make when generating.""" request_timeout: Optional[float] = None """Timeout in seconds for the Cohere API request.""" + user_agent: str = "langchain" + """Identifier for the application making the request.""" class Config: """Configuration for this pydantic object.""" @@ -55,11 +57,18 @@ class CohereEmbeddings(BaseModel, Embeddings): try: import cohere + client_name = values["user_agent"] values["client"] = cohere.Client( - cohere_api_key, max_retries=max_retries, timeout=request_timeout + cohere_api_key, + max_retries=max_retries, + timeout=request_timeout, + client_name=client_name, ) values["async_client"] = cohere.AsyncClient( - cohere_api_key, max_retries=max_retries, timeout=request_timeout + cohere_api_key, + max_retries=max_retries, + timeout=request_timeout, + client_name=client_name, ) except ImportError: raise ValueError( diff --git a/libs/langchain/langchain/llms/cohere.py b/libs/langchain/langchain/llms/cohere.py index c902b4c047..b24e10adb6 100644 --- a/libs/langchain/langchain/llms/cohere.py +++ b/libs/langchain/langchain/llms/cohere.py @@ -80,6 +80,9 @@ class BaseCohere(Serializable): streaming: bool = Field(default=False) """Whether to stream the results.""" + user_agent: str = "langchain" + """Identifier for the application making the request.""" + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -94,8 +97,11 @@ class BaseCohere(Serializable): cohere_api_key = get_from_dict_or_env( values, "cohere_api_key", "COHERE_API_KEY" ) - values["client"] = cohere.Client(cohere_api_key) - values["async_client"] = cohere.AsyncClient(cohere_api_key) + client_name = values["user_agent"] + values["client"] = cohere.Client(cohere_api_key, client_name=client_name) + values["async_client"] = cohere.AsyncClient( + cohere_api_key, client_name=client_name + ) return values diff --git a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py index 3db3f7c92f..d8790b3239 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py @@ -30,6 +30,8 @@ class CohereRerank(BaseDocumentCompressor): """Model to use for reranking.""" cohere_api_key: Optional[str] = None + user_agent: str = "langchain" + """Identifier for the application making the request.""" class Config: """Configuration for this pydantic object.""" @@ -40,18 +42,18 @@ class CohereRerank(BaseDocumentCompressor): @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - cohere_api_key = get_from_dict_or_env( - values, "cohere_api_key", "COHERE_API_KEY" - ) try: import cohere - - values["client"] = cohere.Client(cohere_api_key) except ImportError: raise ImportError( "Could not import cohere python package. " "Please install it with `pip install cohere`." ) + cohere_api_key = get_from_dict_or_env( + values, "cohere_api_key", "COHERE_API_KEY" + ) + client_name = values["user_agent"] + values["client"] = cohere.Client(cohere_api_key, client_name=client_name) return values def compress_documents(