diff --git a/libs/partners/cohere/langchain_cohere/embeddings.py b/libs/partners/cohere/langchain_cohere/embeddings.py index 2bf0933286..bd5193d317 100644 --- a/libs/partners/cohere/langchain_cohere/embeddings.py +++ b/libs/partners/cohere/langchain_cohere/embeddings.py @@ -45,6 +45,9 @@ class CohereEmbeddings(BaseModel, Embeddings): user_agent: str = "langchain" """Identifier for the application making the request.""" + base_url: Optional[str] = None + """Override the default Cohere API URL.""" + class Config: """Configuration for this pydantic object.""" @@ -64,11 +67,13 @@ class CohereEmbeddings(BaseModel, Embeddings): cohere_api_key, timeout=request_timeout, client_name=client_name, + base_url=values["base_url"], ) values["async_client"] = cohere.AsyncClient( cohere_api_key, timeout=request_timeout, client_name=client_name, + base_url=values["base_url"], ) return values diff --git a/libs/partners/cohere/langchain_cohere/llms.py b/libs/partners/cohere/langchain_cohere/llms.py index 4cf30e42a3..a95a5c09b5 100644 --- a/libs/partners/cohere/langchain_cohere/llms.py +++ b/libs/partners/cohere/langchain_cohere/llms.py @@ -69,6 +69,9 @@ class BaseCohere(Serializable): user_agent: str = "langchain" """Identifier for the application making the request.""" + base_url: Optional[str] = None + """Override the default Cohere API URL.""" + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -79,10 +82,12 @@ class BaseCohere(Serializable): values["client"] = cohere.Client( api_key=values["cohere_api_key"].get_secret_value(), client_name=client_name, + base_url=values["base_url"], ) values["async_client"] = cohere.AsyncClient( api_key=values["cohere_api_key"].get_secret_value(), client_name=client_name, + base_url=values["base_url"], ) return values