mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
cohere[patch]: Allow overriding of the base URL in Cohere Client (#19766)
This PR adds the ability for a user to override the base API url for the Cohere client for embeddings and chat llm.
This commit is contained in:
parent
1252ccce6f
commit
bd02b83acd
@ -45,6 +45,9 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
user_agent: str = "langchain"
|
||||
"""Identifier for the application making the request."""
|
||||
|
||||
base_url: Optional[str] = None
|
||||
"""Override the default Cohere API URL."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -64,11 +67,13 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
cohere_api_key,
|
||||
timeout=request_timeout,
|
||||
client_name=client_name,
|
||||
base_url=values["base_url"],
|
||||
)
|
||||
values["async_client"] = cohere.AsyncClient(
|
||||
cohere_api_key,
|
||||
timeout=request_timeout,
|
||||
client_name=client_name,
|
||||
base_url=values["base_url"],
|
||||
)
|
||||
|
||||
return values
|
||||
|
@ -69,6 +69,9 @@ class BaseCohere(Serializable):
|
||||
user_agent: str = "langchain"
|
||||
"""Identifier for the application making the request."""
|
||||
|
||||
base_url: Optional[str] = None
|
||||
"""Override the default Cohere API URL."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
@ -79,10 +82,12 @@ class BaseCohere(Serializable):
|
||||
values["client"] = cohere.Client(
|
||||
api_key=values["cohere_api_key"].get_secret_value(),
|
||||
client_name=client_name,
|
||||
base_url=values["base_url"],
|
||||
)
|
||||
values["async_client"] = cohere.AsyncClient(
|
||||
api_key=values["cohere_api_key"].get_secret_value(),
|
||||
client_name=client_name,
|
||||
base_url=values["base_url"],
|
||||
)
|
||||
return values
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user