From 6ce276e0993a1e076d3eaffd3323089f475e14f5 Mon Sep 17 00:00:00 2001 From: Cynthia Yang Date: Thu, 26 Oct 2023 13:01:08 -0700 Subject: [PATCH] Support Fireworks batching (#8) (#12052) Description * Add _generate and _agenerate to support Fireworks batching. * Add stop words test cases * Opt out retry mechanism Issue - Not applicable Dependencies - None Tag maintainer - @baskaryan --- .../langchain/chat_models/fireworks.py | 35 ++- libs/langchain/langchain/llms/fireworks.py | 199 +++++++++++++++--- .../chat_models/test_fireworks.py | 82 +++++++- .../integration_tests/llms/test_fireworks.py | 76 ++++++- 4 files changed, 347 insertions(+), 45 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index 0da316a968..ce0ce83b59 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -89,6 +89,7 @@ class ChatFireworks(BaseChatModel): ) fireworks_api_key: Optional[str] = None max_retries: int = 20 + use_retry: bool = True @property def lc_secrets(self) -> Dict[str, str]: @@ -134,7 +135,11 @@ class ChatFireworks(BaseChatModel): **self.model_kwargs, } response = completion_with_retry( - self, run_manager=run_manager, stop=stop, **params + self, + self.use_retry, + run_manager=run_manager, + stop=stop, + **params, ) return self._create_chat_result(response) @@ -152,7 +157,7 @@ class ChatFireworks(BaseChatModel): **self.model_kwargs, } response = await acompletion_with_retry( - self, run_manager=run_manager, stop=stop, **params + self, self.use_retry, run_manager=run_manager, stop=stop, **params ) return self._create_chat_result(response) @@ -195,7 +200,7 @@ class ChatFireworks(BaseChatModel): **self.model_kwargs, } for chunk in completion_with_retry( - self, run_manager=run_manager, stop=stop, **params + self, self.use_retry, run_manager=run_manager, stop=stop, **params ): choice = chunk.choices[0] chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class) @@ -224,7 +229,7 @@ class ChatFireworks(BaseChatModel): **self.model_kwargs, } async for chunk in await acompletion_with_retry_streaming( - self, run_manager=run_manager, stop=stop, **params + self, self.use_retry, run_manager=run_manager, stop=stop, **params ): choice = chunk.choices[0] chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class) @@ -238,8 +243,20 @@ class ChatFireworks(BaseChatModel): await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk) +def conditional_decorator( + condition: bool, decorator: Callable[[Any], Any] +) -> Callable[[Any], Any]: + def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]: + if condition: + return decorator(func) + return func + + return actual_decorator + + def completion_with_retry( llm: ChatFireworks, + use_retry: bool, *, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, @@ -249,7 +266,7 @@ def completion_with_retry( retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - @retry_decorator + @conditional_decorator(use_retry, retry_decorator) def _completion_with_retry(**kwargs: Any) -> Any: return fireworks.client.ChatCompletion.create( **kwargs, @@ -260,6 +277,7 @@ def completion_with_retry( async def acompletion_with_retry( llm: ChatFireworks, + use_retry: bool, *, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, @@ -269,7 +287,7 @@ async def acompletion_with_retry( retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - @retry_decorator + @conditional_decorator(use_retry, retry_decorator) async def _completion_with_retry(**kwargs: Any) -> Any: return await fireworks.client.ChatCompletion.acreate( **kwargs, @@ -280,6 +298,7 @@ async def acompletion_with_retry( async def acompletion_with_retry_streaming( llm: ChatFireworks, + use_retry: bool, *, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, @@ -289,7 +308,7 @@ async def acompletion_with_retry_streaming( retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - @retry_decorator + @conditional_decorator(use_retry, retry_decorator) async def _completion_with_retry(**kwargs: Any) -> Any: return fireworks.client.ChatCompletion.acreate( **kwargs, @@ -309,6 +328,8 @@ def _create_retry_decorator( errors = [ fireworks.client.error.RateLimitError, + fireworks.client.error.InternalServerError, + fireworks.client.error.BadGatewayError, fireworks.client.error.ServiceUnavailableError, ] return create_base_retry_decorator( diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 51ebb88009..676c8813b0 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -1,12 +1,14 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.llms.base import LLM, create_base_retry_decorator +from langchain.llms.base import BaseLLM, create_base_retry_decorator from langchain.pydantic_v1 import Field, root_validator -from langchain.schema.output import GenerationChunk +from langchain.schema.output import Generation, GenerationChunk, LLMResult from langchain.utils.env import get_from_dict_or_env @@ -23,7 +25,7 @@ def _stream_response_to_generation_chunk( ) -class Fireworks(LLM): +class Fireworks(BaseLLM): """Fireworks models.""" model: str = "accounts/fireworks/models/llama-v2-7b-chat" @@ -36,6 +38,8 @@ class Fireworks(LLM): ) fireworks_api_key: Optional[str] = None max_retries: int = 20 + batch_size: int = 20 + use_retry: bool = True @property def lc_secrets(self) -> Dict[str, str]: @@ -66,43 +70,92 @@ class Fireworks(LLM): """Return type of llm.""" return "fireworks" - def _call( + def _generate( self, - prompt: str, + prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: - """Run the LLM on the given prompt and input.""" - params: dict = { + ) -> LLMResult: + """Call out to Fireworks endpoint with k unique prompts. + Args: + prompts: The prompts to pass into the model. + stop: Optional list of stop words to use when generating. + Returns: + The full LLM output. + """ + params = { "model": self.model, - "prompt": prompt, **self.model_kwargs, } - response = completion_with_retry( - self, run_manager=run_manager, stop=stop, **params - ) - - return response.choices[0].text - - async def _acall( + sub_prompts = self.get_batch_prompts(prompts) + choices = [] + for _prompts in sub_prompts: + response = completion_with_retry_batching( + self, + self.use_retry, + prompt=_prompts, + run_manager=run_manager, + stop=stop, + **params, + ) + choices.extend(response) + + return self.create_llm_result(choices, prompts) + + async def _agenerate( self, - prompt: str, + prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: - """Run the LLM on the given prompt and input.""" + ) -> LLMResult: + """Call out to Fireworks endpoint async with k unique prompts.""" params = { "model": self.model, - "prompt": prompt, **self.model_kwargs, } - response = await acompletion_with_retry( - self, run_manager=run_manager, stop=stop, **params - ) - - return response.choices[0].text + sub_prompts = self.get_batch_prompts(prompts) + choices = [] + for _prompts in sub_prompts: + response = await acompletion_with_retry_batching( + self, + self.use_retry, + prompt=_prompts, + run_manager=run_manager, + stop=stop, + **params, + ) + choices.extend(response) + + return self.create_llm_result(choices, prompts) + + def get_batch_prompts( + self, + prompts: List[str], + ) -> List[List[str]]: + """Get the sub prompts for llm call.""" + sub_prompts = [ + prompts[i : i + self.batch_size] + for i in range(0, len(prompts), self.batch_size) + ] + return sub_prompts + + def create_llm_result(self, choices: Any, prompts: List[str]) -> LLMResult: + """Create the LLMResult from the choices and prompts.""" + generations = [] + for i, _ in enumerate(prompts): + sub_choices = choices[i : (i + 1)] + generations.append( + [ + Generation( + text=choice.__dict__["choices"][0].text, + ) + for choice in sub_choices + ] + ) + llm_output = {"model": self.model} + return LLMResult(generations=generations, llm_output=llm_output) def _stream( self, @@ -118,7 +171,7 @@ class Fireworks(LLM): **self.model_kwargs, } for stream_resp in completion_with_retry( - self, run_manager=run_manager, stop=stop, **params + self, self.use_retry, run_manager=run_manager, stop=stop, **params ): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk @@ -139,7 +192,7 @@ class Fireworks(LLM): **self.model_kwargs, } async for stream_resp in await acompletion_with_retry_streaming( - self, run_manager=run_manager, stop=stop, **params + self, self.use_retry, run_manager=run_manager, stop=stop, **params ): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk @@ -147,8 +200,20 @@ class Fireworks(LLM): await run_manager.on_llm_new_token(chunk.text, chunk=chunk) +def conditional_decorator( + condition: bool, decorator: Callable[[Any], Any] +) -> Callable[[Any], Any]: + def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]: + if condition: + return decorator(func) + return func + + return actual_decorator + + def completion_with_retry( llm: Fireworks, + use_retry: bool, *, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, @@ -158,7 +223,7 @@ def completion_with_retry( retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - @retry_decorator + @conditional_decorator(use_retry, retry_decorator) def _completion_with_retry(**kwargs: Any) -> Any: return fireworks.client.Completion.create( **kwargs, @@ -169,6 +234,7 @@ def completion_with_retry( async def acompletion_with_retry( llm: Fireworks, + use_retry: bool, *, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, @@ -178,7 +244,7 @@ async def acompletion_with_retry( retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - @retry_decorator + @conditional_decorator(use_retry, retry_decorator) async def _completion_with_retry(**kwargs: Any) -> Any: return await fireworks.client.Completion.acreate( **kwargs, @@ -187,8 +253,79 @@ async def acompletion_with_retry( return await _completion_with_retry(**kwargs) +def completion_with_retry_batching( + llm: Fireworks, + use_retry: bool, + *, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + import fireworks.client + + prompt = kwargs["prompt"] + del kwargs["prompt"] + + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @conditional_decorator(use_retry, retry_decorator) + def _completion_with_retry(prompt: str) -> Any: + return fireworks.client.Completion.create(**kwargs, prompt=prompt) + + def batch_sync_run() -> List: + with ThreadPoolExecutor() as executor: + results = list(executor.map(_completion_with_retry, prompt)) + return results + + return batch_sync_run() + + +async def acompletion_with_retry_batching( + llm: Fireworks, + use_retry: bool, + *, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + import fireworks.client + + prompt = kwargs["prompt"] + del kwargs["prompt"] + + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @conditional_decorator(use_retry, retry_decorator) + async def _completion_with_retry(prompt: str) -> Any: + return await fireworks.client.Completion.acreate(**kwargs, prompt=prompt) + + def run_coroutine_in_new_loop( + coroutine_func: Any, *args: Dict, **kwargs: Dict + ) -> Any: + new_loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(new_loop) + return new_loop.run_until_complete(coroutine_func(*args, **kwargs)) + finally: + new_loop.close() + + async def batch_sync_run() -> List: + with ThreadPoolExecutor() as executor: + results = list( + executor.map( + run_coroutine_in_new_loop, + [_completion_with_retry] * len(prompt), + prompt, + ) + ) + return results + + return await batch_sync_run() + + async def acompletion_with_retry_streaming( llm: Fireworks, + use_retry: bool, *, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, @@ -198,7 +335,7 @@ async def acompletion_with_retry_streaming( retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - @retry_decorator + @conditional_decorator(use_retry, retry_decorator) async def _completion_with_retry(**kwargs: Any) -> Any: return fireworks.client.Completion.acreate( **kwargs, @@ -219,6 +356,8 @@ def _create_retry_decorator( errors = [ fireworks.client.error.RateLimitError, + fireworks.client.error.InternalServerError, + fireworks.client.error.BadGatewayError, fireworks.client.error.ServiceUnavailableError, ] return create_base_retry_decorator( diff --git a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py index ec43bfd582..2bb4409ec4 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py @@ -3,11 +3,7 @@ import pytest from langchain.chat_models.fireworks import ChatFireworks -from langchain.schema import ( - ChatGeneration, - ChatResult, - LLMResult, -) +from langchain.schema import ChatGeneration, ChatResult, LLMResult from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage @@ -72,6 +68,64 @@ def test_chat_fireworks_llm_output_contains_model_id() -> None: assert llm_result.llm_output["model"] == chat.model +def test_fireworks_invoke() -> None: + """Tests chat completion with invoke""" + chat = ChatFireworks() + result = chat.invoke("How is the weather in New York today?", stop=[","]) + assert isinstance(result.content, str) + assert result.content[-1] == "," + + +@pytest.mark.asyncio +async def test_fireworks_ainvoke() -> None: + """Tests chat completion with invoke""" + chat = ChatFireworks() + result = await chat.ainvoke("How is the weather in New York today?", stop=[","]) + assert isinstance(result.content, str) + assert result.content[-1] == "," + + +def test_fireworks_batch() -> None: + """Test batch tokens from ChatFireworks.""" + chat = ChatFireworks() + result = chat.batch( + [ + "What is the weather in Redwood City, CA today", + "What is the weather in Redwood City, CA today", + "What is the weather in Redwood City, CA today", + "What is the weather in Redwood City, CA today", + "What is the weather in Redwood City, CA today", + "What is the weather in Redwood City, CA today", + ], + config={"max_concurrency": 5}, + stop=[","], + ) + for token in result: + assert isinstance(token.content, str) + assert token.content[-1] == "," + + +@pytest.mark.asyncio +async def test_fireworks_abatch() -> None: + """Test batch tokens from ChatFireworks.""" + chat = ChatFireworks() + result = await chat.abatch( + [ + "What is the weather in Redwood City, CA today", + "What is the weather in Redwood City, CA today", + "What is the weather in Redwood City, CA today", + "What is the weather in Redwood City, CA today", + "What is the weather in Redwood City, CA today", + "What is the weather in Redwood City, CA today", + ], + config={"max_concurrency": 5}, + stop=[","], + ) + for token in result: + assert isinstance(token.content, str) + assert token.content[-1] == "," + + def test_fireworks_streaming() -> None: """Test streaming tokens from Fireworks.""" llm = ChatFireworks() @@ -80,6 +134,17 @@ def test_fireworks_streaming() -> None: assert isinstance(token.content, str) +def test_fireworks_streaming_stop_words() -> None: + """Test streaming tokens with stop words.""" + llm = ChatFireworks() + + last_token = "" + for token in llm.stream("I'm Pickle Rick", stop=[","]): + last_token = token.content + assert isinstance(token.content, str) + assert last_token[-1] == "," + + @pytest.mark.asyncio async def test_chat_fireworks_agenerate() -> None: """Test ChatFireworks wrapper with generate.""" @@ -101,5 +166,10 @@ async def test_fireworks_astream() -> None: """Test streaming tokens from Fireworks.""" llm = ChatFireworks() - async for token in llm.astream("Who's the best quarterback in the NFL?"): + last_token = "" + async for token in llm.astream( + "Who's the best quarterback in the NFL?", stop=[","] + ): + last_token = token.content assert isinstance(token.content, str) + assert last_token[-1] == "," diff --git a/libs/langchain/tests/integration_tests/llms/test_fireworks.py b/libs/langchain/tests/integration_tests/llms/test_fireworks.py index cbc7473665..c1068631ce 100644 --- a/libs/langchain/tests/integration_tests/llms/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/llms/test_fireworks.py @@ -16,7 +16,7 @@ from langchain.schema import LLMResult def test_fireworks_call() -> None: """Test valid call to fireworks.""" llm = Fireworks() - output = llm("Who's the best quarterback in the NFL?") + output = llm("How is the weather in New York today?") assert isinstance(output, str) @@ -41,6 +41,60 @@ def test_fireworks_model_param() -> None: assert llm.model == "foo" +def test_fireworks_invoke() -> None: + """Tests completion with invoke""" + llm = Fireworks() + output = llm.invoke("How is the weather in New York today?", stop=[","]) + assert isinstance(output, str) + assert output[-1] == "," + + +@pytest.mark.asyncio +async def test_fireworks_ainvoke() -> None: + """Tests completion with invoke""" + llm = Fireworks() + output = await llm.ainvoke("How is the weather in New York today?", stop=[","]) + assert isinstance(output, str) + assert output[-1] == "," + + +def test_fireworks_batch() -> None: + """Tests completion with invoke""" + llm = Fireworks() + output = llm.batch( + [ + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + ], + stop=[","], + ) + for token in output: + assert isinstance(token, str) + assert token[-1] == "," + + +@pytest.mark.asyncio +async def test_fireworks_abatch() -> None: + """Tests completion with invoke""" + llm = Fireworks() + output = await llm.abatch( + [ + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + "How is the weather in New York today?", + ], + stop=[","], + ) + for token in output: + assert isinstance(token, str) + assert token[-1] == "," + + def test_fireworks_multiple_prompts() -> None: """Test completion with multiple prompts.""" llm = Fireworks() @@ -60,13 +114,31 @@ def test_fireworks_streaming() -> None: assert isinstance(token, str) +def test_fireworks_streaming_stop_words() -> None: + """Test stream completion with stop words.""" + llm = Fireworks() + generator = llm.stream("Who's the best quarterback in the NFL?", stop=[","]) + assert isinstance(generator, Generator) + + last_token = "" + for token in generator: + last_token = token + assert isinstance(token, str) + assert last_token[-1] == "," + + @pytest.mark.asyncio async def test_fireworks_streaming_async() -> None: """Test stream completion.""" llm = Fireworks() - async for token in llm.astream("Who's the best quarterback in the NFL?"): + last_token = "" + async for token in llm.astream( + "Who's the best quarterback in the NFL?", stop=[","] + ): + last_token = token assert isinstance(token, str) + assert last_token[-1] == "," @pytest.mark.asyncio