openai[patch]: pass message name (#17537)

This commit is contained in:
Erick Friis 2024-03-19 12:57:27 -07:00 committed by GitHub
parent e5d7e455dc
commit 69e9610f62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 85 additions and 30 deletions

View File

@ -92,9 +92,10 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
The LangChain message. The LangChain message.
""" """
role = _dict.get("role") role = _dict.get("role")
name = _dict.get("name")
id_ = _dict.get("id") id_ = _dict.get("id")
if role == "user": if role == "user":
return HumanMessage(content=_dict.get("content", ""), id=id_) return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
elif role == "assistant": elif role == "assistant":
# Fix for azure # Fix for azure
# Also OpenAI returns None for tool invocations # Also OpenAI returns None for tool invocations
@ -104,12 +105,14 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
additional_kwargs["function_call"] = dict(function_call) additional_kwargs["function_call"] = dict(function_call)
if tool_calls := _dict.get("tool_calls"): if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs, id=id_) return AIMessage(
content=content, additional_kwargs=additional_kwargs, name=name, id=id_
)
elif role == "system": elif role == "system":
return SystemMessage(content=_dict.get("content", ""), id=id_) return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
elif role == "function": elif role == "function":
return FunctionMessage( return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name"), id=id_ content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
) )
elif role == "tool": elif role == "tool":
additional_kwargs = {} additional_kwargs = {}
@ -117,8 +120,9 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
additional_kwargs["name"] = _dict["name"] additional_kwargs["name"] = _dict["name"]
return ToolMessage( return ToolMessage(
content=_dict.get("content", ""), content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id"), tool_call_id=cast(str, _dict.get("tool_call_id")),
additional_kwargs=additional_kwargs, additional_kwargs=additional_kwargs,
name=name,
id=id_, id=id_,
) )
else: else:
@ -134,13 +138,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns: Returns:
The dictionary. The dictionary.
""" """
message_dict: Dict[str, Any] message_dict: Dict[str, Any] = {
"content": message.content,
"name": message.name,
}
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content} message_dict["role"] = message.role
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content} message_dict["role"] = "user"
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content} message_dict["role"] = "assistant"
if "function_call" in message.additional_kwargs: if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"] message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string # If function call only, content is None not empty string
@ -152,19 +159,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
if message_dict["content"] == "": if message_dict["content"] == "":
message_dict["content"] = None message_dict["content"] = None
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content} message_dict["role"] = "system"
elif isinstance(message, FunctionMessage): elif isinstance(message, FunctionMessage):
message_dict = { message_dict["role"] = "function"
"role": "function",
"content": message.content,
"name": message.name,
}
elif isinstance(message, ToolMessage): elif isinstance(message, ToolMessage):
message_dict = { message_dict["role"] = "tool"
"role": "tool", message_dict["tool_call_id"] = message.tool_call_id
"content": message.content,
"tool_call_id": message.tool_call_id, # tool message doesn't have name: https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
} if message_dict["name"] is None:
del message_dict["name"]
else: else:
raise TypeError(f"Got unknown type {message}") raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs: if "name" in message.additional_kwargs:

View File

