diff --git a/libs/partners/fireworks/Makefile b/libs/partners/fireworks/Makefile index c30170e04a..c474a01c03 100644 --- a/libs/partners/fireworks/Makefile +++ b/libs/partners/fireworks/Makefile @@ -5,11 +5,9 @@ all: help # Define a variable for the test file path. TEST_FILE ?= tests/unit_tests/ +integration_test integration_tests: TEST_FILE ?= tests/integration_tests/ -test: - poetry run pytest $(TEST_FILE) - -tests: +test tests integration_test integration_tests: poetry run pytest $(TEST_FILE) diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index e30257d5bf..a29a718771 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -7,8 +7,10 @@ import os from operator import itemgetter from typing import ( Any, + AsyncIterator, Callable, Dict, + Iterator, List, Literal, Mapping, @@ -24,11 +26,13 @@ from typing import ( from fireworks.client import AsyncFireworks, Fireworks # type: ignore from langchain_core._api import beta from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, + agenerate_from_stream, generate_from_stream, ) from langchain_core.messages import ( @@ -53,7 +57,7 @@ from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, PydanticToolsParser, ) -from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool @@ -344,6 +348,40 @@ class ChatFireworks(BaseChatModel): combined["system_fingerprint"] = system_fingerprint return combined + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + + default_chunk_class = AIMessageChunk + for chunk in self.client.create(messages=message_dicts, **params): + if not isinstance(chunk, dict): + chunk = chunk.dict() + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + default_chunk_class = chunk.__class__ + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info or None + ) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) + yield chunk + def _generate( self, messages: List[BaseMessage], @@ -400,6 +438,66 @@ class ChatFireworks(BaseChatModel): } return ChatResult(generations=generations, llm_output=llm_output) + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + + default_chunk_class = AIMessageChunk + async for chunk in self.async_client.acreate(messages=message_dicts, **params): + if not isinstance(chunk, dict): + chunk = chunk.dict() + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + default_chunk_class = chunk.__class__ + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info or None + ) + if run_manager: + await run_manager.on_llm_new_token( + token=chunk.text, chunk=chunk, logprobs=logprobs + ) + yield chunk + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ChatResult: + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + + message_dicts, params = self._create_message_dicts(messages, stop) + params = { + **params, + **({"stream": stream} if stream is not None else {}), + **kwargs, + } + response = await self.async_client.acreate(messages=message_dicts, **params) + return self._create_chat_result(response) + @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" diff --git a/libs/partners/fireworks/langchain_fireworks/llms.py b/libs/partners/fireworks/langchain_fireworks/llms.py index 548f0aceb6..1b6534fdbf 100644 --- a/libs/partners/fireworks/langchain_fireworks/llms.py +++ b/libs/partners/fireworks/langchain_fireworks/llms.py @@ -213,10 +213,5 @@ class Fireworks(LLM): ) response_json = await response.json() - - if response_json.get("status") != "finished": - err_msg = response_json.get("error", "Undefined Error") - raise Exception(err_msg) - output = self._format_output(response_json) return output diff --git a/libs/partners/fireworks/poetry.lock b/libs/partners/fireworks/poetry.lock index 6a70c54761..555a170dab 100644 --- a/libs/partners/fireworks/poetry.lock +++ b/libs/partners/fireworks/poetry.lock @@ -341,13 +341,13 @@ test = ["pytest (>=6)"] [[package]] name = "fireworks-ai" -version = "0.12.1" +version = "0.13.0" description = "Python client library for the Fireworks.ai Generative AI Platform" optional = false python-versions = ">=3.7" files = [ - {file = "fireworks-ai-0.12.1.tar.gz", hash = "sha256:77a3b3be243182548cb5f690f60528f09ca2a7223b871e47fc4e9d13a0df5c1b"}, - {file = "fireworks_ai-0.12.1-py3-none-any.whl", hash = "sha256:f78dc61f46c534ba045ad111fc7eeed6ea7ff022e7dce446dd23f56ebad371e7"}, + {file = "fireworks-ai-0.13.0.tar.gz", hash = "sha256:d6db1e60f65f237b6e87e3e9c028681be0ba77496df398db386ae2876dab54e0"}, + {file = "fireworks_ai-0.13.0-py3-none-any.whl", hash = "sha256:900559d7eeea8a86dc5789f9034b3873684a685a2e96b56a63e6be3a04803eb6"}, ] [package.dependencies] @@ -1149,13 +1149,13 @@ watchdog = ">=2.0.0" [[package]] name = "python-dateutil" -version = "2.8.2" +version = "2.9.0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, - {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, + {file = "python-dateutil-2.9.0.tar.gz", hash = "sha256:78e73e19c63f5b20ffa567001531680d939dc042bf7850431877645523c66709"}, + {file = "python_dateutil-2.9.0-py2.py3-none-any.whl", hash = "sha256:cbf2f1da5e6083ac2fbfd4da39a25f34312230110440f424a14c7558bb85d82e"}, ] [package.dependencies] @@ -1538,4 +1538,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "b9ee2bfb5053127cb29b8182baea395897744d7b5e8b985c42863133a26708ba" +content-hash = "ab5538b63e5d347dadcad268e135a5ca9fb5bc2edd2436dcee99c55a7ee4b609" diff --git a/libs/partners/fireworks/pyproject.toml b/libs/partners/fireworks/pyproject.toml index c2899a9486..bb6de01512 100644 --- a/libs/partners/fireworks/pyproject.toml +++ b/libs/partners/fireworks/pyproject.toml @@ -13,7 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" langchain-core = "^0.1.27" -fireworks-ai = ">=0.12.0,<1" +fireworks-ai = ">=0.13.0" openai = "^1.10.0" requests = "^2" aiohttp = "^3.9.1"