fireworks[patch]: Fix fireworks bind tools (#18352)

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/18355/head
William FH 5 months ago committed by GitHub
parent eefb49680f
commit f481cbb32d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -625,7 +625,12 @@ class ChatFireworks(BaseChatModel):
"tool_choice can only be True when there is one tool. Received "
f"{len(tools)} tools."
)
tool_choice = formatted_tools[0]
tool_name = formatted_tools[0]["function"]["name"]
tool_choice = {
"type": "function",
"function": {"name": tool_name},
}
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)

@ -598,13 +598,13 @@ url = "../../core"
[[package]]
name = "langsmith"
version = "0.1.6"
version = "0.1.10"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langsmith-0.1.6-py3-none-any.whl", hash = "sha256:59b0905ee80a39cc385a5d2140dd699b1b246104eb5ee8735d4f5805400002bd"},
{file = "langsmith-0.1.6.tar.gz", hash = "sha256:9a31b02edce0b9a1607fbd20af3bac2785f44926ce4499b7de82d7ea9fb96b81"},
{file = "langsmith-0.1.10-py3-none-any.whl", hash = "sha256:2997a80aea60ed235d83502a7ccdc1f62ffb4dd6b3b7dd4218e8fa4de68a6725"},
{file = "langsmith-0.1.10.tar.gz", hash = "sha256:13e7e8b52e694aa4003370cefbb9e79cce3540c65dbf1517902bf7aa4dbbb653"},
]
[package.dependencies]
@ -774,13 +774,13 @@ files = [
[[package]]
name = "openai"
version = "1.12.0"
version = "1.13.3"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.7.1"
files = [
{file = "openai-1.12.0-py3-none-any.whl", hash = "sha256:a54002c814e05222e413664f651b5916714e4700d041d5cf5724d3ae1a3e3481"},
{file = "openai-1.12.0.tar.gz", hash = "sha256:99c5d257d09ea6533d689d1cc77caa0ac679fa21efef8893d8b0832a86877f1b"},
{file = "openai-1.13.3-py3-none-any.whl", hash = "sha256:5769b62abd02f350a8dd1a3a242d8972c947860654466171d60fb0972ae0a41c"},
{file = "openai-1.13.3.tar.gz", hash = "sha256:ff6c6b3bc7327e715e4b3592a923a5a1c7519ff5dd764a83d69f633d49e77a7b"},
]
[package.dependencies]
@ -967,13 +967,13 @@ testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pydantic"
version = "2.6.2"
version = "2.6.3"
description = "Data validation using Python type hints"
optional = false
python-versions = ">=3.8"
files = [
{file = "pydantic-2.6.2-py3-none-any.whl", hash = "sha256:37a5432e54b12fecaa1049c5195f3d860a10e01bdfd24f1840ef14bd0d3aeab3"},
{file = "pydantic-2.6.2.tar.gz", hash = "sha256:a09be1c3d28f3abe37f8a78af58284b236a92ce520105ddc91a6d29ea1176ba7"},
{file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"},
{file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"},
]
[package.dependencies]
@ -1281,13 +1281,13 @@ files = [
[[package]]
name = "sniffio"
version = "1.3.0"
version = "1.3.1"
description = "Sniff out which async library your code is running under"
optional = false
python-versions = ">=3.7"
files = [
{file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"},
{file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
{file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"},
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
]
[[package]]
@ -1365,13 +1365,13 @@ urllib3 = ">=2"
[[package]]
name = "typing-extensions"
version = "4.9.0"
version = "4.10.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
{file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"},
{file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"},
{file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"},
{file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"},
]
[[package]]

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-fireworks"
version = "0.0.2"
version = "0.1.0"
description = "An integration package connecting Fireworks and LangChain"
authors = []
readme = "README.md"

@ -0,0 +1,76 @@
"""Test ChatFireworks API wrapper
You will need FIREWORKS_API_KEY set in your environment to run these tests.
"""
import json
from langchain_core.messages import AIMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_fireworks import ChatFireworks
def test_chat_fireworks_call() -> None:
"""Test valid call to fireworks."""
llm = ChatFireworks(
model="accounts/fireworks/models/firefunction-v1", temperature=0
)
resp = llm.invoke("Hello!")
assert isinstance(resp, AIMessage)
assert len(resp.content) > 0
def test_tool_choice() -> None:
"""Test that tool choice is respected."""
llm = ChatFireworks(
model="accounts/fireworks/models/firefunction-v1", temperature=0
)
class MyTool(BaseModel):
name: str
age: int
with_tool = llm.bind_tools([MyTool], tool_choice="MyTool")
resp = with_tool.invoke("Who was the 27 year old named Erick?")
assert isinstance(resp, AIMessage)
assert resp.content == "" # should just be tool call
tool_calls = resp.additional_kwargs["tool_calls"]
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["function"]["name"] == "MyTool"
assert json.loads(tool_call["function"]["arguments"]) == {
"age": 27,
"name": "Erick",
}
assert tool_call["type"] == "function"
def test_tool_choice_bool() -> None:
"""Test that tool choice is respected just passing in True."""
llm = ChatFireworks(
model="accounts/fireworks/models/firefunction-v1", temperature=0
)
class MyTool(BaseModel):
name: str
age: int
with_tool = llm.bind_tools([MyTool], tool_choice=True)
resp = with_tool.invoke("Who was the 27 year old named Erick?")
assert isinstance(resp, AIMessage)
assert resp.content == "" # should just be tool call
tool_calls = resp.additional_kwargs["tool_calls"]
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["function"]["name"] == "MyTool"
assert json.loads(tool_call["function"]["arguments"]) == {
"age": 27,
"name": "Erick",
}
assert tool_call["type"] == "function"
Loading…
Cancel
Save