mirror of https://github.com/hwchase17/langchain
Add cohere /chat integration (#11389)
Add cohere /chat integration and an iPython notebook to demonstrate the addition. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>pull/11443/head
parent
ca346011b7
commit
2ff91a46c0
@ -0,0 +1,174 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Cohere\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook covers how to get started with Cohere chat models."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 54,
|
||||||
|
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models import ChatCohere\n",
|
||||||
|
"from langchain.schema import AIMessage, HumanMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 55,
|
||||||
|
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatCohere()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 56,
|
||||||
|
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\"Who's there?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 56,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" HumanMessage(\n",
|
||||||
|
" content=\"knock knock\"\n",
|
||||||
|
" )\n",
|
||||||
|
"]\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## `ChatCohere` also supports async and streaming functionality:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 57,
|
||||||
|
"id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.callbacks.manager import CallbackManager\n",
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 64,
|
||||||
|
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Who's there?"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"LLMResult(generations=[[ChatGenerationChunk(text=\"Who's there?\", message=AIMessageChunk(content=\"Who's there?\"))]], llm_output={}, run=[RunInfo(run_id=UUID('1e9eaefc-9c99-4fa9-8297-ef9975d4751e'))])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 64,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"await chat.agenerate([messages])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 63,
|
||||||
|
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Who's there?"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessageChunk(content=\"Who's there?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 63,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatCohere(\n",
|
||||||
|
" streaming=True,\n",
|
||||||
|
" verbose=True,\n",
|
||||||
|
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
||||||
|
")\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,162 @@
|
|||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.base import (
|
||||||
|
BaseChatModel,
|
||||||
|
_agenerate_from_stream,
|
||||||
|
_generate_from_stream,
|
||||||
|
)
|
||||||
|
from langchain.llms.cohere import BaseCohere
|
||||||
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
|
||||||
|
|
||||||
|
def get_role(message: BaseMessage) -> str:
|
||||||
|
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
|
||||||
|
return "User"
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
return "Chatbot"
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
return "System"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCohere(BaseChatModel, BaseCohere):
|
||||||
|
"""`Cohere` chat large language models.
|
||||||
|
|
||||||
|
To use, you should have the ``cohere`` python package installed, and the
|
||||||
|
environment variable ``COHERE_API_KEY`` set with your API key, or pass
|
||||||
|
it as a named parameter to the constructor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatCohere
|
||||||
|
from langchain.schema import HumanMessage
|
||||||
|
|
||||||
|
chat = ChatCohere(model="foo")
|
||||||
|
result = chat([HumanMessage(content="Hello")])
|
||||||
|
print(result.content)
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "cohere-chat"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling Cohere API."""
|
||||||
|
return {
|
||||||
|
"temperature": self.temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {**{"model": self.model}, **self._default_params}
|
||||||
|
|
||||||
|
def get_cohere_chat_request(
|
||||||
|
self, messages: List[BaseMessage], **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"message": messages[0].content,
|
||||||
|
"chat_history": [
|
||||||
|
{"role": get_role(x), "message": x.content} for x in messages[1:]
|
||||||
|
],
|
||||||
|
**self._default_params,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||||
|
stream = self.client.chat(**request, stream=True)
|
||||||
|
|
||||||
|
for data in stream:
|
||||||
|
if data.event_type == "text-generation":
|
||||||
|
delta = data.text
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(delta)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||||
|
stream = await self.async_client.chat(**request, stream=True)
|
||||||
|
|
||||||
|
async for data in stream:
|
||||||
|
if data.event_type == "text-generation":
|
||||||
|
delta = data.text
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(delta)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._stream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return _generate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||||
|
response = self.client.chat(**request)
|
||||||
|
|
||||||
|
message = AIMessage(content=response.text)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._astream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await _agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
request = self.get_cohere_chat_request(messages, **kwargs)
|
||||||
|
response = self.client.chat(**request, stream=False)
|
||||||
|
|
||||||
|
message = AIMessage(content=response.text)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
"""Calculate number of tokens."""
|
||||||
|
return len(self.client.tokenize(text).tokens)
|
Loading…
Reference in New Issue