groq: Add tool calling support (#19971)

**Description:** Add with_structured_output to groq chat models
**Issue:** 
**Dependencies:** N/A
**Twitter handle:** N/A
This commit is contained in:
Graden Rea 2024-04-03 14:40:20 -07:00 committed by GitHub
parent 6f20f140ca
commit 88cf8a2905
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 551 additions and 19 deletions

View File

@ -358,13 +358,119 @@
"model_with_structure.invoke(\"Tell me a joke about cats\")"
]
},
{
"cell_type": "markdown",
"id": "6214781d",
"metadata": {},
"source": [
"## Groq\n",
"\n",
"Groq provides an OpenAI-compatible function calling API"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3066b2af",
"execution_count": 11,
"id": "70511bc3",
"metadata": {},
"outputs": [],
"source": []
"source": [
"from langchain_groq import ChatGroq"
]
},
{
"cell_type": "markdown",
"id": "6b7e97a6",
"metadata": {},
"source": [
"### Function Calling\n",
"\n",
"By default, we will use `function_calling`"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "be9fdf04",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/reag/src/langchain/libs/core/langchain_core/_api/beta_decorator.py:87: LangChainBetaWarning: The function `with_structured_output` is in beta. It is actively being worked on, so the API may change.\n",
" warn_beta(\n"
]
}
],
"source": [
"model = ChatGroq()\n",
"model_with_structure = model.with_structured_output(Joke)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e13f4676",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Joke(setup=\"Why don't cats play poker in the jungle?\", punchline='Too many cheetahs!')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_with_structure.invoke(\"Tell me a joke about cats\")"
]
},
{
"cell_type": "markdown",
"id": "a82c2f55",
"metadata": {},
"source": [
"### JSON Mode\n",
"\n",
"We also support JSON mode. Note that we need to specify in the prompt the format that it should respond in."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "86574fb8",
"metadata": {},
"outputs": [],
"source": [
"model_with_structure = model.with_structured_output(Joke, method=\"json_mode\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "01dced9c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Joke(setup=\"Why don't cats play poker in the jungle?\", punchline='Too many cheetahs!')"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_with_structure.invoke(\n",
" \"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys\"\n",
")"
]
}
],
"metadata": {

View File

@ -4,24 +4,31 @@ from __future__ import annotations
import os
import warnings
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
Union,
cast,
)
from langchain_core._api import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
@ -43,13 +50,28 @@ from langchain_core.messages import (
ToolMessage,
ToolMessageChunk,
)
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
class ChatGroq(BaseChatModel):
@ -390,6 +412,334 @@ class ChatGroq(BaseChatModel):
combined["system_fingerprint"] = system_fingerprint
return combined
def bind_functions(
self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind functions (and other objects) to this chat model.
Model is compatible with OpenAI function-calling API.
NOTE: Using bind_tools is recommended instead, as the `functions` and
`function_call` request parameters are officially deprecated.
Args:
functions: A list of function definitions to bind to this chat model.
Can be a dictionary, pydantic model, or callable. Pydantic
models and callables will be automatically converted to
their schema dictionary representation.
function_call: Which function to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any).
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
if function_call is not None:
function_call = (
{"name": function_call}
if isinstance(function_call, str)
and function_call not in ("auto", "none")
else function_call
)
if isinstance(function_call, dict) and len(formatted_functions) != 1:
raise ValueError(
"When specifying `function_call`, you must provide exactly one "
"function."
)
if (
isinstance(function_call, dict)
and formatted_functions[0]["name"] != function_call["name"]
):
raise ValueError(
f"Function call {function_call} was specified, but the only "
f"provided function was {formatted_functions[0]['name']}."
)
kwargs = {**kwargs, "function_call": function_call}
return super().bind(
functions=formatted_functions,
**kwargs,
)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "any", "none"], bool]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function,
"auto" to automatically determine which function to call
with the option to not call any function, "any" to enforce that some
function is called, or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice:
if isinstance(tool_choice, str) and (
tool_choice not in ("auto", "any", "none")
):
tool_choice = {"type": "function", "function": {"name": tool_choice}}
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1):
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if isinstance(tool_choice, dict) and (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
)
if isinstance(tool_choice, bool):
if len(tools) > 1:
raise ValueError(
"tool_choice can only be True when there is one tool. Received "
f"{len(tools)} tools."
)
tool_name = formatted_tools[0]["function"]["name"]
tool_choice = {
"type": "function",
"function": {"name": tool_name},
}
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
@beta()
def with_structured_output(
self,
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OpenAI function-calling spec.
method: The method for steering model generation, either "function_calling"
or "json_mode". If "function_calling" then the schema will be converted
to a OpenAI function and the returned model will make use of the
function-calling API. If "json_mode" then Groq's JSON mode will be
used. Note that if using "json_mode" then you must include instructions
for formatting the output into the desired schema into the model call.
include_raw: If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes any ChatModel input and returns as output:
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_groq import ChatGroq
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatGroq(temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> AnswerWithJustification(
# answer='A pound of bricks and a pound of feathers weigh the same.'
# justification="Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same."
# )
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True):
.. code-block:: python
from langchain_groq import ChatGroq
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatGroq(temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_01htjn3cspevxbqc1d7nkk8wab', 'function': {'arguments': '{"answer": "A pound of bricks and a pound of feathers weigh the same.", "justification": "Both a pound of bricks and a pound of feathers have been defined to have the same weight. The \'pound\' is a unit of weight, so any two things that are described as weighing a pound will weigh the same.", "unit": "pounds"}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}, id='run-456beee6-65f6-4e80-88af-a6065480822c-0'),
# 'parsed': AnswerWithJustification(answer='A pound of bricks and a pound of feathers weigh the same.', justification="Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same."),
# 'parsing_error': None
# }
Example: Function-calling, dict schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_groq import ChatGroq
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = ChatGroq(temperature=0)
structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'answer': 'A pound of bricks and a pound of feathers weigh the same.',
# 'justification': "Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same.", 'unit': 'pounds'}
# }
Example: JSON mode, Pydantic schema (method="json_mode", include_raw=True):
.. code-block::
from langchain_groq import ChatGroq
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
answer: str
justification: str
llm = ChatGroq(temperature=0)
structured_llm = llm.with_structured_output(
AnswerWithJustification,
method="json_mode",
include_raw=True
)
structured_llm.invoke(
"Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "A pound of bricks is the same weight as a pound of feathers.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The material being weighed does not affect the weight, only the volume or number of items being weighed."\n}', id='run-e5453bc5-5025-4833-95f9-4967bf6d5c4f-0'),
# 'parsed': AnswerWithJustification(answer='A pound of bricks is the same weight as a pound of feathers.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The material being weighed does not affect the weight, only the volume or number of items being weighed.'),
# 'parsing_error': None
# }
Example: JSON mode, no schema (schema=None, method="json_mode", include_raw=True):
.. code-block::
from langchain_groq import ChatGroq
llm = ChatGroq(temperature=0)
structured_llm = llm.with_structured_output(method="json_mode", include_raw=True)
structured_llm.invoke(
"Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "A pound of bricks is the same weight as a pound of feathers.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The material doesn\'t change the weight, only the volume or space that the material takes up."\n}', id='run-a4abbdb6-c20e-456f-bfff-da906a7e76b5-0'),
# 'parsed': {
# 'answer': 'A pound of bricks is the same weight as a pound of feathers.',
# 'justification': "Both a pound of bricks and a pound of feathers weigh one pound. The material doesn't change the weight, only the volume or space that the material takes up."},
# 'parsing_error': None
# }
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], tool_choice=True)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema)
if is_pydantic_schema
else JsonOutputParser()
)
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_format'. Received: '{method}'"
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)
class _FunctionCall(TypedDict):
name: str
#
# Type conversion helpers
@ -480,17 +830,18 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
Returns:
The LangChain message.
"""
id_ = _dict.get("id")
role = _dict.get("role")
if role == "user":
return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant":
content = _dict.get("content", "")
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs)
return AIMessage(content=content, id=id_, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
elif role == "function":

View File

@ -89,7 +89,9 @@ markers = [
]
filterwarnings = [
"error",
'ignore::ResourceWarning:',
'ignore:The function `with_structured_output` is in beta',
# Maintain support for pydantic 1.X
'default:The `dict` method is deprecated; use `model_dump` instead.*:DeprecationWarning',
'default:The `dict` method is deprecated; use `model_dump` instead:DeprecationWarning',
]
asyncio_mode = "auto"

View File

@ -1,15 +1,18 @@
"""Test ChatGroq chat model."""
import json
from typing import Any
import pytest
from langchain_core.messages import (
AIMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_groq import ChatGroq
from tests.unit_tests.fake.callbacks import (
@ -45,9 +48,9 @@ def test_invoke() -> None:
@pytest.mark.scheduled
async def test_ainvoke() -> None:
"""Test ainvoke tokens from ChatGroq."""
llm = ChatGroq(max_tokens=10)
chat = ChatGroq(max_tokens=10)
result = await llm.ainvoke("Welcome to the Groqetship!", config={"tags": ["foo"]})
result = await chat.ainvoke("Welcome to the Groqetship!", config={"tags": ["foo"]})
assert isinstance(result, BaseMessage)
assert isinstance(result.content, str)
@ -55,9 +58,9 @@ async def test_ainvoke() -> None:
@pytest.mark.scheduled
def test_batch() -> None:
"""Test batch tokens from ChatGroq."""
llm = ChatGroq(max_tokens=10)
chat = ChatGroq(max_tokens=10)
result = llm.batch(["Hello!", "Welcome to the Groqetship!"])
result = chat.batch(["Hello!", "Welcome to the Groqetship!"])
for token in result:
assert isinstance(token, BaseMessage)
assert isinstance(token.content, str)
@ -66,9 +69,9 @@ def test_batch() -> None:
@pytest.mark.scheduled
async def test_abatch() -> None:
"""Test abatch tokens from ChatGroq."""
llm = ChatGroq(max_tokens=10)
chat = ChatGroq(max_tokens=10)
result = await llm.abatch(["Hello!", "Welcome to the Groqetship!"])
result = await chat.abatch(["Hello!", "Welcome to the Groqetship!"])
for token in result:
assert isinstance(token, BaseMessage)
assert isinstance(token.content, str)
@ -77,9 +80,9 @@ async def test_abatch() -> None:
@pytest.mark.scheduled
async def test_stream() -> None:
"""Test streaming tokens from Groq."""
llm = ChatGroq(max_tokens=10)
chat = ChatGroq(max_tokens=10)
for token in llm.stream("Welcome to the Groqetship!"):
for token in chat.stream("Welcome to the Groqetship!"):
assert isinstance(token, BaseMessageChunk)
assert isinstance(token.content, str)
@ -87,9 +90,9 @@ async def test_stream() -> None:
@pytest.mark.scheduled
async def test_astream() -> None:
"""Test streaming tokens from Groq."""
llm = ChatGroq(max_tokens=10)
chat = ChatGroq(max_tokens=10)
async for token in llm.astream("Welcome to the Groqetship!"):
async for token in chat.astream("Welcome to the Groqetship!"):
assert isinstance(token, BaseMessageChunk)
assert isinstance(token.content, str)
@ -202,11 +205,11 @@ def test_streaming_generation_info() -> None:
temperature=0,
callbacks=[callback],
)
list(chat.stream("Respond with the single word Hello"))
list(chat.stream("Respond with the single word Hello", stop=["o"]))
generation = callback.saved_things["generation"]
# `Hello!` is two tokens, assert that that is what is returned
assert isinstance(generation, LLMResult)
assert generation.generations[0][0].text == "Hello"
assert generation.generations[0][0].text == "Hell"
def test_system_message() -> None:
@ -219,6 +222,75 @@ def test_system_message() -> None:
assert isinstance(response.content, str)
@pytest.mark.scheduled
def test_tool_choice() -> None:
"""Test that tool choice is respected."""
llm = ChatGroq()
class MyTool(BaseModel):
name: str
age: int
with_tool = llm.bind_tools([MyTool], tool_choice="MyTool")
resp = with_tool.invoke("Who was the 27 year old named Erick?")
assert isinstance(resp, AIMessage)
assert resp.content == "" # should just be tool call
tool_calls = resp.additional_kwargs["tool_calls"]
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["function"]["name"] == "MyTool"
assert json.loads(tool_call["function"]["arguments"]) == {
"age": 27,
"name": "Erick",
}
assert tool_call["type"] == "function"
@pytest.mark.scheduled
def test_tool_choice_bool() -> None:
"""Test that tool choice is respected just passing in True."""
llm = ChatGroq()
class MyTool(BaseModel):
name: str
age: int
with_tool = llm.bind_tools([MyTool], tool_choice=True)
resp = with_tool.invoke("Who was the 27 year old named Erick?")
assert isinstance(resp, AIMessage)
assert resp.content == "" # should just be tool call
tool_calls = resp.additional_kwargs["tool_calls"]
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["function"]["name"] == "MyTool"
assert json.loads(tool_call["function"]["arguments"]) == {
"age": 27,
"name": "Erick",
}
assert tool_call["type"] == "function"
@pytest.mark.scheduled
def test_json_mode_structured_output() -> None:
"""Test with_structured_output with json"""
class Joke(BaseModel):
"""Joke to tell user."""
setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")
chat = ChatGroq().with_structured_output(Joke, method="json_mode")
result = chat.invoke(
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
)
assert type(result) == Joke
assert len(result.setup) != 0
assert len(result.punchline) != 0
# Groq does not currently support N > 1
# @pytest.mark.scheduled
# def test_chat_multiple_completions() -> None:

View File

@ -16,7 +16,8 @@ from langchain_core.messages import (
from langchain_groq.chat_models import ChatGroq, _convert_dict_to_message
os.environ["GROQ_API_KEY"] = "fake-key"
if "GROQ_API_KEY" not in os.environ:
os.environ["GROQ_API_KEY"] = "fake-key"
def test_groq_model_param() -> None: