@ -1,12 +1,14 @@
import asyncio
from concurrent . futures import ThreadPoolExecutor
from typing import Any , AsyncIterator , Callable , Dict , Iterator , List , Optional , Union
from langchain . callbacks . manager import (
AsyncCallbackManagerForLLMRun ,
CallbackManagerForLLMRun ,
)
from langchain . llms . base import LLM, create_base_retry_decorator
from langchain . llms . base import Base LLM, create_base_retry_decorator
from langchain . pydantic_v1 import Field , root_validator
from langchain . schema . output import Generation Chunk
from langchain . schema . output import Generation , Generation Chunk, LLMResult
from langchain . utils . env import get_from_dict_or_env
@ -23,7 +25,7 @@ def _stream_response_to_generation_chunk(
)
class Fireworks ( LLM) :
class Fireworks ( Base LLM) :
""" Fireworks models. """
model : str = " accounts/fireworks/models/llama-v2-7b-chat "
@ -36,6 +38,8 @@ class Fireworks(LLM):
)
fireworks_api_key : Optional [ str ] = None
max_retries : int = 20
batch_size : int = 20
use_retry : bool = True
@property
def lc_secrets ( self ) - > Dict [ str , str ] :
@ -66,43 +70,92 @@ class Fireworks(LLM):
""" Return type of llm. """
return " fireworks "
def _ call (
def _ generate (
self ,
prompt : str ,
prompt s : List [ str ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > str :
""" Run the LLM on the given prompt and input. """
params : dict = {
) - > LLMResult :
""" Call out to Fireworks endpoint with k unique prompts.
Args :
prompts : The prompts to pass into the model .
stop : Optional list of stop words to use when generating .
Returns :
The full LLM output .
"""
params = {
" model " : self . model ,
" prompt " : prompt ,
* * self . model_kwargs ,
}
response = completion_with_retry (
self , run_manager = run_manager , stop = stop , * * params
sub_prompts = self . get_batch_prompts ( prompts )
choices = [ ]
for _prompts in sub_prompts :
response = completion_with_retry_batching (
self ,
self . use_retry ,
prompt = _prompts ,
run_manager = run_manager ,
stop = stop ,
* * params ,
)
choices . extend ( response )
return response . choices [ 0 ] . text
return self . create_llm_result ( choices , prompts )
async def _acall (
async def _a generate (
self ,
prompt : str ,
prompt s : List [ str ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > str :
""" Run the LLM on the given prompt and input ."""
) - > LLMResult :
""" Call out to Fireworks endpoint async with k unique prompts ."""
params = {
" model " : self . model ,
" prompt " : prompt ,
* * self . model_kwargs ,
}
response = await acompletion_with_retry (
self , run_manager = run_manager , stop = stop , * * params
sub_prompts = self . get_batch_prompts ( prompts )
choices = [ ]
for _prompts in sub_prompts :
response = await acompletion_with_retry_batching (
self ,
self . use_retry ,
prompt = _prompts ,
run_manager = run_manager ,
stop = stop ,
* * params ,
)
choices . extend ( response )
return response . choices [ 0 ] . text
return self . create_llm_result ( choices , prompts )
def get_batch_prompts (
self ,
prompts : List [ str ] ,
) - > List [ List [ str ] ] :
""" Get the sub prompts for llm call. """
sub_prompts = [
prompts [ i : i + self . batch_size ]
for i in range ( 0 , len ( prompts ) , self . batch_size )
]
return sub_prompts
def create_llm_result ( self , choices : Any , prompts : List [ str ] ) - > LLMResult :
""" Create the LLMResult from the choices and prompts. """
generations = [ ]
for i , _ in enumerate ( prompts ) :
sub_choices = choices [ i : ( i + 1 ) ]
generations . append (
[
Generation (
text = choice . __dict__ [ " choices " ] [ 0 ] . text ,
)
for choice in sub_choices
]
)
llm_output = { " model " : self . model }
return LLMResult ( generations = generations , llm_output = llm_output )
def _stream (
self ,
@ -118,7 +171,7 @@ class Fireworks(LLM):
* * self . model_kwargs ,
}
for stream_resp in completion_with_retry (
self , run_manager = run_manager , stop = stop , * * params
self , self . use_retry , run_manager = run_manager , stop = stop , * * params
) :
chunk = _stream_response_to_generation_chunk ( stream_resp )
yield chunk
@ -139,7 +192,7 @@ class Fireworks(LLM):
* * self . model_kwargs ,
}
async for stream_resp in await acompletion_with_retry_streaming (
self , run_manager = run_manager , stop = stop , * * params
self , self . use_retry , run_manager = run_manager , stop = stop , * * params
) :
chunk = _stream_response_to_generation_chunk ( stream_resp )
yield chunk
@ -147,8 +200,20 @@ class Fireworks(LLM):
await run_manager . on_llm_new_token ( chunk . text , chunk = chunk )
def conditional_decorator (
condition : bool , decorator : Callable [ [ Any ] , Any ]
) - > Callable [ [ Any ] , Any ] :
def actual_decorator ( func : Callable [ [ Any ] , Any ] ) - > Callable [ [ Any ] , Any ] :
if condition :
return decorator ( func )
return func
return actual_decorator
def completion_with_retry (
llm : Fireworks ,
use_retry : bool ,
* ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
@ -158,7 +223,7 @@ def completion_with_retry(
retry_decorator = _create_retry_decorator ( llm , run_manager = run_manager )
@ retry_decorator
@ conditional_decorator( use_retry , retry_decorator)
def _completion_with_retry ( * * kwargs : Any ) - > Any :
return fireworks . client . Completion . create (
* * kwargs ,
@ -169,6 +234,7 @@ def completion_with_retry(
async def acompletion_with_retry (
llm : Fireworks ,
use_retry : bool ,
* ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
@ -178,7 +244,7 @@ async def acompletion_with_retry(
retry_decorator = _create_retry_decorator ( llm , run_manager = run_manager )
@ retry_decorator
@ conditional_decorator( use_retry , retry_decorator)
async def _completion_with_retry ( * * kwargs : Any ) - > Any :
return await fireworks . client . Completion . acreate (
* * kwargs ,
@ -187,8 +253,79 @@ async def acompletion_with_retry(
return await _completion_with_retry ( * * kwargs )
def completion_with_retry_batching (
llm : Fireworks ,
use_retry : bool ,
* ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > Any :
""" Use tenacity to retry the completion call. """
import fireworks . client
prompt = kwargs [ " prompt " ]
del kwargs [ " prompt " ]
retry_decorator = _create_retry_decorator ( llm , run_manager = run_manager )
@conditional_decorator ( use_retry , retry_decorator )
def _completion_with_retry ( prompt : str ) - > Any :
return fireworks . client . Completion . create ( * * kwargs , prompt = prompt )
def batch_sync_run ( ) - > List :
with ThreadPoolExecutor ( ) as executor :
results = list ( executor . map ( _completion_with_retry , prompt ) )
return results
return batch_sync_run ( )
async def acompletion_with_retry_batching (
llm : Fireworks ,
use_retry : bool ,
* ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > Any :
""" Use tenacity to retry the completion call. """
import fireworks . client
prompt = kwargs [ " prompt " ]
del kwargs [ " prompt " ]
retry_decorator = _create_retry_decorator ( llm , run_manager = run_manager )
@conditional_decorator ( use_retry , retry_decorator )
async def _completion_with_retry ( prompt : str ) - > Any :
return await fireworks . client . Completion . acreate ( * * kwargs , prompt = prompt )
def run_coroutine_in_new_loop (
coroutine_func : Any , * args : Dict , * * kwargs : Dict
) - > Any :
new_loop = asyncio . new_event_loop ( )
try :
asyncio . set_event_loop ( new_loop )
return new_loop . run_until_complete ( coroutine_func ( * args , * * kwargs ) )
finally :
new_loop . close ( )
async def batch_sync_run ( ) - > List :
with ThreadPoolExecutor ( ) as executor :
results = list (
executor . map (
run_coroutine_in_new_loop ,
[ _completion_with_retry ] * len ( prompt ) ,
prompt ,
)
)
return results
return await batch_sync_run ( )
async def acompletion_with_retry_streaming (
llm : Fireworks ,
use_retry : bool ,
* ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
@ -198,7 +335,7 @@ async def acompletion_with_retry_streaming(
retry_decorator = _create_retry_decorator ( llm , run_manager = run_manager )
@ retry_decorator
@ conditional_decorator( use_retry , retry_decorator)
async def _completion_with_retry ( * * kwargs : Any ) - > Any :
return fireworks . client . Completion . acreate (
* * kwargs ,
@ -219,6 +356,8 @@ def _create_retry_decorator(
errors = [
fireworks . client . error . RateLimitError ,
fireworks . client . error . InternalServerError ,
fireworks . client . error . BadGatewayError ,
fireworks . client . error . ServiceUnavailableError ,
]
return create_base_retry_decorator (