From f481cbb32d08980b7838e86cde71b692ea8fc135 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Thu, 29 Feb 2024 17:18:15 -0800 Subject: [PATCH] fireworks[patch]: Fix fireworks bind tools (#18352) Co-authored-by: Erick Friis --- .../langchain_fireworks/chat_models.py | 7 +- libs/partners/fireworks/poetry.lock | 30 ++++---- libs/partners/fireworks/pyproject.toml | 2 +- .../integration_tests/test_chat_models.py | 76 +++++++++++++++++++ 4 files changed, 98 insertions(+), 17 deletions(-) create mode 100644 libs/partners/fireworks/tests/integration_tests/test_chat_models.py diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 2b8414ef85..8844e379c3 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -625,7 +625,12 @@ class ChatFireworks(BaseChatModel): "tool_choice can only be True when there is one tool. Received " f"{len(tools)} tools." ) - tool_choice = formatted_tools[0] + tool_name = formatted_tools[0]["function"]["name"] + tool_choice = { + "type": "function", + "function": {"name": tool_name}, + } + kwargs["tool_choice"] = tool_choice return super().bind(tools=formatted_tools, **kwargs) diff --git a/libs/partners/fireworks/poetry.lock b/libs/partners/fireworks/poetry.lock index 8741feeb6e..50a1b7be1f 100644 --- a/libs/partners/fireworks/poetry.lock +++ b/libs/partners/fireworks/poetry.lock @@ -598,13 +598,13 @@ url = "../../core" [[package]] name = "langsmith" -version = "0.1.6" +version = "0.1.10" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langsmith-0.1.6-py3-none-any.whl", hash = "sha256:59b0905ee80a39cc385a5d2140dd699b1b246104eb5ee8735d4f5805400002bd"}, - {file = "langsmith-0.1.6.tar.gz", hash = "sha256:9a31b02edce0b9a1607fbd20af3bac2785f44926ce4499b7de82d7ea9fb96b81"}, + {file = "langsmith-0.1.10-py3-none-any.whl", hash = "sha256:2997a80aea60ed235d83502a7ccdc1f62ffb4dd6b3b7dd4218e8fa4de68a6725"}, + {file = "langsmith-0.1.10.tar.gz", hash = "sha256:13e7e8b52e694aa4003370cefbb9e79cce3540c65dbf1517902bf7aa4dbbb653"}, ] [package.dependencies] @@ -774,13 +774,13 @@ files = [ [[package]] name = "openai" -version = "1.12.0" +version = "1.13.3" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.12.0-py3-none-any.whl", hash = "sha256:a54002c814e05222e413664f651b5916714e4700d041d5cf5724d3ae1a3e3481"}, - {file = "openai-1.12.0.tar.gz", hash = "sha256:99c5d257d09ea6533d689d1cc77caa0ac679fa21efef8893d8b0832a86877f1b"}, + {file = "openai-1.13.3-py3-none-any.whl", hash = "sha256:5769b62abd02f350a8dd1a3a242d8972c947860654466171d60fb0972ae0a41c"}, + {file = "openai-1.13.3.tar.gz", hash = "sha256:ff6c6b3bc7327e715e4b3592a923a5a1c7519ff5dd764a83d69f633d49e77a7b"}, ] [package.dependencies] @@ -967,13 +967,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pydantic" -version = "2.6.2" +version = "2.6.3" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.6.2-py3-none-any.whl", hash = "sha256:37a5432e54b12fecaa1049c5195f3d860a10e01bdfd24f1840ef14bd0d3aeab3"}, - {file = "pydantic-2.6.2.tar.gz", hash = "sha256:a09be1c3d28f3abe37f8a78af58284b236a92ce520105ddc91a6d29ea1176ba7"}, + {file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"}, + {file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"}, ] [package.dependencies] @@ -1281,13 +1281,13 @@ files = [ [[package]] name = "sniffio" -version = "1.3.0" +version = "1.3.1" description = "Sniff out which async library your code is running under" optional = false python-versions = ">=3.7" files = [ - {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, - {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] [[package]] @@ -1365,13 +1365,13 @@ urllib3 = ">=2" [[package]] name = "typing-extensions" -version = "4.9.0" +version = "4.10.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"}, - {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"}, + {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, + {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, ] [[package]] diff --git a/libs/partners/fireworks/pyproject.toml b/libs/partners/fireworks/pyproject.toml index a46acaa1f7..9bc42aa748 100644 --- a/libs/partners/fireworks/pyproject.toml +++ b/libs/partners/fireworks/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-fireworks" -version = "0.0.2" +version = "0.1.0" description = "An integration package connecting Fireworks and LangChain" authors = [] readme = "README.md" diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py new file mode 100644 index 0000000000..1773173a2e --- /dev/null +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -0,0 +1,76 @@ +"""Test ChatFireworks API wrapper + +You will need FIREWORKS_API_KEY set in your environment to run these tests. +""" + +import json + +from langchain_core.messages import AIMessage +from langchain_core.pydantic_v1 import BaseModel + +from langchain_fireworks import ChatFireworks + + +def test_chat_fireworks_call() -> None: + """Test valid call to fireworks.""" + llm = ChatFireworks( + model="accounts/fireworks/models/firefunction-v1", temperature=0 + ) + + resp = llm.invoke("Hello!") + assert isinstance(resp, AIMessage) + + assert len(resp.content) > 0 + + +def test_tool_choice() -> None: + """Test that tool choice is respected.""" + llm = ChatFireworks( + model="accounts/fireworks/models/firefunction-v1", temperature=0 + ) + + class MyTool(BaseModel): + name: str + age: int + + with_tool = llm.bind_tools([MyTool], tool_choice="MyTool") + + resp = with_tool.invoke("Who was the 27 year old named Erick?") + assert isinstance(resp, AIMessage) + assert resp.content == "" # should just be tool call + tool_calls = resp.additional_kwargs["tool_calls"] + assert len(tool_calls) == 1 + tool_call = tool_calls[0] + assert tool_call["function"]["name"] == "MyTool" + assert json.loads(tool_call["function"]["arguments"]) == { + "age": 27, + "name": "Erick", + } + assert tool_call["type"] == "function" + + +def test_tool_choice_bool() -> None: + """Test that tool choice is respected just passing in True.""" + + llm = ChatFireworks( + model="accounts/fireworks/models/firefunction-v1", temperature=0 + ) + + class MyTool(BaseModel): + name: str + age: int + + with_tool = llm.bind_tools([MyTool], tool_choice=True) + + resp = with_tool.invoke("Who was the 27 year old named Erick?") + assert isinstance(resp, AIMessage) + assert resp.content == "" # should just be tool call + tool_calls = resp.additional_kwargs["tool_calls"] + assert len(tool_calls) == 1 + tool_call = tool_calls[0] + assert tool_call["function"]["name"] == "MyTool" + assert json.loads(tool_call["function"]["arguments"]) == { + "age": 27, + "name": "Erick", + } + assert tool_call["type"] == "function"