diff --git a/docs/extras/integrations/chat/litellm.ipynb b/docs/extras/integrations/chat/litellm.ipynb new file mode 100644 index 0000000000..977f5f1554 --- /dev/null +++ b/docs/extras/integrations/chat/litellm.ipynb @@ -0,0 +1,185 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "bf733a38-db84-4363-89e2-de6735c37230", + "metadata": {}, + "source": [ + "# 🚅 LiteLLM\n", + "\n", + "[LiteLLM](https://github.com/BerriAI/litellm) is a library that simplifies calling Anthropic, Azure, Huggingface, Replicate, etc. \n", + "\n", + "This notebook covers how to get started with using Langchain + the LiteLLM I/O library. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatLiteLLM\n", + "from langchain.prompts.chat import (\n", + " ChatPromptTemplate,\n", + " SystemMessagePromptTemplate,\n", + " AIMessagePromptTemplate,\n", + " HumanMessagePromptTemplate,\n", + ")\n", + "from langchain.schema import AIMessage, HumanMessage, SystemMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "70cf04e8-423a-4ff6-8b09-f11fb711c817", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chat = ChatLiteLLM(model=\"gpt-3.5-turbo\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " HumanMessage(\n", + " content=\"Translate this sentence from English to French. I love programming.\"\n", + " )\n", + "]\n", + "chat(messages)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c361ab1e-8c0c-4206-9e3c-9d1424a12b9c", + "metadata": {}, + "source": [ + "## `ChatLiteLLM` also supports async and streaming functionality:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.callbacks.manager import CallbackManager\n", + "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "LLMResult(generations=[[ChatGeneration(text=\" J'aime programmer.\", generation_info=None, message=AIMessage(content=\" J'aime programmer.\", additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('8cc8fb68-1c35-439c-96a0-695036a93652'))])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "await chat.agenerate([messages])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "025be980-e50d-4a68-93dc-c9c7b500ce34", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " J'aime la programmation." + ] + }, + { + "data": { + "text/plain": [ + "AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat = ChatLiteLLM(\n", + " streaming=True,\n", + " verbose=True,\n", + " callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n", + ")\n", + "chat(messages)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c253883f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index f3b1cbe161..a115ce816f 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -24,6 +24,7 @@ from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.google_palm import ChatGooglePalm from langchain.chat_models.human import HumanInputChatModel from langchain.chat_models.jinachat import JinaChat +from langchain.chat_models.litellm import ChatLiteLLM from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI @@ -41,4 +42,5 @@ __all__ = [ "JinaChat", "HumanInputChatModel", "ChatAnyscale", + "ChatLiteLLM", ] diff --git a/libs/langchain/langchain/chat_models/litellm.py b/libs/langchain/langchain/chat_models/litellm.py new file mode 100644 index 0000000000..165e1d1a6e --- /dev/null +++ b/libs/langchain/langchain/chat_models/litellm.py @@ -0,0 +1,449 @@ +"""Wrapper around LiteLLM's model I/O library.""" +from __future__ import annotations + +import logging +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + Union, +) + +from pydantic import BaseModel, Field, root_validator + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.llms.base import create_base_retry_decorator +from langchain.schema import ( + ChatGeneration, + ChatResult, +) +from langchain.schema.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, +) +from langchain.schema.output import ChatGenerationChunk +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +class ChatLiteLLMException(Exception): + """Error raised when there is an issue with the LiteLLM I/O Library""" + + +def _truncate_at_stop_tokens( + text: str, + stop: Optional[List[str]], +) -> str: + """Truncates text at the earliest stop token found.""" + if stop is None: + return text + + for stop_token in stop: + stop_token_idx = text.find(stop_token) + if stop_token_idx != -1: + text = text[:stop_token_idx] + return text + + +class FunctionMessage(BaseMessage): + """A Message for passing the result of executing a function back to a model.""" + + name: str + """The name of the function that was executed.""" + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "function" + + +class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): + pass + + +def _create_retry_decorator( + llm: ChatLiteLLM, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: + """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions""" + import openai + + errors = [ + openai.error.Timeout, + openai.error.APIError, + openai.error.APIConnectionError, + openai.error.RateLimitError, + openai.error.ServiceUnavailableError, + ] + return create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager + ) + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + # Fix for azure + # Also OpenAI returns None for tool invocations + content = _dict.get("content", "") or "" + if _dict.get("function_call"): + additional_kwargs = {"function_call": dict(_dict["function_call"])} + else: + additional_kwargs = {} + return AIMessage(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=_dict["content"]) + elif role == "function": + return FunctionMessage(content=_dict["content"], name=_dict["name"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + +async def acompletion_with_retry( + llm: ChatLiteLLM, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the async completion call.""" + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + # Use OpenAI's async api https://github.com/openai/openai-python#async-api + return await llm.client.acreate(**kwargs) + + return await _completion_with_retry(**kwargs) + + +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("role") + content = _dict.get("content") or "" + if _dict.get("function_call"): + additional_kwargs = {"function_call": dict(_dict["function_call"])} + else: + additional_kwargs = {} + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + elif role == "function" or default_class == FunctionMessageChunk: + return FunctionMessageChunk(content=content, name=_dict["name"]) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) + else: + return default_class(content=content) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } + else: + raise ValueError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +class ChatLiteLLM(BaseChatModel, BaseModel): + """Wrapper around the LiteLLM Model I/O library. + + To use you must have the google.generativeai Python package installed and + either: + + 1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or + 2. Pass your API key using the google_api_key kwarg to the ChatGoogle + constructor. + + Example: + .. code-block:: python + + from langchain.chat_models import ChatGooglePalm + chat = ChatGooglePalm() + + """ + + client: Any #: :meta private: + model_name: str = "gpt-3.5-turbo" + """Model name to use.""" + openai_api_key: Optional[str] = None + azure_api_key: Optional[str] = None + anthropic_api_key: Optional[str] = None + replicate_api_key: Optional[str] = None + cohere_api_key: Optional[str] = None + openrouter_api_key: Optional[str] = None + streaming: bool = False + api_base: Optional[str] = None + organization: Optional[str] = None + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + temperature: Optional[float] = None + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Run inference with this temperature. Must by in the closed + interval [0.0, 1.0].""" + top_p: Optional[float] = None + """Decode using nucleus sampling: consider the smallest set of tokens whose + probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" + top_k: Optional[int] = None + """Decode using top-k sampling: consider the set of top_k most probable tokens. + Must be positive.""" + n: int = 1 + """Number of chat completions to generate for each prompt. Note that the API may + not return the full n completions if duplicates are generated.""" + max_tokens: int = 256 + + max_retries: int = 6 + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return { + "model": self.model_name, + "force_timeout": self.request_timeout, + "max_tokens": self.max_tokens, + "stream": self.streaming, + "n": self.n, + "temperature": self.temperature, + **self.model_kwargs, + } + + @property + def _client_params(self) -> Dict[str, Any]: + """Get the parameters used for the openai client.""" + self.client.api_base = self.api_base + self.client.organization = self.organization + creds: Dict[str, Any] = { + "model": self.model_name, + "force_timeout": self.request_timeout, + } + return {**self._default_params, **creds} + + def completion_with_retry( + self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any + ) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return self.client.completion(**kwargs) + + return _completion_with_retry(**kwargs) + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate api key, python package exists, temperature, top_p, and top_k.""" + try: + import litellm + except ImportError: + raise ChatLiteLLMException( + "Could not import google.generativeai python package. " + "Please install it with `pip install google-generativeai`" + ) + + values["openai_api_key"] = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY", default="" + ) + values["azure_api_key"] = get_from_dict_or_env( + values, "azure_api_key", "AZURE_API_KEY", default="" + ) + values["anthropic_api_key"] = get_from_dict_or_env( + values, "anthropic_api_key", "ANTHROPIC_API_KEY", default="" + ) + values["replicate_api_key"] = get_from_dict_or_env( + values, "replicate_api_key", "REPLICATE_API_KEY", default="" + ) + values["openrouter_api_key"] = get_from_dict_or_env( + values, "openrouter_api_key", "OPENROUTER_API_KEY", default="" + ) + values["client"] = litellm + + if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: + raise ValueError("temperature must be in the range [0.0, 1.0]") + + if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: + raise ValueError("top_p must be in the range [0.0, 1.0]") + + if values["top_k"] is not None and values["top_k"] <= 0: + raise ValueError("top_k must be positive") + + return values + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ChatResult: + if stream if stream is not None else self.streaming: + generation: Optional[ChatGenerationChunk] = None + for chunk in self._stream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation]) + + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + response = self.completion_with_retry( + messages=message_dicts, run_manager=run_manager, **params + ) + return self._create_chat_result(response) + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + for res in response["choices"]: + message = _convert_dict_to_message(res["message"]) + gen = ChatGeneration( + message=message, + generation_info=dict(finish_reason=res.get("finish_reason")), + ) + generations.append(gen) + token_usage = response.get("usage", {}) + llm_output = {"token_usage": token_usage, "model_name": self.model_name} + return ChatResult(generations=generations, llm_output=llm_output) + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = self._client_params + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + message_dicts = [_convert_message_to_dict(m) for m in messages] + return message_dicts, params + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + + default_chunk_class = AIMessageChunk + for chunk in self.completion_with_retry( + messages=message_dicts, run_manager=run_manager, **params + ): + if len(chunk["choices"]) == 0: + continue + delta = chunk["choices"][0]["delta"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + default_chunk_class = chunk.__class__ + yield ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token(chunk.content) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + + default_chunk_class = AIMessageChunk + async for chunk in await acompletion_with_retry( + self, messages=message_dicts, run_manager=run_manager, **params + ): + if len(chunk["choices"]) == 0: + continue + delta = chunk["choices"][0]["delta"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + default_chunk_class = chunk.__class__ + yield ChatGenerationChunk(message=chunk) + if run_manager: + await run_manager.on_llm_new_token(chunk.content) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ChatResult: + if stream if stream is not None else self.streaming: + generation: Optional[ChatGenerationChunk] = None + async for chunk in self._astream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation]) + + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + response = await acompletion_with_retry( + self, messages=message_dicts, run_manager=run_manager, **params + ) + return self._create_chat_result(response) + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return { + "model_name": self.model_name, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "n": self.n, + } + + @property + def _llm_type(self) -> str: + return "litellm-chat" diff --git a/libs/langchain/tests/integration_tests/chat_models/test_litellm.py b/libs/langchain/tests/integration_tests/chat_models/test_litellm.py new file mode 100644 index 0000000000..4f252453da --- /dev/null +++ b/libs/langchain/tests/integration_tests/chat_models/test_litellm.py @@ -0,0 +1,64 @@ +"""Test Anthropic API wrapper.""" +from typing import List + +from langchain.callbacks.manager import ( + CallbackManager, +) +from langchain.chat_models.litellm import ChatLiteLLM +from langchain.schema import ( + ChatGeneration, + LLMResult, +) +from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + + +def test_litellm_call() -> None: + """Test valid call to litellm.""" + chat = ChatLiteLLM( + model="test", + ) + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_litellm_generate() -> None: + """Test generate method of anthropic.""" + chat = ChatLiteLLM(model="test") + chat_messages: List[List[BaseMessage]] = [ + [HumanMessage(content="How many toes do dogs have?")] + ] + messages_copy = [messages.copy() for messages in chat_messages] + result: LLMResult = chat.generate(chat_messages) + assert isinstance(result, LLMResult) + for response in result.generations[0]: + assert isinstance(response, ChatGeneration) + assert isinstance(response.text, str) + assert response.text == response.message.content + assert chat_messages == messages_copy + + +def test_litellm_streaming() -> None: + """Test streaming tokens from anthropic.""" + chat = ChatLiteLLM(model="test", streaming=True) + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_litellm_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatLiteLLM( + model="test", + streaming=True, + callback_manager=callback_manager, + verbose=True, + ) + message = HumanMessage(content="Write me a sentence with 10 words.") + chat([message]) + assert callback_handler.llm_streams > 1