@ -318,7 +318,7 @@ files = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.1.29" version = "0.1.33-rc.1"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -344,13 +344,13 @@ url = "../../core"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.1.22" version = "0.1.29"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
files = [ files = [
{file = "langsmith-0.1.22-py3-none-any.whl", hash = "sha256:b877d302bd4cf7c79e9e6e24bedf669132abf0659143390a29350eda0945544f"}, {file = "langsmith-0.1.29-py3-none-any.whl", hash = "sha256:5439f5bf25b00a43602aa1ddaba0a31d413ed920e7b20494070328f7e1ecbb86"},
{file = "langsmith-0.1.22.tar.gz", hash = "sha256:2921ae2297c2fb23baa2641b9cf416914ac7fd65f4a9dd5a573bc30efb54b693"}, {file = "langsmith-0.1.29.tar.gz", hash = "sha256:60ba0bd889c6a2683d123f66dc5043368eb2f103c4eb69e382abf7ce69a9f7d6"},
] ]
[package.dependencies] [package.dependencies]
@ -458,13 +458,13 @@ files = [
[[package]] [[package]]
name = "openai" name = "openai"
version = "1.13.3" version = "1.14.2"
description = "The official Python library for the openai API" description = "The official Python library for the openai API"
optional = false optional = false
python-versions = ">=3.7.1" python-versions = ">=3.7.1"
files = [ files = [
{file = "openai-1.13.3-py3-none-any.whl", hash = "sha256:5769b62abd02f350a8dd1a3a242d8972c947860654466171d60fb0972ae0a41c"}, {file = "openai-1.14.2-py3-none-any.whl", hash = "sha256:a48b3c4d635b603952189ac5a0c0c9b06c025b80eb2900396939f02bb2104ac3"},
{file = "openai-1.13.3.tar.gz", hash = "sha256:ff6c6b3bc7327e715e4b3592a923a5a1c7519ff5dd764a83d69f633d49e77a7b"}, {file = "openai-1.14.2.tar.gz", hash = "sha256:e5642f7c02cf21994b08477d7bb2c1e46d8f335d72c26f0396c5f89b15b5b153"},
] ]
[package.dependencies] [package.dependencies]
@ -566,13 +566,13 @@ testing = ["pytest", "pytest-benchmark"]
[[package]] [[package]]
name = "pydantic" name = "pydantic"
version = "2.6.3" version = "2.6.4"
description = "Data validation using Python type hints" description = "Data validation using Python type hints"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"}, {file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"},
{file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"}, {file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"},
] ]
[package.dependencies] [package.dependencies]

View File

@ -1,4 +1,5 @@
"""Test OpenAI Chat API wrapper.""" """Test OpenAI Chat API wrapper."""
import json import json
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@ -44,6 +45,13 @@ def test__convert_dict_to_message_human() -> None:
assert result == expected_output assert result == expected_output
def test__convert_dict_to_message_human_with_name() -> None:
message = {"role": "user", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo", name="test")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None: def test__convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"} message = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message) result = _convert_dict_to_message(message)
@ -51,6 +59,13 @@ def test__convert_dict_to_message_ai() -> None:
assert result == expected_output assert result == expected_output
def test__convert_dict_to_message_ai_with_name() -> None:
message = {"role": "assistant", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo", name="test")
assert result == expected_output
def test__convert_dict_to_message_system() -> None: def test__convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"} message = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message) result = _convert_dict_to_message(message)
@ -58,6 +73,13 @@ def test__convert_dict_to_message_system() -> None:
assert result == expected_output assert result == expected_output
def test__convert_dict_to_message_system_with_name() -> None:
message = {"role": "system", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo", name="test")
assert result == expected_output
@pytest.fixture @pytest.fixture
def mock_completion() -> dict: def mock_completion() -> dict:
return { return {
@ -71,6 +93,7 @@ def mock_completion() -> dict:
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": "Bar Baz", "content": "Bar Baz",
"name": "Erick",
}, },
"finish_reason": "stop", "finish_reason": "stop",
} }
@ -134,3 +157,31 @@ async def test_openai_ainvoke(mock_completion: dict) -> None:
def test__get_encoding_model(model: str) -> None: def test__get_encoding_model(model: str) -> None:
ChatOpenAI(model=model)._get_encoding_model() ChatOpenAI(model=model)._get_encoding_model()
return return
def test_openai_invoke_name(mock_completion: dict) -> None:
llm = ChatOpenAI()
mock_client = MagicMock()
mock_client.create.return_value = mock_completion
with patch.object(
llm,
"client",
mock_client,
):
messages = [
HumanMessage(content="Foo", name="Katie"),
]
res = llm.invoke(messages)
call_args, call_kwargs = mock_client.create.call_args
assert len(call_args) == 0 # no positional args
call_messages = call_kwargs["messages"]
assert len(call_messages) == 1
assert call_messages[0]["role"] == "user"
assert call_messages[0]["content"] == "Foo"
assert call_messages[0]["name"] == "Katie"
# check return type has name
assert res.content == "Bar Baz"
assert res.name == "Erick"