From b1e38439310fd14369c2c55ef83c764f2b2f02d1 Mon Sep 17 00:00:00 2001 From: billytrend-cohere <144115527+billytrend-cohere@users.noreply.github.com> Date: Mon, 30 Oct 2023 18:20:55 +0000 Subject: [PATCH] Add client_name="langchain" to Cohere usage (#11328) Hey, we're looking to invest more in adding cohere integrations to langchain so would love to get more of an idea for how it's used. Hopefully this pr is acceptable. This week I'm also going to be looking into adding our new [retrieval augmented generation product](https://txt.cohere.com/chat-with-rag/) to langchain. --------- Co-authored-by: Bagatur --- libs/langchain/langchain/embeddings/cohere.py | 13 +++++++++++-- libs/langchain/langchain/llms/cohere.py | 10 ++++++++-- .../document_compressors/cohere_rerank.py | 12 +++++++----- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/libs/langchain/langchain/embeddings/cohere.py b/libs/langchain/langchain/embeddings/cohere.py index 1a6b3cdca0..e64019dd20 100644 --- a/libs/langchain/langchain/embeddings/cohere.py +++ b/libs/langchain/langchain/embeddings/cohere.py @@ -37,6 +37,8 @@ class CohereEmbeddings(BaseModel, Embeddings): """Maximum number of retries to make when generating.""" request_timeout: Optional[float] = None """Timeout in seconds for the Cohere API request.""" + user_agent: str = "langchain" + """Identifier for the application making the request.""" class Config: """Configuration for this pydantic object.""" @@ -55,11 +57,18 @@ class CohereEmbeddings(BaseModel, Embeddings): try: import cohere + client_name = values["user_agent"] values["client"] = cohere.Client( - cohere_api_key, max_retries=max_retries, timeout=request_timeout + cohere_api_key, + max_retries=max_retries, + timeout=request_timeout, + client_name=client_name, ) values["async_client"] = cohere.AsyncClient( - cohere_api_key, max_retries=max_retries, timeout=request_timeout + cohere_api_key, + max_retries=max_retries, + timeout=request_timeout, + client_name=client_name, ) except ImportError: raise ValueError( diff --git a/libs/langchain/langchain/llms/cohere.py b/libs/langchain/langchain/llms/cohere.py index c902b4c047..b24e10adb6 100644 --- a/libs/langchain/langchain/llms/cohere.py +++ b/libs/langchain/langchain/llms/cohere.py @@ -80,6 +80,9 @@ class BaseCohere(Serializable): streaming: bool = Field(default=False) """Whether to stream the results.""" + user_agent: str = "langchain" + """Identifier for the application making the request.""" + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -94,8 +97,11 @@ class BaseCohere(Serializable): cohere_api_key = get_from_dict_or_env( values, "cohere_api_key", "COHERE_API_KEY" ) - values["client"] = cohere.Client(cohere_api_key) - values["async_client"] = cohere.AsyncClient(cohere_api_key) + client_name = values["user_agent"] + values["client"] = cohere.Client(cohere_api_key, client_name=client_name) + values["async_client"] = cohere.AsyncClient( + cohere_api_key, client_name=client_name + ) return values diff --git a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py index 3db3f7c92f..d8790b3239 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py @@ -30,6 +30,8 @@ class CohereRerank(BaseDocumentCompressor): """Model to use for reranking.""" cohere_api_key: Optional[str] = None + user_agent: str = "langchain" + """Identifier for the application making the request.""" class Config: """Configuration for this pydantic object.""" @@ -40,18 +42,18 @@ class CohereRerank(BaseDocumentCompressor): @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - cohere_api_key = get_from_dict_or_env( - values, "cohere_api_key", "COHERE_API_KEY" - ) try: import cohere - - values["client"] = cohere.Client(cohere_api_key) except ImportError: raise ImportError( "Could not import cohere python package. " "Please install it with `pip install cohere`." ) + cohere_api_key = get_from_dict_or_env( + values, "cohere_api_key", "COHERE_API_KEY" + ) + client_name = values["user_agent"] + values["client"] = cohere.Client(cohere_api_key, client_name=client_name) return values def compress_documents(