Add cohere /chat integration (#11389)

Add cohere /chat integration and an iPython notebook to demonstrate the
addition.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11443/head
billytrend-cohere 10 months ago committed by GitHub
parent ca346011b7
commit 2ff91a46c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "bf733a38-db84-4363-89e2-de6735c37230",
"metadata": {},
"source": [
"# Cohere\n",
"\n",
"This notebook covers how to get started with Cohere chat models."
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chat_models import ChatCohere\n",
"from langchain.schema import AIMessage, HumanMessage"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat = ChatCohere()"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"Who's there?\")"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" HumanMessage(\n",
" content=\"knock knock\"\n",
" )\n",
"]\n",
"chat(messages)"
]
},
{
"cell_type": "markdown",
"id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c",
"metadata": {},
"source": [
"## `ChatCohere` also supports async and streaming functionality:"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.callbacks.manager import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Who's there?"
]
},
{
"data": {
"text/plain": [
"LLMResult(generations=[[ChatGenerationChunk(text=\"Who's there?\", message=AIMessageChunk(content=\"Who's there?\"))]], llm_output={}, run=[RunInfo(run_id=UUID('1e9eaefc-9c99-4fa9-8297-ef9975d4751e'))])"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await chat.agenerate([messages])"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Who's there?"
]
},
{
"data": {
"text/plain": [
"AIMessageChunk(content=\"Who's there?\")"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat = ChatCohere(\n",
" streaming=True,\n",
" verbose=True,\n",
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
")\n",
"chat(messages)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -22,6 +22,7 @@ from langchain.chat_models.anyscale import ChatAnyscale
from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain.chat_models.bedrock import BedrockChat
from langchain.chat_models.cohere import ChatCohere
from langchain.chat_models.ernie import ErnieBotChat
from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.fireworks import ChatFireworks
@ -45,6 +46,7 @@ __all__ = [
"FakeListChatModel",
"PromptLayerChatOpenAI",
"ChatAnthropic",
"ChatCohere",
"ChatGooglePalm",
"ChatMLflowAIGateway",
"ChatOllama",

@ -0,0 +1,162 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import (
BaseChatModel,
_agenerate_from_stream,
_generate_from_stream,
)
from langchain.llms.cohere import BaseCohere
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
def get_role(message: BaseMessage) -> str:
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
return "User"
elif isinstance(message, AIMessage):
return "Chatbot"
elif isinstance(message, SystemMessage):
return "System"
else:
raise ValueError(f"Got unknown type {message}")
class ChatCohere(BaseChatModel, BaseCohere):
"""`Cohere` chat large language models.
To use, you should have the ``cohere`` python package installed, and the
environment variable ``COHERE_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.chat_models import ChatCohere
from langchain.schema import HumanMessage
chat = ChatCohere(model="foo")
result = chat([HumanMessage(content="Hello")])
print(result.content)
"""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "cohere-chat"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API."""
return {
"temperature": self.temperature,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}
def get_cohere_chat_request(
self, messages: List[BaseMessage], **kwargs: Any
) -> Dict[str, Any]:
return {
"message": messages[0].content,
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[1:]
],
**self._default_params,
**kwargs,
}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self.get_cohere_chat_request(messages, **kwargs)
stream = self.client.chat(**request, stream=True)
for data in stream:
if data.event_type == "text-generation":
delta = data.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = self.get_cohere_chat_request(messages, **kwargs)
stream = await self.async_client.chat(**request, stream=True)
async for data in stream:
if data.event_type == "text-generation":
delta = data.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
await run_manager.on_llm_new_token(delta)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
request = self.get_cohere_chat_request(messages, **kwargs)
response = self.client.chat(**request)
message = AIMessage(content=response.text)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await _agenerate_from_stream(stream_iter)
request = self.get_cohere_chat_request(messages, **kwargs)
response = self.client.chat(**request, stream=False)
message = AIMessage(content=response.text)
return ChatResult(generations=[ChatGeneration(message=message)])
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
return len(self.client.tokenize(text).tokens)

@ -17,7 +17,8 @@ from langchain.callbacks.manager import (
)
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, root_validator
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
@ -61,7 +62,42 @@ def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
return _completion_with_retry(**kwargs)
class Cohere(LLM):
class BaseCohere(Serializable):
client: Any #: :meta private:
async_client: Any #: :meta private:
model: Optional[str] = Field(default=None)
"""Model name to use."""
temperature: float = 0.75
"""A non-negative float that tunes the degree of randomness in generation."""
cohere_api_key: Optional[str] = None
stop: Optional[List[str]] = None
streaming: bool = Field(default=False)
"""Whether to stream the results."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import cohere
except ImportError:
raise ImportError(
"Could not import cohere python package. "
"Please install it with `pip install cohere`."
)
else:
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
values["client"] = cohere.Client(cohere_api_key)
values["async_client"] = cohere.AsyncClient(cohere_api_key)
return values
class Cohere(LLM, BaseCohere):
"""Cohere large language models.
To use, you should have the ``cohere`` python package installed, and the
@ -72,20 +108,13 @@ class Cohere(LLM):
.. code-block:: python
from langchain.llms import Cohere
cohere = Cohere(model="gptd-instruct-tft", cohere_api_key="my-api-key")
"""
client: Any #: :meta private:
async_client: Any #: :meta private:
model: Optional[str] = None
"""Model name to use."""
max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""
temperature: float = 0.75
"""A non-negative float that tunes the degree of randomness in generation."""
k: int = 0
"""Number of most likely tokens to consider at each step."""
@ -105,33 +134,11 @@ class Cohere(LLM):
max_retries: int = 10
"""Maximum number of retries to make when generating."""
cohere_api_key: Optional[str] = None
stop: Optional[List[str]] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
try:
import cohere
values["client"] = cohere.Client(cohere_api_key)
values["async_client"] = cohere.AsyncClient(cohere_api_key)
except ImportError:
raise ImportError(
"Could not import cohere python package. "
"Please install it with `pip install cohere`."
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API."""
@ -145,6 +152,10 @@ class Cohere(LLM):
"truncate": self.truncate,
}
@property
def lc_secrets(self) -> Dict[str, str]:
return {"cohere_api_key": "COHERE_API_KEY"}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""

Loading…
Cancel
Save