Support Fireworks batching (#8) (#12052)

Description

* Add _generate and _agenerate to support Fireworks batching.
* Add stop words test cases
* Opt out retry mechanism

Issue - Not applicable
Dependencies - None
Tag maintainer - @baskaryan
pull/12369/head
Cynthia Yang 9 months ago committed by GitHub
parent 3fbb2f3e52
commit 6ce276e099
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -89,6 +89,7 @@ class ChatFireworks(BaseChatModel):
)
fireworks_api_key: Optional[str] = None
max_retries: int = 20
use_retry: bool = True
@property
def lc_secrets(self) -> Dict[str, str]:
@ -134,7 +135,11 @@ class ChatFireworks(BaseChatModel):
**self.model_kwargs,
}
response = completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
self,
self.use_retry,
run_manager=run_manager,
stop=stop,
**params,
)
return self._create_chat_result(response)
@ -152,7 +157,7 @@ class ChatFireworks(BaseChatModel):
**self.model_kwargs,
}
response = await acompletion_with_retry(
self, run_manager=run_manager, stop=stop, **params
self, self.use_retry, run_manager=run_manager, stop=stop, **params
)
return self._create_chat_result(response)
@ -195,7 +200,7 @@ class ChatFireworks(BaseChatModel):
**self.model_kwargs,
}
for chunk in completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
self, self.use_retry, run_manager=run_manager, stop=stop, **params
):
choice = chunk.choices[0]
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
@ -224,7 +229,7 @@ class ChatFireworks(BaseChatModel):
**self.model_kwargs,
}
async for chunk in await acompletion_with_retry_streaming(
self, run_manager=run_manager, stop=stop, **params
self, self.use_retry, run_manager=run_manager, stop=stop, **params
):
choice = chunk.choices[0]
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
@ -238,8 +243,20 @@ class ChatFireworks(BaseChatModel):
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
def conditional_decorator(
condition: bool, decorator: Callable[[Any], Any]
) -> Callable[[Any], Any]:
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
if condition:
return decorator(func)
return func
return actual_decorator
def completion_with_retry(
llm: ChatFireworks,
use_retry: bool,
*,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
@ -249,7 +266,7 @@ def completion_with_retry(
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
@conditional_decorator(use_retry, retry_decorator)
def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.ChatCompletion.create(
**kwargs,
@ -260,6 +277,7 @@ def completion_with_retry(
async def acompletion_with_retry(
llm: ChatFireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
@ -269,7 +287,7 @@ async def acompletion_with_retry(
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(**kwargs: Any) -> Any:
return await fireworks.client.ChatCompletion.acreate(
**kwargs,
@ -280,6 +298,7 @@ async def acompletion_with_retry(
async def acompletion_with_retry_streaming(
llm: ChatFireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
@ -289,7 +308,7 @@ async def acompletion_with_retry_streaming(
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.ChatCompletion.acreate(
**kwargs,
@ -309,6 +328,8 @@ def _create_retry_decorator(
errors = [
fireworks.client.error.RateLimitError,
fireworks.client.error.InternalServerError,
fireworks.client.error.BadGatewayError,
fireworks.client.error.ServiceUnavailableError,
]
return create_base_retry_decorator(

@ -1,12 +1,14 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM, create_base_retry_decorator
from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.output import GenerationChunk
from langchain.schema.output import Generation, GenerationChunk, LLMResult
from langchain.utils.env import get_from_dict_or_env
@ -23,7 +25,7 @@ def _stream_response_to_generation_chunk(
)
class Fireworks(LLM):
class Fireworks(BaseLLM):
"""Fireworks models."""
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
@ -36,6 +38,8 @@ class Fireworks(LLM):
)
fireworks_api_key: Optional[str] = None
max_retries: int = 20
batch_size: int = 20
use_retry: bool = True
@property
def lc_secrets(self) -> Dict[str, str]:
@ -66,43 +70,92 @@ class Fireworks(LLM):
"""Return type of llm."""
return "fireworks"
def _call(
def _generate(
self,
prompt: str,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given prompt and input."""
params: dict = {
) -> LLMResult:
"""Call out to Fireworks endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
"""
params = {
"model": self.model,
"prompt": prompt,
**self.model_kwargs,
}
response = completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
)
return response.choices[0].text
async def _acall(
sub_prompts = self.get_batch_prompts(prompts)
choices = []
for _prompts in sub_prompts:
response = completion_with_retry_batching(
self,
self.use_retry,
prompt=_prompts,
run_manager=run_manager,
stop=stop,
**params,
)
choices.extend(response)
return self.create_llm_result(choices, prompts)
async def _agenerate(
self,
prompt: str,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given prompt and input."""
) -> LLMResult:
"""Call out to Fireworks endpoint async with k unique prompts."""
params = {
"model": self.model,
"prompt": prompt,
**self.model_kwargs,
}
response = await acompletion_with_retry(
self, run_manager=run_manager, stop=stop, **params
)
return response.choices[0].text
sub_prompts = self.get_batch_prompts(prompts)
choices = []
for _prompts in sub_prompts:
response = await acompletion_with_retry_batching(
self,
self.use_retry,
prompt=_prompts,
run_manager=run_manager,
stop=stop,
**params,
)
choices.extend(response)
return self.create_llm_result(choices, prompts)
def get_batch_prompts(
self,
prompts: List[str],
) -> List[List[str]]:
"""Get the sub prompts for llm call."""
sub_prompts = [
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
return sub_prompts
def create_llm_result(self, choices: Any, prompts: List[str]) -> LLMResult:
"""Create the LLMResult from the choices and prompts."""
generations = []
for i, _ in enumerate(prompts):
sub_choices = choices[i : (i + 1)]
generations.append(
[
Generation(
text=choice.__dict__["choices"][0].text,
)
for choice in sub_choices
]
)
llm_output = {"model": self.model}
return LLMResult(generations=generations, llm_output=llm_output)
def _stream(
self,
@ -118,7 +171,7 @@ class Fireworks(LLM):
**self.model_kwargs,
}
for stream_resp in completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
self, self.use_retry, run_manager=run_manager, stop=stop, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
@ -139,7 +192,7 @@ class Fireworks(LLM):
**self.model_kwargs,
}
async for stream_resp in await acompletion_with_retry_streaming(
self, run_manager=run_manager, stop=stop, **params
self, self.use_retry, run_manager=run_manager, stop=stop, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
@ -147,8 +200,20 @@ class Fireworks(LLM):
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
def conditional_decorator(
condition: bool, decorator: Callable[[Any], Any]
) -> Callable[[Any], Any]:
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
if condition:
return decorator(func)
return func
return actual_decorator
def completion_with_retry(
llm: Fireworks,
use_retry: bool,
*,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
@ -158,7 +223,7 @@ def completion_with_retry(
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
@conditional_decorator(use_retry, retry_decorator)
def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.Completion.create(
**kwargs,
@ -169,6 +234,7 @@ def completion_with_retry(
async def acompletion_with_retry(
llm: Fireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
@ -178,7 +244,7 @@ async def acompletion_with_retry(
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(**kwargs: Any) -> Any:
return await fireworks.client.Completion.acreate(
**kwargs,
@ -187,8 +253,79 @@ async def acompletion_with_retry(
return await _completion_with_retry(**kwargs)
def completion_with_retry_batching(
llm: Fireworks,
use_retry: bool,
*,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
import fireworks.client
prompt = kwargs["prompt"]
del kwargs["prompt"]
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@conditional_decorator(use_retry, retry_decorator)
def _completion_with_retry(prompt: str) -> Any:
return fireworks.client.Completion.create(**kwargs, prompt=prompt)
def batch_sync_run() -> List:
with ThreadPoolExecutor() as executor:
results = list(executor.map(_completion_with_retry, prompt))
return results
return batch_sync_run()
async def acompletion_with_retry_batching(
llm: Fireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
import fireworks.client
prompt = kwargs["prompt"]
del kwargs["prompt"]
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(prompt: str) -> Any:
return await fireworks.client.Completion.acreate(**kwargs, prompt=prompt)
def run_coroutine_in_new_loop(
coroutine_func: Any, *args: Dict, **kwargs: Dict
) -> Any:
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(coroutine_func(*args, **kwargs))
finally:
new_loop.close()
async def batch_sync_run() -> List:
with ThreadPoolExecutor() as executor:
results = list(
executor.map(
run_coroutine_in_new_loop,
[_completion_with_retry] * len(prompt),
prompt,
)
)
return results
return await batch_sync_run()
async def acompletion_with_retry_streaming(
llm: Fireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
@ -198,7 +335,7 @@ async def acompletion_with_retry_streaming(
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.Completion.acreate(
**kwargs,
@ -219,6 +356,8 @@ def _create_retry_decorator(
errors = [
fireworks.client.error.RateLimitError,
fireworks.client.error.InternalServerError,
fireworks.client.error.BadGatewayError,
fireworks.client.error.ServiceUnavailableError,
]
return create_base_retry_decorator(

@ -3,11 +3,7 @@
import pytest
from langchain.chat_models.fireworks import ChatFireworks
from langchain.schema import (
ChatGeneration,
ChatResult,
LLMResult,
)
from langchain.schema import ChatGeneration, ChatResult, LLMResult
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
@ -72,6 +68,64 @@ def test_chat_fireworks_llm_output_contains_model_id() -> None:
assert llm_result.llm_output["model"] == chat.model
def test_fireworks_invoke() -> None:
"""Tests chat completion with invoke"""
chat = ChatFireworks()
result = chat.invoke("How is the weather in New York today?", stop=[","])
assert isinstance(result.content, str)
assert result.content[-1] == ","
@pytest.mark.asyncio
async def test_fireworks_ainvoke() -> None:
"""Tests chat completion with invoke"""
chat = ChatFireworks()
result = await chat.ainvoke("How is the weather in New York today?", stop=[","])
assert isinstance(result.content, str)
assert result.content[-1] == ","
def test_fireworks_batch() -> None:
"""Test batch tokens from ChatFireworks."""
chat = ChatFireworks()
result = chat.batch(
[
"What is the weather in Redwood City, CA today",
"What is the weather in Redwood City, CA today",
"What is the weather in Redwood City, CA today",
"What is the weather in Redwood City, CA today",
"What is the weather in Redwood City, CA today",
"What is the weather in Redwood City, CA today",
],
config={"max_concurrency": 5},
stop=[","],
)
for token in result:
assert isinstance(token.content, str)
assert token.content[-1] == ","
@pytest.mark.asyncio
async def test_fireworks_abatch() -> None:
"""Test batch tokens from ChatFireworks."""
chat = ChatFireworks()
result = await chat.abatch(
[
"What is the weather in Redwood City, CA today",
"What is the weather in Redwood City, CA today",
"What is the weather in Redwood City, CA today",
"What is the weather in Redwood City, CA today",
"What is the weather in Redwood City, CA today",
"What is the weather in Redwood City, CA today",
],
config={"max_concurrency": 5},
stop=[","],
)
for token in result:
assert isinstance(token.content, str)
assert token.content[-1] == ","
def test_fireworks_streaming() -> None:
"""Test streaming tokens from Fireworks."""
llm = ChatFireworks()
@ -80,6 +134,17 @@ def test_fireworks_streaming() -> None:
assert isinstance(token.content, str)
def test_fireworks_streaming_stop_words() -> None:
"""Test streaming tokens with stop words."""
llm = ChatFireworks()
last_token = ""
for token in llm.stream("I'm Pickle Rick", stop=[","]):
last_token = token.content
assert isinstance(token.content, str)
assert last_token[-1] == ","
@pytest.mark.asyncio
async def test_chat_fireworks_agenerate() -> None:
"""Test ChatFireworks wrapper with generate."""
@ -101,5 +166,10 @@ async def test_fireworks_astream() -> None:
"""Test streaming tokens from Fireworks."""
llm = ChatFireworks()
async for token in llm.astream("Who's the best quarterback in the NFL?"):
last_token = ""
async for token in llm.astream(
"Who's the best quarterback in the NFL?", stop=[","]
):
last_token = token.content
assert isinstance(token.content, str)
assert last_token[-1] == ","

@ -16,7 +16,7 @@ from langchain.schema import LLMResult
def test_fireworks_call() -> None:
"""Test valid call to fireworks."""
llm = Fireworks()
output = llm("Who's the best quarterback in the NFL?")
output = llm("How is the weather in New York today?")
assert isinstance(output, str)
@ -41,6 +41,60 @@ def test_fireworks_model_param() -> None:
assert llm.model == "foo"
def test_fireworks_invoke() -> None:
"""Tests completion with invoke"""
llm = Fireworks()
output = llm.invoke("How is the weather in New York today?", stop=[","])
assert isinstance(output, str)
assert output[-1] == ","
@pytest.mark.asyncio
async def test_fireworks_ainvoke() -> None:
"""Tests completion with invoke"""
llm = Fireworks()
output = await llm.ainvoke("How is the weather in New York today?", stop=[","])
assert isinstance(output, str)
assert output[-1] == ","
def test_fireworks_batch() -> None:
"""Tests completion with invoke"""
llm = Fireworks()
output = llm.batch(
[
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
],
stop=[","],
)
for token in output:
assert isinstance(token, str)
assert token[-1] == ","
@pytest.mark.asyncio
async def test_fireworks_abatch() -> None:
"""Tests completion with invoke"""
llm = Fireworks()
output = await llm.abatch(
[
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
"How is the weather in New York today?",
],
stop=[","],
)
for token in output:
assert isinstance(token, str)
assert token[-1] == ","
def test_fireworks_multiple_prompts() -> None:
"""Test completion with multiple prompts."""
llm = Fireworks()
@ -60,13 +114,31 @@ def test_fireworks_streaming() -> None:
assert isinstance(token, str)
def test_fireworks_streaming_stop_words() -> None:
"""Test stream completion with stop words."""
llm = Fireworks()
generator = llm.stream("Who's the best quarterback in the NFL?", stop=[","])
assert isinstance(generator, Generator)
last_token = ""
for token in generator:
last_token = token
assert isinstance(token, str)
assert last_token[-1] == ","
@pytest.mark.asyncio
async def test_fireworks_streaming_async() -> None:
"""Test stream completion."""
llm = Fireworks()
async for token in llm.astream("Who's the best quarterback in the NFL?"):
last_token = ""
async for token in llm.astream(
"Who's the best quarterback in the NFL?", stop=[","]
):
last_token = token
assert isinstance(token, str)
assert last_token[-1] == ","
@pytest.mark.asyncio

Loading…
Cancel
Save