mirror of
https://github.com/hwchase17/langchain
synced 2024-11-20 03:25:56 +00:00
openai[patch]: pass message name (#17537)
This commit is contained in:
parent
e5d7e455dc
commit
69e9610f62
@ -92,9 +92,10 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
The LangChain message.
|
The LangChain message.
|
||||||
"""
|
"""
|
||||||
role = _dict.get("role")
|
role = _dict.get("role")
|
||||||
|
name = _dict.get("name")
|
||||||
id_ = _dict.get("id")
|
id_ = _dict.get("id")
|
||||||
if role == "user":
|
if role == "user":
|
||||||
return HumanMessage(content=_dict.get("content", ""), id=id_)
|
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
# Fix for azure
|
# Fix for azure
|
||||||
# Also OpenAI returns None for tool invocations
|
# Also OpenAI returns None for tool invocations
|
||||||
@ -104,12 +105,14 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
additional_kwargs["function_call"] = dict(function_call)
|
additional_kwargs["function_call"] = dict(function_call)
|
||||||
if tool_calls := _dict.get("tool_calls"):
|
if tool_calls := _dict.get("tool_calls"):
|
||||||
additional_kwargs["tool_calls"] = tool_calls
|
additional_kwargs["tool_calls"] = tool_calls
|
||||||
return AIMessage(content=content, additional_kwargs=additional_kwargs, id=id_)
|
return AIMessage(
|
||||||
|
content=content, additional_kwargs=additional_kwargs, name=name, id=id_
|
||||||
|
)
|
||||||
elif role == "system":
|
elif role == "system":
|
||||||
return SystemMessage(content=_dict.get("content", ""), id=id_)
|
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
|
||||||
elif role == "function":
|
elif role == "function":
|
||||||
return FunctionMessage(
|
return FunctionMessage(
|
||||||
content=_dict.get("content", ""), name=_dict.get("name"), id=id_
|
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
|
||||||
)
|
)
|
||||||
elif role == "tool":
|
elif role == "tool":
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
@ -117,8 +120,9 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
additional_kwargs["name"] = _dict["name"]
|
additional_kwargs["name"] = _dict["name"]
|
||||||
return ToolMessage(
|
return ToolMessage(
|
||||||
content=_dict.get("content", ""),
|
content=_dict.get("content", ""),
|
||||||
tool_call_id=_dict.get("tool_call_id"),
|
tool_call_id=cast(str, _dict.get("tool_call_id")),
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
|
name=name,
|
||||||
id=id_,
|
id=id_,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -134,13 +138,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
Returns:
|
Returns:
|
||||||
The dictionary.
|
The dictionary.
|
||||||
"""
|
"""
|
||||||
message_dict: Dict[str, Any]
|
message_dict: Dict[str, Any] = {
|
||||||
|
"content": message.content,
|
||||||
|
"name": message.name,
|
||||||
|
}
|
||||||
if isinstance(message, ChatMessage):
|
if isinstance(message, ChatMessage):
|
||||||
message_dict = {"role": message.role, "content": message.content}
|
message_dict["role"] = message.role
|
||||||
elif isinstance(message, HumanMessage):
|
elif isinstance(message, HumanMessage):
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict["role"] = "user"
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict["role"] = "assistant"
|
||||||
if "function_call" in message.additional_kwargs:
|
if "function_call" in message.additional_kwargs:
|
||||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||||
# If function call only, content is None not empty string
|
# If function call only, content is None not empty string
|
||||||
@ -152,19 +159,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
if message_dict["content"] == "":
|
if message_dict["content"] == "":
|
||||||
message_dict["content"] = None
|
message_dict["content"] = None
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict["role"] = "system"
|
||||||
elif isinstance(message, FunctionMessage):
|
elif isinstance(message, FunctionMessage):
|
||||||
message_dict = {
|
message_dict["role"] = "function"
|
||||||
"role": "function",
|
|
||||||
"content": message.content,
|
|
||||||
"name": message.name,
|
|
||||||
}
|
|
||||||
elif isinstance(message, ToolMessage):
|
elif isinstance(message, ToolMessage):
|
||||||
message_dict = {
|
message_dict["role"] = "tool"
|
||||||
"role": "tool",
|
message_dict["tool_call_id"] = message.tool_call_id
|
||||||
"content": message.content,
|
|
||||||
"tool_call_id": message.tool_call_id,
|
# tool message doesn't have name: https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||||
}
|
if message_dict["name"] is None:
|
||||||
|
del message_dict["name"]
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Got unknown type {message}")
|
raise TypeError(f"Got unknown type {message}")
|
||||||
if "name" in message.additional_kwargs:
|
if "name" in message.additional_kwargs:
|
||||||
|
20
libs/partners/openai/poetry.lock
generated
20
libs/partners/openai/poetry.lock
generated
@ -318,7 +318,7 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.1.29"
|
version = "0.1.33-rc.1"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
@ -344,13 +344,13 @@ url = "../../core"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langsmith"
|
name = "langsmith"
|
||||||
version = "0.1.22"
|
version = "0.1.29"
|
||||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "langsmith-0.1.22-py3-none-any.whl", hash = "sha256:b877d302bd4cf7c79e9e6e24bedf669132abf0659143390a29350eda0945544f"},
|
{file = "langsmith-0.1.29-py3-none-any.whl", hash = "sha256:5439f5bf25b00a43602aa1ddaba0a31d413ed920e7b20494070328f7e1ecbb86"},
|
||||||
{file = "langsmith-0.1.22.tar.gz", hash = "sha256:2921ae2297c2fb23baa2641b9cf416914ac7fd65f4a9dd5a573bc30efb54b693"},
|
{file = "langsmith-0.1.29.tar.gz", hash = "sha256:60ba0bd889c6a2683d123f66dc5043368eb2f103c4eb69e382abf7ce69a9f7d6"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -458,13 +458,13 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openai"
|
name = "openai"
|
||||||
version = "1.13.3"
|
version = "1.14.2"
|
||||||
description = "The official Python library for the openai API"
|
description = "The official Python library for the openai API"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7.1"
|
python-versions = ">=3.7.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "openai-1.13.3-py3-none-any.whl", hash = "sha256:5769b62abd02f350a8dd1a3a242d8972c947860654466171d60fb0972ae0a41c"},
|
{file = "openai-1.14.2-py3-none-any.whl", hash = "sha256:a48b3c4d635b603952189ac5a0c0c9b06c025b80eb2900396939f02bb2104ac3"},
|
||||||
{file = "openai-1.13.3.tar.gz", hash = "sha256:ff6c6b3bc7327e715e4b3592a923a5a1c7519ff5dd764a83d69f633d49e77a7b"},
|
{file = "openai-1.14.2.tar.gz", hash = "sha256:e5642f7c02cf21994b08477d7bb2c1e46d8f335d72c26f0396c5f89b15b5b153"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -566,13 +566,13 @@ testing = ["pytest", "pytest-benchmark"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pydantic"
|
name = "pydantic"
|
||||||
version = "2.6.3"
|
version = "2.6.4"
|
||||||
description = "Data validation using Python type hints"
|
description = "Data validation using Python type hints"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"},
|
{file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"},
|
||||||
{file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"},
|
{file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Test OpenAI Chat API wrapper."""
|
"""Test OpenAI Chat API wrapper."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
@ -44,6 +45,13 @@ def test__convert_dict_to_message_human() -> None:
|
|||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_human_with_name() -> None:
|
||||||
|
message = {"role": "user", "content": "foo", "name": "test"}
|
||||||
|
result = _convert_dict_to_message(message)
|
||||||
|
expected_output = HumanMessage(content="foo", name="test")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test__convert_dict_to_message_ai() -> None:
|
def test__convert_dict_to_message_ai() -> None:
|
||||||
message = {"role": "assistant", "content": "foo"}
|
message = {"role": "assistant", "content": "foo"}
|
||||||
result = _convert_dict_to_message(message)
|
result = _convert_dict_to_message(message)
|
||||||
@ -51,6 +59,13 @@ def test__convert_dict_to_message_ai() -> None:
|
|||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_ai_with_name() -> None:
|
||||||
|
message = {"role": "assistant", "content": "foo", "name": "test"}
|
||||||
|
result = _convert_dict_to_message(message)
|
||||||
|
expected_output = AIMessage(content="foo", name="test")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test__convert_dict_to_message_system() -> None:
|
def test__convert_dict_to_message_system() -> None:
|
||||||
message = {"role": "system", "content": "foo"}
|
message = {"role": "system", "content": "foo"}
|
||||||
result = _convert_dict_to_message(message)
|
result = _convert_dict_to_message(message)
|
||||||
@ -58,6 +73,13 @@ def test__convert_dict_to_message_system() -> None:
|
|||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_system_with_name() -> None:
|
||||||
|
message = {"role": "system", "content": "foo", "name": "test"}
|
||||||
|
result = _convert_dict_to_message(message)
|
||||||
|
expected_output = SystemMessage(content="foo", name="test")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_completion() -> dict:
|
def mock_completion() -> dict:
|
||||||
return {
|
return {
|
||||||
@ -71,6 +93,7 @@ def mock_completion() -> dict:
|
|||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "Bar Baz",
|
"content": "Bar Baz",
|
||||||
|
"name": "Erick",
|
||||||
},
|
},
|
||||||
"finish_reason": "stop",
|
"finish_reason": "stop",
|
||||||
}
|
}
|
||||||
@ -134,3 +157,31 @@ async def test_openai_ainvoke(mock_completion: dict) -> None:
|
|||||||
def test__get_encoding_model(model: str) -> None:
|
def test__get_encoding_model(model: str) -> None:
|
||||||
ChatOpenAI(model=model)._get_encoding_model()
|
ChatOpenAI(model=model)._get_encoding_model()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_invoke_name(mock_completion: dict) -> None:
|
||||||
|
llm = ChatOpenAI()
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.create.return_value = mock_completion
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
llm,
|
||||||
|
"client",
|
||||||
|
mock_client,
|
||||||
|
):
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="Foo", name="Katie"),
|
||||||
|
]
|
||||||
|
res = llm.invoke(messages)
|
||||||
|
call_args, call_kwargs = mock_client.create.call_args
|
||||||
|
assert len(call_args) == 0 # no positional args
|
||||||
|
call_messages = call_kwargs["messages"]
|
||||||
|
assert len(call_messages) == 1
|
||||||
|
assert call_messages[0]["role"] == "user"
|
||||||
|
assert call_messages[0]["content"] == "Foo"
|
||||||
|
assert call_messages[0]["name"] == "Katie"
|
||||||
|
|
||||||
|
# check return type has name
|
||||||
|
assert res.content == "Bar Baz"
|
||||||
|
assert res.name == "Erick"
|
||||||
|
Loading…
Reference in New Issue
Block a user