openai[patch]: tool use integration test (#19460)

pull/19448/head^2
Bagatur 6 months ago committed by GitHub
parent a99e644913
commit d93d49bc43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,6 +1,5 @@
"""Test ChatOpenAI chat model.""" """Test ChatOpenAI chat model."""
from typing import Any, List, Optional, cast
from typing import Any, Optional, cast
import pytest import pytest
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager
@ -10,6 +9,7 @@ from langchain_core.messages import (
BaseMessageChunk, BaseMessageChunk,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
ToolMessage,
) )
from langchain_core.outputs import ( from langchain_core.outputs import (
ChatGeneration, ChatGeneration,
@ -470,6 +470,25 @@ async def test_async_response_metadata_streaming() -> None:
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"] assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]
class GenerateUsername(BaseModel):
"Get a username based on someone's name and hair color."
name: str
hair_color: str
def test_tool_use() -> None:
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True)
msgs: List = [HumanMessage("Sally has green hair, what would her username be?")]
ai_msg = llm_with_tool.invoke(msgs)
tool_msg = ToolMessage(
"sally_green_hair", tool_call_id=ai_msg.additional_kwargs["tool_calls"][0]["id"]
)
msgs.extend([ai_msg, tool_msg])
llm_with_tool.invoke(msgs)
def test_openai_structured_output() -> None: def test_openai_structured_output() -> None:
class MyModel(BaseModel): class MyModel(BaseModel):
"""A Person""" """A Person"""

Loading…
Cancel
Save