community[patch]: support bind_tools for ChatMlflow (#24547)

Thank you for contributing to LangChain!

- [x] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- **Description:** 
Support ChatMlflow.bind_tools method
Tested in Databricks:
<img width="836" alt="image"
src="https://github.com/user-attachments/assets/fa28ef50-0110-4698-8eda-4faf6f0b9ef8">



- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
This commit is contained in:
Serena Ruan 2024-08-01 23:43:07 +08:00 committed by GitHub
parent 769c3bb838
commit 1827bb4042
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 689 additions and 20 deletions

View File

@ -36,7 +36,7 @@
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
"| | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
"\n",
"### Supported Methods\n",
"\n",
@ -395,6 +395,66 @@
"chat_model_external.invoke(\"How to use Databricks?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Function calling on Databricks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Databricks Function Calling is OpenAI-compatible and is only available during model serving as part of Foundation Model APIs.\n",
"\n",
"See [Databricks function calling introduction](https://docs.databricks.com/en/machine-learning/model-serving/function-calling.html#supported-models) for supported models."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.databricks import ChatDatabricks\n",
"\n",
"llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
" },\n",
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
" },\n",
" },\n",
" },\n",
" }\n",
"]\n",
"\n",
"# supported tool_choice values: \"auto\", \"required\", \"none\", function name in string format,\n",
"# or a dictionary as {\"type\": \"function\", \"function\": {\"name\": <<tool_name>>}}\n",
"model = llm.bind_tools(tools, tool_choice=\"auto\")\n",
"\n",
"messages = [{\"role\": \"user\", \"content\": \"What is the current temperature of Chicago?\"}]\n",
"print(model.invoke(messages))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See [Databricks Unity Catalog](docs/integrations/tools/databricks.ipynb) about how to use UC functions in chains."
]
},
{
"cell_type": "markdown",
"metadata": {},

View File

@ -38,7 +38,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet databricks-sdk langchain-community langchain-openai"
"%pip install --upgrade --quiet databricks-sdk langchain-community mlflow"
]
},
{
@ -47,9 +47,9 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import ChatOpenAI\n",
"from langchain_community.chat_models.databricks import ChatDatabricks\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\")"
"llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")"
]
},
{

View File

@ -91,3 +91,4 @@ vdms>=0.0.20
xata>=1.0.0a7,<2
xmltodict>=0.13.0,<0.14
nanopq==0.2.1
mlflow[genai]>=2.14.0

View File

@ -1,5 +1,19 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional, cast
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
)
from urllib.parse import urlparse
from langchain_core.callbacks import CallbackManagerForLLMRun
@ -15,15 +29,27 @@ from langchain_core.messages import (
FunctionMessage,
HumanMessage,
HumanMessageChunk,
InvalidToolCall,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
PrivateAttr,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
logger = logging.getLogger(__name__)
@ -228,11 +254,32 @@ class ChatMlflow(BaseChatModel):
@staticmethod
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
content = _dict["content"]
content = cast(str, _dict.get("content"))
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
return AIMessage(content=content)
content = content or ""
additional_kwargs: Dict = {}
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(
parse_tool_call(raw_tool_call, return_id=True)
)
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
id=_dict.get("id"),
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
return SystemMessage(content=content)
else:
@ -243,13 +290,38 @@ class ChatMlflow(BaseChatModel):
_dict: Mapping[str, Any], default_role: str
) -> BaseMessageChunk:
role = _dict.get("role", default_role)
content = _dict["content"]
content = _dict.get("content") or ""
if role == "user":
return HumanMessageChunk(content=content)
elif role == "assistant":
return AIMessageChunk(content=content)
additional_kwargs: Dict = {}
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc["index"],
)
for rtc in raw_tool_calls
]
except KeyError:
pass
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=_dict.get("id"),
tool_call_chunks=tool_call_chunks,
)
elif role == "system":
return SystemMessageChunk(content=content)
elif role == "tool":
return ToolMessageChunk(
content=content, tool_call_id=_dict["tool_call_id"], id=_dict.get("id")
)
else:
return ChatMessageChunk(content=content, role=role)
@ -262,14 +334,47 @@ class ChatMlflow(BaseChatModel):
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {"content": message.content}
if (name := message.name or message.additional_kwargs.get("name")) is not None:
message_dict["name"] = name
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
message_dict["role"] = message.role
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
message_dict["role"] = "user"
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
message_dict["role"] = "assistant"
if message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_openai_tool_call(tc)
for tc in message.invalid_tool_calls
] # type: ignore[assignment]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
tool_call_supported_props = {"id", "type", "function"}
message_dict["tool_calls"] = [
{
k: v
for k, v in tool_call.items() # type: ignore[union-attr]
if k in tool_call_supported_props
}
for tool_call in message_dict["tool_calls"]
]
else:
pass
# If tool calls present, content null value should be None not empty string.
if "tool_calls" in message_dict:
message_dict["content"] = message_dict["content"] or None # type: ignore[assignment]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
message_dict["role"] = "system"
elif isinstance(message, ToolMessage):
message_dict["role"] = "tool"
message_dict["tool_call_id"] = message.tool_call_id
supported_props = {"content", "role", "tool_call_id"}
message_dict = {
k: v for k, v in message_dict.items() if k in supported_props
}
elif isinstance(message, FunctionMessage):
raise ValueError(
"Function messages are not supported by Databricks. Please"
@ -280,12 +385,6 @@ class ChatMlflow(BaseChatModel):
if "function_call" in message.additional_kwargs:
ChatMlflow._raise_functions_not_supported()
if message.additional_kwargs:
logger.warning(
"Additional message arguments are unsupported by Databricks"
" and will be ignored: %s",
message.additional_kwargs,
)
return message_dict
@staticmethod
@ -302,3 +401,89 @@ class ChatMlflow(BaseChatModel):
usage = response.get("usage", {})
return ChatResult(generations=generations, llm_output=usage)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
tool_choice: Which tool to require the model to call.
Options are:
name of the tool (str): calls corresponding tool;
"auto": automatically selects a tool (including no tool);
"none": model does not generate any tool calls and instead must
generate a standard assistant message;
"required": the model picks the most relevant tool in tools and
must generate a tool call;
or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice:
if isinstance(tool_choice, str):
# tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "required"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
elif isinstance(tool_choice, dict):
tool_names = [
formatted_tool["function"]["name"]
for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"]
for tool_name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tools were {tool_names}."
)
else:
raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}
def _lc_invalid_tool_call_to_openai_tool_call(
invalid_tool_call: InvalidToolCall,
) -> dict:
return {
"type": "function",
"id": invalid_tool_call["id"],
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
},
}

