diff --git a/docs/extras/integrations/chat/cohere.ipynb b/docs/extras/integrations/chat/cohere.ipynb new file mode 100644 index 0000000000..e48da19d87 --- /dev/null +++ b/docs/extras/integrations/chat/cohere.ipynb @@ -0,0 +1,174 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bf733a38-db84-4363-89e2-de6735c37230", + "metadata": {}, + "source": [ + "# Cohere\n", + "\n", + "This notebook covers how to get started with Cohere chat models." + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatCohere\n", + "from langchain.schema import AIMessage, HumanMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "70cf04e8-423a-4ff6-8b09-f11fb711c817", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chat = ChatCohere()" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"Who's there?\")" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " HumanMessage(\n", + " content=\"knock knock\"\n", + " )\n", + "]\n", + "chat(messages)" + ] + }, + { + "cell_type": "markdown", + "id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c", + "metadata": {}, + "source": [ + "## `ChatCohere` also supports async and streaming functionality:" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.callbacks.manager import CallbackManager\n", + "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Who's there?" + ] + }, + { + "data": { + "text/plain": [ + "LLMResult(generations=[[ChatGenerationChunk(text=\"Who's there?\", message=AIMessageChunk(content=\"Who's there?\"))]], llm_output={}, run=[RunInfo(run_id=UUID('1e9eaefc-9c99-4fa9-8297-ef9975d4751e'))])" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "await chat.agenerate([messages])" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "025be980-e50d-4a68-93dc-c9c7b500ce34", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Who's there?" + ] + }, + { + "data": { + "text/plain": [ + "AIMessageChunk(content=\"Who's there?\")" + ] + }, + "execution_count": 63, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat = ChatCohere(\n", + " streaming=True,\n", + " verbose=True,\n", + " callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n", + ")\n", + "chat(messages)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index b3980b162d..92fb6728e0 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -22,6 +22,7 @@ from langchain.chat_models.anyscale import ChatAnyscale from langchain.chat_models.azure_openai import AzureChatOpenAI from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint from langchain.chat_models.bedrock import BedrockChat +from langchain.chat_models.cohere import ChatCohere from langchain.chat_models.ernie import ErnieBotChat from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.fireworks import ChatFireworks @@ -45,6 +46,7 @@ __all__ = [ "FakeListChatModel", "PromptLayerChatOpenAI", "ChatAnthropic", + "ChatCohere", "ChatGooglePalm", "ChatMLflowAIGateway", "ChatOllama", diff --git a/libs/langchain/langchain/chat_models/cohere.py b/libs/langchain/langchain/chat_models/cohere.py new file mode 100644 index 0000000000..2ae754e644 --- /dev/null +++ b/libs/langchain/langchain/chat_models/cohere.py @@ -0,0 +1,162 @@ +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import ( + BaseChatModel, + _agenerate_from_stream, + _generate_from_stream, +) +from langchain.llms.cohere import BaseCohere +from langchain.schema.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult + + +def get_role(message: BaseMessage) -> str: + if isinstance(message, ChatMessage) or isinstance(message, HumanMessage): + return "User" + elif isinstance(message, AIMessage): + return "Chatbot" + elif isinstance(message, SystemMessage): + return "System" + else: + raise ValueError(f"Got unknown type {message}") + + +class ChatCohere(BaseChatModel, BaseCohere): + """`Cohere` chat large language models. + + To use, you should have the ``cohere`` python package installed, and the + environment variable ``COHERE_API_KEY`` set with your API key, or pass + it as a named parameter to the constructor. + + Example: + .. code-block:: python + + from langchain.chat_models import ChatCohere + from langchain.schema import HumanMessage + + chat = ChatCohere(model="foo") + result = chat([HumanMessage(content="Hello")]) + print(result.content) + """ + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + arbitrary_types_allowed = True + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "cohere-chat" + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling Cohere API.""" + return { + "temperature": self.temperature, + } + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return {**{"model": self.model}, **self._default_params} + + def get_cohere_chat_request( + self, messages: List[BaseMessage], **kwargs: Any + ) -> Dict[str, Any]: + return { + "message": messages[0].content, + "chat_history": [ + {"role": get_role(x), "message": x.content} for x in messages[1:] + ], + **self._default_params, + **kwargs, + } + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + request = self.get_cohere_chat_request(messages, **kwargs) + stream = self.client.chat(**request, stream=True) + + for data in stream: + if data.event_type == "text-generation": + delta = data.text + yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + run_manager.on_llm_new_token(delta) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + request = self.get_cohere_chat_request(messages, **kwargs) + stream = await self.async_client.chat(**request, stream=True) + + async for data in stream: + if data.event_type == "text-generation": + delta = data.text + yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + await run_manager.on_llm_new_token(delta) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return _generate_from_stream(stream_iter) + + request = self.get_cohere_chat_request(messages, **kwargs) + response = self.client.chat(**request) + + message = AIMessage(content=response.text) + return ChatResult(generations=[ChatGeneration(message=message)]) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await _agenerate_from_stream(stream_iter) + + request = self.get_cohere_chat_request(messages, **kwargs) + response = self.client.chat(**request, stream=False) + + message = AIMessage(content=response.text) + return ChatResult(generations=[ChatGeneration(message=message)]) + + def get_num_tokens(self, text: str) -> int: + """Calculate number of tokens.""" + return len(self.client.tokenize(text).tokens) diff --git a/libs/langchain/langchain/llms/cohere.py b/libs/langchain/langchain/llms/cohere.py index 257a69c9ff..454692c0e7 100644 --- a/libs/langchain/langchain/llms/cohere.py +++ b/libs/langchain/langchain/llms/cohere.py @@ -17,7 +17,8 @@ from langchain.callbacks.manager import ( ) from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator +from langchain.load.serializable import Serializable +from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) @@ -61,7 +62,42 @@ def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any: return _completion_with_retry(**kwargs) -class Cohere(LLM): +class BaseCohere(Serializable): + client: Any #: :meta private: + async_client: Any #: :meta private: + model: Optional[str] = Field(default=None) + """Model name to use.""" + + temperature: float = 0.75 + """A non-negative float that tunes the degree of randomness in generation.""" + + cohere_api_key: Optional[str] = None + + stop: Optional[List[str]] = None + + streaming: bool = Field(default=False) + """Whether to stream the results.""" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import cohere + except ImportError: + raise ImportError( + "Could not import cohere python package. " + "Please install it with `pip install cohere`." + ) + else: + cohere_api_key = get_from_dict_or_env( + values, "cohere_api_key", "COHERE_API_KEY" + ) + values["client"] = cohere.Client(cohere_api_key) + values["async_client"] = cohere.AsyncClient(cohere_api_key) + return values + + +class Cohere(LLM, BaseCohere): """Cohere large language models. To use, you should have the ``cohere`` python package installed, and the @@ -72,20 +108,13 @@ class Cohere(LLM): .. code-block:: python from langchain.llms import Cohere + cohere = Cohere(model="gptd-instruct-tft", cohere_api_key="my-api-key") """ - client: Any #: :meta private: - async_client: Any #: :meta private: - model: Optional[str] = None - """Model name to use.""" - max_tokens: int = 256 """Denotes the number of tokens to predict per generation.""" - temperature: float = 0.75 - """A non-negative float that tunes the degree of randomness in generation.""" - k: int = 0 """Number of most likely tokens to consider at each step.""" @@ -105,33 +134,11 @@ class Cohere(LLM): max_retries: int = 10 """Maximum number of retries to make when generating.""" - cohere_api_key: Optional[str] = None - - stop: Optional[List[str]] = None - class Config: """Configuration for this pydantic object.""" extra = Extra.forbid - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - cohere_api_key = get_from_dict_or_env( - values, "cohere_api_key", "COHERE_API_KEY" - ) - try: - import cohere - - values["client"] = cohere.Client(cohere_api_key) - values["async_client"] = cohere.AsyncClient(cohere_api_key) - except ImportError: - raise ImportError( - "Could not import cohere python package. " - "Please install it with `pip install cohere`." - ) - return values - @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling Cohere API.""" @@ -145,6 +152,10 @@ class Cohere(LLM): "truncate": self.truncate, } + @property + def lc_secrets(self) -> Dict[str, str]: + return {"cohere_api_key": "COHERE_API_KEY"} + @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters."""