Add client_name="langchain" to Cohere usage (#11328)

Hey, we're looking to invest more in adding cohere integrations to
langchain so would love to get more of an idea for how it's used.
Hopefully this pr is acceptable. This week I'm also going to be looking
into adding our new [retrieval augmented generation
product](https://txt.cohere.com/chat-with-rag/) to langchain.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/12580/head
billytrend-cohere 10 months ago committed by GitHub
parent 37aec1e050
commit b1e3843931
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,6 +37,8 @@ class CohereEmbeddings(BaseModel, Embeddings):
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
request_timeout: Optional[float] = None request_timeout: Optional[float] = None
"""Timeout in seconds for the Cohere API request.""" """Timeout in seconds for the Cohere API request."""
user_agent: str = "langchain"
"""Identifier for the application making the request."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -55,11 +57,18 @@ class CohereEmbeddings(BaseModel, Embeddings):
try: try:
import cohere import cohere
client_name = values["user_agent"]
values["client"] = cohere.Client( values["client"] = cohere.Client(
cohere_api_key, max_retries=max_retries, timeout=request_timeout cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
) )
values["async_client"] = cohere.AsyncClient( values["async_client"] = cohere.AsyncClient(
cohere_api_key, max_retries=max_retries, timeout=request_timeout cohere_api_key,
max_retries=max_retries,
timeout=request_timeout,
client_name=client_name,
) )
except ImportError: except ImportError:
raise ValueError( raise ValueError(

@ -80,6 +80,9 @@ class BaseCohere(Serializable):
streaming: bool = Field(default=False) streaming: bool = Field(default=False)
"""Whether to stream the results.""" """Whether to stream the results."""
user_agent: str = "langchain"
"""Identifier for the application making the request."""
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
@ -94,8 +97,11 @@ class BaseCohere(Serializable):
cohere_api_key = get_from_dict_or_env( cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY" values, "cohere_api_key", "COHERE_API_KEY"
) )
values["client"] = cohere.Client(cohere_api_key) client_name = values["user_agent"]
values["async_client"] = cohere.AsyncClient(cohere_api_key) values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
values["async_client"] = cohere.AsyncClient(
cohere_api_key, client_name=client_name
)
return values return values

@ -30,6 +30,8 @@ class CohereRerank(BaseDocumentCompressor):
"""Model to use for reranking.""" """Model to use for reranking."""
cohere_api_key: Optional[str] = None cohere_api_key: Optional[str] = None
user_agent: str = "langchain"
"""Identifier for the application making the request."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -40,18 +42,18 @@ class CohereRerank(BaseDocumentCompressor):
@root_validator(pre=True) @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
try: try:
import cohere import cohere
values["client"] = cohere.Client(cohere_api_key)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import cohere python package. " "Could not import cohere python package. "
"Please install it with `pip install cohere`." "Please install it with `pip install cohere`."
) )
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
client_name = values["user_agent"]
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
return values return values
def compress_documents( def compress_documents(

Loading…
Cancel
Save