Add client_name="langchain" to Cohere usage (#11328)

Hey, we're looking to invest more in adding cohere integrations to
langchain so would love to get more of an idea for how it's used.
Hopefully this pr is acceptable. This week I'm also going to be looking
into adding our new [retrieval augmented generation
product](https://txt.cohere.com/chat-with-rag/) to langchain.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/12580/head
billytrend-cohere 9 months ago committed by GitHub
parent 37aec1e050
commit b1e3843931
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,6 +37,8 @@ class CohereEmbeddings(BaseModel, Embeddings):
"""Maximum number of retries to make when generating."""
request_timeout: Optional[float] = None
"""Timeout in seconds for the Cohere API request."""
user_agent: str = "langchain"
"""Identifier for the application making the request."""
class Config:
"""Configuration for this pydantic object."""
@ -55,11 +57,18 @@ class CohereEmbeddings(BaseModel, Embeddings):
try:
import cohere
client_name = values["user_agent"]
values["client"] = cohere.Client(
cohere_api_key, max_retries=max_retries, timeout=request_timeout
cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
)
values["async_client"] = cohere.AsyncClient(
cohere_api_key, max_retries=max_retries, timeout=request_timeout
cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
)
except ImportError:
raise ValueError(

@ -80,6 +80,9 @@ class BaseCohere(Serializable):
streaming: bool = Field(default=False)
"""Whether to stream the results."""
user_agent: str = "langchain"
"""Identifier for the application making the request."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@ -94,8 +97,11 @@ class BaseCohere(Serializable):
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
values["client"] = cohere.Client(cohere_api_key)
values["async_client"] = cohere.AsyncClient(cohere_api_key)
client_name = values["user_agent"]
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
values["async_client"] = cohere.AsyncClient(
cohere_api_key, client_name=client_name
)
return values

@ -30,6 +30,8 @@ class CohereRerank(BaseDocumentCompressor):
"""Model to use for reranking."""
cohere_api_key: Optional[str] = None
user_agent: str = "langchain"
"""Identifier for the application making the request."""
class Config:
"""Configuration for this pydantic object."""
@ -40,18 +42,18 @@ class CohereRerank(BaseDocumentCompressor):
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
try:
import cohere
values["client"] = cohere.Client(cohere_api_key)
except ImportError:
raise ImportError(
"Could not import cohere python package. "
"Please install it with `pip install cohere`."
)
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
client_name = values["user_agent"]
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
return values
def compress_documents(

Loading…
Cancel
Save