From bd02b83acd3951c507541f5ac67a54793101535f Mon Sep 17 00:00:00 2001 From: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com> Date: Fri, 29 Mar 2024 21:22:30 +0000 Subject: [PATCH] cohere[patch]: Allow overriding of the base URL in Cohere Client (#19766) This PR adds the ability for a user to override the base API url for the Cohere client for embeddings and chat llm. --- libs/partners/cohere/langchain_cohere/embeddings.py | 5 +++++ libs/partners/cohere/langchain_cohere/llms.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/libs/partners/cohere/langchain_cohere/embeddings.py b/libs/partners/cohere/langchain_cohere/embeddings.py index 2bf0933286..bd5193d317 100644 --- a/libs/partners/cohere/langchain_cohere/embeddings.py +++ b/libs/partners/cohere/langchain_cohere/embeddings.py @@ -45,6 +45,9 @@ class CohereEmbeddings(BaseModel, Embeddings): user_agent: str = "langchain" """Identifier for the application making the request.""" + base_url: Optional[str] = None + """Override the default Cohere API URL.""" + class Config: """Configuration for this pydantic object.""" @@ -64,11 +67,13 @@ class CohereEmbeddings(BaseModel, Embeddings): cohere_api_key, timeout=request_timeout, client_name=client_name, + base_url=values["base_url"], ) values["async_client"] = cohere.AsyncClient( cohere_api_key, timeout=request_timeout, client_name=client_name, + base_url=values["base_url"], ) return values diff --git a/libs/partners/cohere/langchain_cohere/llms.py b/libs/partners/cohere/langchain_cohere/llms.py index 4cf30e42a3..a95a5c09b5 100644 --- a/libs/partners/cohere/langchain_cohere/llms.py +++ b/libs/partners/cohere/langchain_cohere/llms.py @@ -69,6 +69,9 @@ class BaseCohere(Serializable): user_agent: str = "langchain" """Identifier for the application making the request.""" + base_url: Optional[str] = None + """Override the default Cohere API URL.""" + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -79,10 +82,12 @@ class BaseCohere(Serializable): values["client"] = cohere.Client( api_key=values["cohere_api_key"].get_secret_value(), client_name=client_name, + base_url=values["base_url"], ) values["async_client"] = cohere.AsyncClient( api_key=values["cohere_api_key"].get_secret_value(), client_name=client_name, + base_url=values["base_url"], ) return values