mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
fireworks[patch]: Fix fireworks bind tools (#18352)
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
eefb49680f
commit
f481cbb32d
@ -625,7 +625,12 @@ class ChatFireworks(BaseChatModel):
|
||||
"tool_choice can only be True when there is one tool. Received "
|
||||
f"{len(tools)} tools."
|
||||
)
|
||||
tool_choice = formatted_tools[0]
|
||||
tool_name = formatted_tools[0]["function"]["name"]
|
||||
tool_choice = {
|
||||
"type": "function",
|
||||
"function": {"name": tool_name},
|
||||
}
|
||||
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
|
30
libs/partners/fireworks/poetry.lock
generated
30
libs/partners/fireworks/poetry.lock
generated
@ -598,13 +598,13 @@ url = "../../core"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.1.6"
|
||||
version = "0.1.10"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
files = [
|
||||
{file = "langsmith-0.1.6-py3-none-any.whl", hash = "sha256:59b0905ee80a39cc385a5d2140dd699b1b246104eb5ee8735d4f5805400002bd"},
|
||||
{file = "langsmith-0.1.6.tar.gz", hash = "sha256:9a31b02edce0b9a1607fbd20af3bac2785f44926ce4499b7de82d7ea9fb96b81"},
|
||||
{file = "langsmith-0.1.10-py3-none-any.whl", hash = "sha256:2997a80aea60ed235d83502a7ccdc1f62ffb4dd6b3b7dd4218e8fa4de68a6725"},
|
||||
{file = "langsmith-0.1.10.tar.gz", hash = "sha256:13e7e8b52e694aa4003370cefbb9e79cce3540c65dbf1517902bf7aa4dbbb653"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -774,13 +774,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.12.0"
|
||||
version = "1.13.3"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = false
|
||||
python-versions = ">=3.7.1"
|
||||
files = [
|
||||
{file = "openai-1.12.0-py3-none-any.whl", hash = "sha256:a54002c814e05222e413664f651b5916714e4700d041d5cf5724d3ae1a3e3481"},
|
||||
{file = "openai-1.12.0.tar.gz", hash = "sha256:99c5d257d09ea6533d689d1cc77caa0ac679fa21efef8893d8b0832a86877f1b"},
|
||||
{file = "openai-1.13.3-py3-none-any.whl", hash = "sha256:5769b62abd02f350a8dd1a3a242d8972c947860654466171d60fb0972ae0a41c"},
|
||||
{file = "openai-1.13.3.tar.gz", hash = "sha256:ff6c6b3bc7327e715e4b3592a923a5a1c7519ff5dd764a83d69f633d49e77a7b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -967,13 +967,13 @@ testing = ["pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.6.2"
|
||||
version = "2.6.3"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic-2.6.2-py3-none-any.whl", hash = "sha256:37a5432e54b12fecaa1049c5195f3d860a10e01bdfd24f1840ef14bd0d3aeab3"},
|
||||
{file = "pydantic-2.6.2.tar.gz", hash = "sha256:a09be1c3d28f3abe37f8a78af58284b236a92ce520105ddc91a6d29ea1176ba7"},
|
||||
{file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"},
|
||||
{file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1281,13 +1281,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.0"
|
||||
version = "1.3.1"
|
||||
description = "Sniff out which async library your code is running under"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"},
|
||||
{file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
|
||||
{file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"},
|
||||
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1365,13 +1365,13 @@ urllib3 = ">=2"
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.9.0"
|
||||
version = "4.10.0"
|
||||
description = "Backported and Experimental Type Hints for Python 3.8+"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"},
|
||||
{file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"},
|
||||
{file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"},
|
||||
{file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-fireworks"
|
||||
version = "0.0.2"
|
||||
version = "0.1.0"
|
||||
description = "An integration package connecting Fireworks and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
@ -0,0 +1,76 @@
|
||||
"""Test ChatFireworks API wrapper
|
||||
|
||||
You will need FIREWORKS_API_KEY set in your environment to run these tests.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_fireworks import ChatFireworks
|
||||
|
||||
|
||||
def test_chat_fireworks_call() -> None:
|
||||
"""Test valid call to fireworks."""
|
||||
llm = ChatFireworks(
|
||||
model="accounts/fireworks/models/firefunction-v1", temperature=0
|
||||
)
|
||||
|
||||
resp = llm.invoke("Hello!")
|
||||
assert isinstance(resp, AIMessage)
|
||||
|
||||
assert len(resp.content) > 0
|
||||
|
||||
|
||||
def test_tool_choice() -> None:
|
||||
"""Test that tool choice is respected."""
|
||||
llm = ChatFireworks(
|
||||
model="accounts/fireworks/models/firefunction-v1", temperature=0
|
||||
)
|
||||
|
||||
class MyTool(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
with_tool = llm.bind_tools([MyTool], tool_choice="MyTool")
|
||||
|
||||
resp = with_tool.invoke("Who was the 27 year old named Erick?")
|
||||
assert isinstance(resp, AIMessage)
|
||||
assert resp.content == "" # should just be tool call
|
||||
tool_calls = resp.additional_kwargs["tool_calls"]
|
||||
assert len(tool_calls) == 1
|
||||
tool_call = tool_calls[0]
|
||||
assert tool_call["function"]["name"] == "MyTool"
|
||||
assert json.loads(tool_call["function"]["arguments"]) == {
|
||||
"age": 27,
|
||||
"name": "Erick",
|
||||
}
|
||||
assert tool_call["type"] == "function"
|
||||
|
||||
|
||||
def test_tool_choice_bool() -> None:
|
||||
"""Test that tool choice is respected just passing in True."""
|
||||
|
||||
llm = ChatFireworks(
|
||||
model="accounts/fireworks/models/firefunction-v1", temperature=0
|
||||
)
|
||||
|
||||
class MyTool(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
with_tool = llm.bind_tools([MyTool], tool_choice=True)
|
||||
|
||||
resp = with_tool.invoke("Who was the 27 year old named Erick?")
|
||||
assert isinstance(resp, AIMessage)
|
||||
assert resp.content == "" # should just be tool call
|
||||
tool_calls = resp.additional_kwargs["tool_calls"]
|
||||
assert len(tool_calls) == 1
|
||||
tool_call = tool_calls[0]
|
||||
assert tool_call["function"]["name"] == "MyTool"
|
||||
assert json.loads(tool_call["function"]["arguments"]) == {
|
||||
"age": 27,
|
||||
"name": "Erick",
|
||||
}
|
||||
assert tool_call["type"] == "function"
|
Loading…
Reference in New Issue
Block a user