groq: handle streaming tool call case (#19978)

pull/19981/head
Erick Friis 2 months ago committed by GitHub
parent 5acb564d6f
commit 51bdfe04e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -225,11 +225,9 @@ class ChatGroq(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
@ -237,7 +235,6 @@ class ChatGroq(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = self.client.create(messages=message_dicts, **params)
@ -248,11 +245,9 @@ class ChatGroq(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
@ -261,7 +256,6 @@ class ChatGroq(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = await self.async_client.create(messages=message_dicts, **params)
@ -275,6 +269,31 @@ class ChatGroq(BaseChatModel):
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
# groq api does not support streaming with tools yet
if "tools" in kwargs:
response = self.client.create(
messages=message_dicts, **{**params, **kwargs}
)
chat_result = self._create_chat_result(response)
generation = chat_result.generations[0]
message = generation.message
chunk_ = ChatGenerationChunk(
message=AIMessageChunk(
content=message.content, additional_kwargs=message.additional_kwargs
),
generation_info=generation.generation_info,
)
if run_manager:
geninfo = chunk_.generation_info or {}
run_manager.on_llm_new_token(
chunk_.text,
chunk=chunk_,
logprobs=geninfo.get("logprobs"),
)
yield chunk_
return
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
@ -310,6 +329,31 @@ class ChatGroq(BaseChatModel):
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
# groq api does not support streaming with tools yet
if "tools" in kwargs:
response = await self.async_client.create(
messages=message_dicts, **{**params, **kwargs}
)
chat_result = self._create_chat_result(response)
generation = chat_result.generations[0]
message = generation.message
chunk_ = ChatGenerationChunk(
message=AIMessageChunk(
content=message.content, additional_kwargs=message.additional_kwargs
),
generation_info=generation.generation_info,
)
if run_manager:
geninfo = chunk_.generation_info or {}
await run_manager.on_llm_new_token(
chunk_.text,
chunk=chunk_,
logprobs=geninfo.get("logprobs"),
)
yield chunk_
return
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk

@ -6,6 +6,7 @@ from typing import Any
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
@ -272,6 +273,64 @@ def test_tool_choice_bool() -> None:
assert tool_call["type"] == "function"
def test_streaming_tool_call() -> None:
"""Test that tool choice is respected."""
llm = ChatGroq()
class MyTool(BaseModel):
name: str
age: int
with_tool = llm.bind_tools([MyTool], tool_choice="MyTool")
resp = with_tool.stream("Who was the 27 year old named Erick?")
additional_kwargs = None
for chunk in resp:
assert isinstance(chunk, AIMessageChunk)
assert chunk.content == "" # should just be tool call
additional_kwargs = chunk.additional_kwargs
assert additional_kwargs is not None
tool_calls = additional_kwargs["tool_calls"]
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["function"]["name"] == "MyTool"
assert json.loads(tool_call["function"]["arguments"]) == {
"age": 27,
"name": "Erick",
}
assert tool_call["type"] == "function"
async def test_astreaming_tool_call() -> None:
"""Test that tool choice is respected."""
llm = ChatGroq()
class MyTool(BaseModel):
name: str
age: int
with_tool = llm.bind_tools([MyTool], tool_choice="MyTool")
resp = with_tool.astream("Who was the 27 year old named Erick?")
additional_kwargs = None
async for chunk in resp:
assert isinstance(chunk, AIMessageChunk)
assert chunk.content == "" # should just be tool call
additional_kwargs = chunk.additional_kwargs
assert additional_kwargs is not None
tool_calls = additional_kwargs["tool_calls"]
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["function"]["name"] == "MyTool"
assert json.loads(tool_call["function"]["arguments"]) == {
"age": 27,
"name": "Erick",
}
assert tool_call["type"] == "function"
@pytest.mark.scheduled
def test_json_mode_structured_output() -> None:
"""Test with_structured_output with json"""

Loading…
Cancel
Save