mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
feat(community): add tools support for litellm (#23906)
I used the following example to validate the behavior ```python from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import ConfigurableField from langchain_anthropic import ChatAnthropic from langchain_community.chat_models import ChatLiteLLM from langchain_core.tools import tool from langchain.agents import create_tool_calling_agent, AgentExecutor @tool def multiply(x: float, y: float) -> float: """Multiply 'x' times 'y'.""" return x * y @tool def exponentiate(x: float, y: float) -> float: """Raise 'x' to the 'y'.""" return x**y @tool def add(x: float, y: float) -> float: """Add 'x' and 'y'.""" return x + y prompt = ChatPromptTemplate.from_messages([ ("system", "you're a helpful assistant"), ("human", "{input}"), ("placeholder", "{agent_scratchpad}"), ]) tools = [multiply, exponentiate, add] llm = ChatAnthropic(model="claude-3-sonnet-20240229", temperature=0) # llm = ChatLiteLLM(model="claude-3-sonnet-20240229", temperature=0) agent = create_tool_calling_agent(llm, tools, prompt) agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) agent_executor.invoke({"input": "what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241", }) ``` `ChatAnthropic` version works: ``` > Entering new AgentExecutor chain... Invoking: `exponentiate` with `{'x': 5, 'y': 2.743}` responded: [{'text': 'To calculate 3 + 5^2.743, we can use the "exponentiate" and "add" tools:', 'type': 'text', 'index': 0}, {'id': 'toolu_01Gf54DFTkfLMJQX3TXffmxe', 'input': {}, 'name': 'exponentiate', 'type': 'tool_use', 'index': 1, 'partial_json': '{"x": 5, "y": 2.743}'}] 82.65606421491815 Invoking: `add` with `{'x': 3, 'y': 82.65606421491815}` responded: [{'id': 'toolu_01XUq9S56GT3Yv2N1KmNmmWp', 'input': {}, 'name': 'add', 'type': 'tool_use', 'index': 0, 'partial_json': '{"x": 3, "y": 82.65606421491815}'}] 85.65606421491815 Invoking: `add` with `{'x': 17.24, 'y': -918.1241}` responded: [{'text': '\n\nSo 3 + 5^2.743 = 85.66\n\nTo calculate 17.24 - 918.1241, we can use:', 'type': 'text', 'index': 0}, {'id': 'toolu_01BkXTwP7ec9JKYtZPy5JKjm', 'input': {}, 'name': 'add', 'type': 'tool_use', 'index': 1, 'partial_json': '{"x": 17.24, "y": -918.1241}'}] -900.8841[{'text': '\n\nTherefore, 17.24 - 918.1241 = -900.88', 'type': 'text', 'index': 0}] > Finished chain. ``` While `ChatLiteLLM` version doesn't. But with the changes in this PR, along with: - https://github.com/langchain-ai/langchain/pull/23823 - https://github.com/BerriAI/litellm/pull/4554 The result is _almost_ the same: ``` > Entering new AgentExecutor chain... Invoking: `exponentiate` with `{'x': 5, 'y': 2.743}` responded: To calculate 3 + 5^2.743, we can use the "exponentiate" and "add" tools: 82.65606421491815 Invoking: `add` with `{'x': 3, 'y': 82.65606421491815}` 85.65606421491815 Invoking: `add` with `{'x': 17.24, 'y': -918.1241}` responded: So 3 + 5^2.743 = 85.66 To calculate 17.24 - 918.1241, we can use: -900.8841 Therefore, 17.24 - 918.1241 = -900.88 > Finished chain. ``` If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
bfb7f8d40a
commit
c2706cfb9e
@ -40,6 +40,7 @@ jinja2>=3,<4
|
||||
jq>=1.4.1,<2
|
||||
jsonschema>1
|
||||
keybert>=0.8.5
|
||||
litellm>=1.30,<=1.39.5
|
||||
lxml>=4.9.3,<6.0
|
||||
markdownify>=0.11.6,<0.12
|
||||
motor>=3.3.1,<4
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
@ -42,6 +43,9 @@ from langchain_core.messages import (
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolCall,
|
||||
ToolCallChunk,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
@ -132,10 +136,30 @@ def _convert_delta_to_message_chunk(
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
|
||||
tool_call_chunks = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
ToolCallChunk(
|
||||
name=rtc["function"].get("name"),
|
||||
args=rtc["function"].get("arguments"),
|
||||
id=rtc.get("id"),
|
||||
index=rtc["index"],
|
||||
)
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
@ -146,23 +170,41 @@ def _convert_delta_to_message_chunk(
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"id": tool_call["id"],
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any] = {"content": message.content}
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
message_dict["role"] = message.role
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
message_dict["role"] = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
message_dict["role"] = "assistant"
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
|
||||
]
|
||||
elif "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
message_dict["role"] = "system"
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
message_dict["role"] = "function"
|
||||
message_dict["name"] = message.name
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict["role"] = "tool"
|
||||
message_dict["tool_call_id"] = message.tool_call_id
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
@ -360,6 +402,8 @@ class ChatLiteLLM(BaseChatModel):
|
||||
for chunk in self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
@ -384,6 +428,8 @@ class ChatLiteLLM(BaseChatModel):
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
|
@ -0,0 +1,23 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
from langchain_community.chat_models.litellm import ChatLiteLLM
|
||||
|
||||
|
||||
class TestLiteLLMStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatLiteLLM
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "ollama/mistral"}
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||
def test_usage_metadata(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata(model)
|
24
libs/community/tests/unit_tests/chat_models/test_litellm.py
Normal file
24
libs/community/tests/unit_tests/chat_models/test_litellm.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
from langchain_community.chat_models.litellm import ChatLiteLLM
|
||||
|
||||
|
||||
@pytest.mark.requires("litellm")
|
||||
class TestLiteLLMStandard(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatLiteLLM
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"api_key": "test_api_key"}
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||
def test_standard_params(self, model: BaseChatModel) -> None:
|
||||
super().test_standard_params(model)
|
Loading…
Reference in New Issue
Block a user