mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community[patch]: support bind_tools for ChatMlflow (#24547)
Thank you for contributing to LangChain! - [x] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - **Description:** Support ChatMlflow.bind_tools method Tested in Databricks: <img width="836" alt="image" src="https://github.com/user-attachments/assets/fa28ef50-0110-4698-8eda-4faf6f0b9ef8"> - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
This commit is contained in:
parent
769c3bb838
commit
1827bb4042
@ -36,7 +36,7 @@
|
||||
"### Model features\n",
|
||||
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
|
||||
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
|
||||
"| ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
|
||||
"\n",
|
||||
"### Supported Methods\n",
|
||||
"\n",
|
||||
@ -395,6 +395,66 @@
|
||||
"chat_model_external.invoke(\"How to use Databricks?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Function calling on Databricks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Databricks Function Calling is OpenAI-compatible and is only available during model serving as part of Foundation Model APIs.\n",
|
||||
"\n",
|
||||
"See [Databricks function calling introduction](https://docs.databricks.com/en/machine-learning/model-serving/function-calling.html#supported-models) for supported models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.chat_models.databricks import ChatDatabricks\n",
|
||||
"\n",
|
||||
"llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")\n",
|
||||
"tools = [\n",
|
||||
" {\n",
|
||||
" \"type\": \"function\",\n",
|
||||
" \"function\": {\n",
|
||||
" \"name\": \"get_current_weather\",\n",
|
||||
" \"description\": \"Get the current weather in a given location\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"location\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
||||
" },\n",
|
||||
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# supported tool_choice values: \"auto\", \"required\", \"none\", function name in string format,\n",
|
||||
"# or a dictionary as {\"type\": \"function\", \"function\": {\"name\": <<tool_name>>}}\n",
|
||||
"model = llm.bind_tools(tools, tool_choice=\"auto\")\n",
|
||||
"\n",
|
||||
"messages = [{\"role\": \"user\", \"content\": \"What is the current temperature of Chicago?\"}]\n",
|
||||
"print(model.invoke(messages))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"See [Databricks Unity Catalog](docs/integrations/tools/databricks.ipynb) about how to use UC functions in chains."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -38,7 +38,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet databricks-sdk langchain-community langchain-openai"
|
||||
"%pip install --upgrade --quiet databricks-sdk langchain-community mlflow"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -47,9 +47,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"from langchain_community.chat_models.databricks import ChatDatabricks\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\")"
|
||||
"llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -91,3 +91,4 @@ vdms>=0.0.20
|
||||
xata>=1.0.0a7,<2
|
||||
xmltodict>=0.13.0,<0.14
|
||||
nanopq==0.2.1
|
||||
mlflow[genai]>=2.14.0
|
||||
|
@ -1,5 +1,19 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
@ -15,15 +29,27 @@ from langchain_core.messages import (
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
InvalidToolCall,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call_chunk
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
make_invalid_tool_call,
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -228,11 +254,32 @@ class ChatMlflow(BaseChatModel):
|
||||
@staticmethod
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
content = _dict["content"]
|
||||
content = cast(str, _dict.get("content"))
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=content)
|
||||
content = content or ""
|
||||
additional_kwargs: Dict = {}
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
tool_calls.append(
|
||||
parse_tool_call(raw_tool_call, return_id=True)
|
||||
)
|
||||
except Exception as e:
|
||||
invalid_tool_calls.append(
|
||||
make_invalid_tool_call(raw_tool_call, str(e))
|
||||
)
|
||||
return AIMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=_dict.get("id"),
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
else:
|
||||
@ -243,13 +290,38 @@ class ChatMlflow(BaseChatModel):
|
||||
_dict: Mapping[str, Any], default_role: str
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role", default_role)
|
||||
content = _dict["content"]
|
||||
content = _dict.get("content") or ""
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant":
|
||||
return AIMessageChunk(content=content)
|
||||
additional_kwargs: Dict = {}
|
||||
tool_call_chunks = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
args=rtc["function"].get("arguments"),
|
||||
id=rtc.get("id"),
|
||||
index=rtc["index"],
|
||||
)
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=_dict.get("id"),
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "tool":
|
||||
return ToolMessageChunk(
|
||||
content=content, tool_call_id=_dict["tool_call_id"], id=_dict.get("id")
|
||||
)
|
||||
else:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
|
||||
@ -262,14 +334,47 @@ class ChatMlflow(BaseChatModel):
|
||||
|
||||
@staticmethod
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict = {"content": message.content}
|
||||
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
||||
message_dict["name"] = name
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
message_dict["role"] = message.role
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
message_dict["role"] = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
message_dict["role"] = "assistant"
|
||||
if message.tool_calls or message.invalid_tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
|
||||
] + [
|
||||
_lc_invalid_tool_call_to_openai_tool_call(tc)
|
||||
for tc in message.invalid_tool_calls
|
||||
] # type: ignore[assignment]
|
||||
elif "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
tool_call_supported_props = {"id", "type", "function"}
|
||||
message_dict["tool_calls"] = [
|
||||
{
|
||||
k: v
|
||||
for k, v in tool_call.items() # type: ignore[union-attr]
|
||||
if k in tool_call_supported_props
|
||||
}
|
||||
for tool_call in message_dict["tool_calls"]
|
||||
]
|
||||
else:
|
||||
pass
|
||||
# If tool calls present, content null value should be None not empty string.
|
||||
if "tool_calls" in message_dict:
|
||||
message_dict["content"] = message_dict["content"] or None # type: ignore[assignment]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
message_dict["role"] = "system"
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict["role"] = "tool"
|
||||
message_dict["tool_call_id"] = message.tool_call_id
|
||||
supported_props = {"content", "role", "tool_call_id"}
|
||||
message_dict = {
|
||||
k: v for k, v in message_dict.items() if k in supported_props
|
||||
}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
raise ValueError(
|
||||
"Function messages are not supported by Databricks. Please"
|
||||
@ -280,12 +385,6 @@ class ChatMlflow(BaseChatModel):
|
||||
|
||||
if "function_call" in message.additional_kwargs:
|
||||
ChatMlflow._raise_functions_not_supported()
|
||||
if message.additional_kwargs:
|
||||
logger.warning(
|
||||
"Additional message arguments are unsupported by Databricks"
|
||||
" and will be ignored: %s",
|
||||
message.additional_kwargs,
|
||||
)
|
||||
return message_dict
|
||||
|
||||
@staticmethod
|
||||
@ -302,3 +401,89 @@ class ChatMlflow(BaseChatModel):
|
||||
|
||||
usage = response.get("usage", {})
|
||||
return ChatResult(generations=generations, llm_output=usage)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Assumes model is compatible with OpenAI tool-calling API.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
tool_choice: Which tool to require the model to call.
|
||||
Options are:
|
||||
name of the tool (str): calls corresponding tool;
|
||||
"auto": automatically selects a tool (including no tool);
|
||||
"none": model does not generate any tool calls and instead must
|
||||
generate a standard assistant message;
|
||||
"required": the model picks the most relevant tool in tools and
|
||||
must generate a tool call;
|
||||
|
||||
or a dict of the form:
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice:
|
||||
if isinstance(tool_choice, str):
|
||||
# tool_choice is a tool/function name
|
||||
if tool_choice not in ("auto", "none", "required"):
|
||||
tool_choice = {
|
||||
"type": "function",
|
||||
"function": {"name": tool_choice},
|
||||
}
|
||||
elif isinstance(tool_choice, dict):
|
||||
tool_names = [
|
||||
formatted_tool["function"]["name"]
|
||||
for formatted_tool in formatted_tools
|
||||
]
|
||||
if not any(
|
||||
tool_name == tool_choice["function"]["name"]
|
||||
for tool_name in tool_names
|
||||
):
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_choice} was specified, but the only "
|
||||
f"provided tools were {tool_names}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
||||
f"Received: {tool_choice}"
|
||||
)
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
|
||||
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"id": tool_call["id"],
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _lc_invalid_tool_call_to_openai_tool_call(
|
||||
invalid_tool_call: InvalidToolCall,
|
||||
) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"id": invalid_tool_call["id"],
|
||||
"function": {
|
||||
"name": invalid_tool_call["name"],
|
||||
"arguments": invalid_tool_call["args"],
|
||||
},
|
||||
}
|
||||
|
423
libs/community/tests/unit_tests/chat_models/test_mlflow.py
Normal file
423
libs/community/tests/unit_tests/chat_models/test_mlflow.py
Normal file
@ -0,0 +1,423 @@
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain.agents import AgentExecutor, create_tool_calling_agent
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolCallChunk,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, BaseModel
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
from langchain_community.chat_models.mlflow import ChatMlflow
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm() -> ChatMlflow:
|
||||
return ChatMlflow(
|
||||
endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_input() -> List[BaseMessage]:
|
||||
data = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.",
|
||||
},
|
||||
{"role": "user", "content": "36939 * 8922.4"},
|
||||
]
|
||||
return [ChatMlflow._convert_dict_to_message(value) for value in data]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prediction() -> dict:
|
||||
return {
|
||||
"id": "chatcmpl_id",
|
||||
"object": "chat.completion",
|
||||
"created": 1721875529,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "To calculate the result of 36939 multiplied by 8922.4, "
|
||||
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_predict_stream_result() -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "36939"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "x"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "8922.4"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": " = "},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "329,511,111.6"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": ""},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("mlflow")
|
||||
def test_chat_mlflow_predict(
|
||||
llm: ChatMlflow, model_input: List[BaseMessage], mock_prediction: dict
|
||||
) -> None:
|
||||
mock_client = MagicMock()
|
||||
llm._client = mock_client
|
||||
|
||||
def mock_predict(*args: Any, **kwargs: Any) -> Any:
|
||||
return mock_prediction
|
||||
|
||||
mock_client.predict = mock_predict
|
||||
res = llm.invoke(model_input)
|
||||
assert res.content == mock_prediction["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
@pytest.mark.requires("mlflow")
|
||||
def test_chat_mlflow_stream(
|
||||
llm: ChatMlflow,
|
||||
model_input: List[BaseMessage],
|
||||
mock_predict_stream_result: List[dict],
|
||||
) -> None:
|
||||
mock_client = MagicMock()
|
||||
llm._client = mock_client
|
||||
|
||||
def mock_stream(*args: Any, **kwargs: Any) -> Any:
|
||||
yield from mock_predict_stream_result
|
||||
|
||||
mock_client.predict_stream = mock_stream
|
||||
for i, res in enumerate(llm.stream(model_input)):
|
||||
assert (
|
||||
res.content
|
||||
== mock_predict_stream_result[i]["choices"][0]["delta"]["content"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("mlflow")
|
||||
@pytest.mark.skipif(
|
||||
_PYDANTIC_MAJOR_VERSION < 2,
|
||||
reason="The tool mock is not compatible with pydantic 1.x",
|
||||
)
|
||||
def test_chat_mlflow_bind_tools(
|
||||
llm: ChatMlflow, mock_predict_stream_result: List[dict]
|
||||
) -> None:
|
||||
mock_client = MagicMock()
|
||||
llm._client = mock_client
|
||||
|
||||
def mock_stream(*args: Any, **kwargs: Any) -> Any:
|
||||
yield from mock_predict_stream_result
|
||||
|
||||
mock_client.predict_stream = mock_stream
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful assistant. Make sure to use tool for information.",
|
||||
),
|
||||
("placeholder", "{chat_history}"),
|
||||
("human", "{input}"),
|
||||
("placeholder", "{agent_scratchpad}"),
|
||||
]
|
||||
)
|
||||
|
||||
def mock_func(*args: Any, **kwargs: Any) -> str:
|
||||
return "36939 x 8922.4 = 329,511,111.6"
|
||||
|
||||
tools = [
|
||||
StructuredTool(
|
||||
name="name",
|
||||
description="description",
|
||||
args_schema=BaseModel,
|
||||
func=mock_func,
|
||||
)
|
||||
]
|
||||
agent = create_tool_calling_agent(llm, tools, prompt)
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) # type: ignore[arg-type]
|
||||
result = agent_executor.invoke({"input": "36939 * 8922.4"})
|
||||
assert result["output"] == "36939x8922.4 = 329,511,111.6"
|
||||
|
||||
|
||||
def test_convert_dict_to_message_human() -> None:
|
||||
message = {"role": "user", "content": "foo"}
|
||||
result = ChatMlflow._convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_dict_to_message_ai() -> None:
|
||||
message = {"role": "assistant", "content": "foo"}
|
||||
result = ChatMlflow._convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "main__test__python_exec",
|
||||
"arguments": '{"code": "result = 36939 * 8922.4" }',
|
||||
},
|
||||
}
|
||||
]
|
||||
message_with_tools: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
result = ChatMlflow._convert_dict_to_message(message_with_tools)
|
||||
expected_output = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls},
|
||||
id="call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_calls[0]["function"]["name"], # type: ignore[index]
|
||||
"args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index]
|
||||
"id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_convert_dict_to_message_system() -> None:
|
||||
message = {"role": "system", "content": "foo"}
|
||||
result = ChatMlflow._convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_dict_to_message_chat() -> None:
|
||||
message = {"role": "any_role", "content": "foo"}
|
||||
result = ChatMlflow._convert_dict_to_message(message)
|
||||
expected_output = ChatMessage(content="foo", role="any_role")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_delta_to_message_chunk_ai() -> None:
|
||||
delta = {"role": "assistant", "content": "foo"}
|
||||
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
||||
expected_output = AIMessageChunk(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
delta_with_tools: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"index": 0, "function": {"arguments": " }"}}],
|
||||
}
|
||||
result = ChatMlflow._convert_delta_to_message_chunk(delta_with_tools, "role")
|
||||
expected_output = AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]},
|
||||
id=None,
|
||||
tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)],
|
||||
)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_delta_to_message_chunk_tool() -> None:
|
||||
delta = {
|
||||
"role": "tool",
|
||||
"content": "foo",
|
||||
"tool_call_id": "tool_call_id",
|
||||
"id": "some_id",
|
||||
}
|
||||
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
||||
expected_output = ToolMessageChunk(
|
||||
content="foo", id="some_id", tool_call_id="tool_call_id"
|
||||
)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_delta_to_message_chunk_human() -> None:
|
||||
delta = {
|
||||
"role": "user",
|
||||
"content": "foo",
|
||||
}
|
||||
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
||||
expected_output = HumanMessageChunk(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_delta_to_message_chunk_system() -> None:
|
||||
delta = {
|
||||
"role": "system",
|
||||
"content": "foo",
|
||||
}
|
||||
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
||||
expected_output = SystemMessageChunk(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_delta_to_message_chunk_chat() -> None:
|
||||
delta = {
|
||||
"role": "any_role",
|
||||
"content": "foo",
|
||||
}
|
||||
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
|
||||
expected_output = ChatMessageChunk(content="foo", role="any_role")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_message_to_dict_human() -> None:
|
||||
human_message = HumanMessage(content="foo")
|
||||
result = ChatMlflow._convert_message_to_dict(human_message)
|
||||
expected_output = {"role": "user", "content": "foo"}
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_message_to_dict_system() -> None:
|
||||
system_message = SystemMessage(content="foo")
|
||||
result = ChatMlflow._convert_message_to_dict(system_message)
|
||||
expected_output = {"role": "system", "content": "foo"}
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_message_to_dict_ai() -> None:
|
||||
ai_message = AIMessage(content="foo")
|
||||
result = ChatMlflow._convert_message_to_dict(ai_message)
|
||||
expected_output = {"role": "assistant", "content": "foo"}
|
||||
assert result == expected_output
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "name", "args": {}, "id": "id", "type": "tool_call"}],
|
||||
)
|
||||
result = ChatMlflow._convert_message_to_dict(ai_message)
|
||||
expected_output_with_tools: Dict[str, Any] = {
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"id": "id",
|
||||
"function": {"name": "name", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
assert result == expected_output_with_tools
|
||||
|
||||
|
||||
def test_convert_message_to_dict_tool() -> None:
|
||||
tool_message = ToolMessageChunk(
|
||||
content="foo", id="some_id", tool_call_id="tool_call_id"
|
||||
)
|
||||
result = ChatMlflow._convert_message_to_dict(tool_message)
|
||||
expected_output = {
|
||||
"role": "tool",
|
||||
"content": "foo",
|
||||
"tool_call_id": "tool_call_id",
|
||||
}
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_message_to_dict_function() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ChatMlflow._convert_message_to_dict(FunctionMessage(content="", name="name"))
|
Loading…
Reference in New Issue
Block a user