View File

@ -0,0 +1,423 @@
import json
from typing import Any, Dict, List
from unittest.mock import MagicMock
import pytest
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolCallChunk,
ToolMessageChunk,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, BaseModel
from langchain_core.tools import StructuredTool
from langchain_community.chat_models.mlflow import ChatMlflow
@pytest.fixture
def llm() -> ChatMlflow:
return ChatMlflow(
endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
)
@pytest.fixture
def model_input() -> List[BaseMessage]:
data = [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{"role": "user", "content": "36939 * 8922.4"},
]
return [ChatMlflow._convert_dict_to_message(value) for value in data]
@pytest.fixture
def mock_prediction() -> dict:
return {
"id": "chatcmpl_id",
"object": "chat.completion",
"created": 1721875529,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "To calculate the result of 36939 multiplied by 8922.4, "
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
}
@pytest.fixture
def mock_predict_stream_result() -> List[dict]:
return [
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "36939"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "x"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "8922.4"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": " = "},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "329,511,111.6"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
},
]
@pytest.mark.requires("mlflow")
def test_chat_mlflow_predict(
llm: ChatMlflow, model_input: List[BaseMessage], mock_prediction: dict
) -> None:
mock_client = MagicMock()
llm._client = mock_client
def mock_predict(*args: Any, **kwargs: Any) -> Any:
return mock_prediction
mock_client.predict = mock_predict
res = llm.invoke(model_input)
assert res.content == mock_prediction["choices"][0]["message"]["content"]
@pytest.mark.requires("mlflow")
def test_chat_mlflow_stream(
llm: ChatMlflow,
model_input: List[BaseMessage],
mock_predict_stream_result: List[dict],
) -> None:
mock_client = MagicMock()
llm._client = mock_client
def mock_stream(*args: Any, **kwargs: Any) -> Any:
yield from mock_predict_stream_result
mock_client.predict_stream = mock_stream
for i, res in enumerate(llm.stream(model_input)):
assert (
res.content
== mock_predict_stream_result[i]["choices"][0]["delta"]["content"]
)
@pytest.mark.requires("mlflow")
@pytest.mark.skipif(
_PYDANTIC_MAJOR_VERSION < 2,
reason="The tool mock is not compatible with pydantic 1.x",
)
def test_chat_mlflow_bind_tools(
llm: ChatMlflow, mock_predict_stream_result: List[dict]
) -> None:
mock_client = MagicMock()
llm._client = mock_client
def mock_stream(*args: Any, **kwargs: Any) -> Any:
yield from mock_predict_stream_result
mock_client.predict_stream = mock_stream
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful assistant. Make sure to use tool for information.",
),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
def mock_func(*args: Any, **kwargs: Any) -> str:
return "36939 x 8922.4 = 329,511,111.6"
tools = [
StructuredTool(
name="name",
description="description",
args_schema=BaseModel,
func=mock_func,
)
]
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) # type: ignore[arg-type]
result = agent_executor.invoke({"input": "36939 * 8922.4"})
assert result["output"] == "36939x8922.4 = 329,511,111.6"
def test_convert_dict_to_message_human() -> None:
message = {"role": "user", "content": "foo"}
result = ChatMlflow._convert_dict_to_message(message)
expected_output = HumanMessage(content="foo")
assert result == expected_output
def test_convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = ChatMlflow._convert_dict_to_message(message)
expected_output = AIMessage(content="foo")
assert result == expected_output
tool_calls = [
{
"id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
"type": "function",
"function": {
"name": "main__test__python_exec",
"arguments": '{"code": "result = 36939 * 8922.4" }',
},
}
]
message_with_tools: Dict[str, Any] = {
"role": "assistant",
"content": None,
"tool_calls": tool_calls,
}
result = ChatMlflow._convert_dict_to_message(message_with_tools)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": tool_calls},
id="call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
tool_calls=[
{
"name": tool_calls[0]["function"]["name"], # type: ignore[index]
"args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index]
"id": "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12",
"type": "tool_call",
}
],
)
def test_convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = ChatMlflow._convert_dict_to_message(message)
expected_output = SystemMessage(content="foo")
assert result == expected_output
def test_convert_dict_to_message_chat() -> None:
message = {"role": "any_role", "content": "foo"}
result = ChatMlflow._convert_dict_to_message(message)
expected_output = ChatMessage(content="foo", role="any_role")
assert result == expected_output
def test_convert_delta_to_message_chunk_ai() -> None:
delta = {"role": "assistant", "content": "foo"}
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
expected_output = AIMessageChunk(content="foo")
assert result == expected_output
delta_with_tools: Dict[str, Any] = {
"role": "assistant",
"content": None,
"tool_calls": [{"index": 0, "function": {"arguments": " }"}}],
}
result = ChatMlflow._convert_delta_to_message_chunk(delta_with_tools, "role")
expected_output = AIMessageChunk(
content="",
additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]},
id=None,
tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)],
)
assert result == expected_output
def test_convert_delta_to_message_chunk_tool() -> None:
delta = {
"role": "tool",
"content": "foo",
"tool_call_id": "tool_call_id",
"id": "some_id",
}
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
expected_output = ToolMessageChunk(
content="foo", id="some_id", tool_call_id="tool_call_id"
)
assert result == expected_output
def test_convert_delta_to_message_chunk_human() -> None:
delta = {
"role": "user",
"content": "foo",
}
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
expected_output = HumanMessageChunk(content="foo")
assert result == expected_output
def test_convert_delta_to_message_chunk_system() -> None:
delta = {
"role": "system",
"content": "foo",
}
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
expected_output = SystemMessageChunk(content="foo")
assert result == expected_output
def test_convert_delta_to_message_chunk_chat() -> None:
delta = {
"role": "any_role",
"content": "foo",
}
result = ChatMlflow._convert_delta_to_message_chunk(delta, "default_role")
expected_output = ChatMessageChunk(content="foo", role="any_role")
assert result == expected_output
def test_convert_message_to_dict_human() -> None:
human_message = HumanMessage(content="foo")
result = ChatMlflow._convert_message_to_dict(human_message)
expected_output = {"role": "user", "content": "foo"}
assert result == expected_output
def test_convert_message_to_dict_system() -> None:
system_message = SystemMessage(content="foo")
result = ChatMlflow._convert_message_to_dict(system_message)
expected_output = {"role": "system", "content": "foo"}
assert result == expected_output
def test_convert_message_to_dict_ai() -> None:
ai_message = AIMessage(content="foo")
result = ChatMlflow._convert_message_to_dict(ai_message)
expected_output = {"role": "assistant", "content": "foo"}
assert result == expected_output
ai_message = AIMessage(
content="",
tool_calls=[{"name": "name", "args": {}, "id": "id", "type": "tool_call"}],
)
result = ChatMlflow._convert_message_to_dict(ai_message)
expected_output_with_tools: Dict[str, Any] = {
"content": None,
"role": "assistant",
"tool_calls": [
{
"type": "function",
"id": "id",
"function": {"name": "name", "arguments": "{}"},
}
],
}
assert result == expected_output_with_tools
def test_convert_message_to_dict_tool() -> None:
tool_message = ToolMessageChunk(
content="foo", id="some_id", tool_call_id="tool_call_id"
)
result = ChatMlflow._convert_message_to_dict(tool_message)
expected_output = {
"role": "tool",
"content": "foo",
"tool_call_id": "tool_call_id",
}
assert result == expected_output
def test_convert_message_to_dict_function() -> None:
with pytest.raises(ValueError):
ChatMlflow._convert_message_to_dict(FunctionMessage(content="", name="name"))