|
|
|
@ -96,7 +96,10 @@ class ChatFireworks(BaseChatModel):
|
|
|
|
|
try:
|
|
|
|
|
import fireworks.client
|
|
|
|
|
except ImportError as e:
|
|
|
|
|
raise ImportError("") from e
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"Could not import fireworks-ai python package. "
|
|
|
|
|
"Please install it with `pip install fireworks-ai`."
|
|
|
|
|
) from e
|
|
|
|
|
fireworks_api_key = get_from_dict_or_env(
|
|
|
|
|
values, "fireworks_api_key", "FIREWORKS_API_KEY"
|
|
|
|
|
)
|
|
|
|
@ -194,6 +197,8 @@ class ChatFireworks(BaseChatModel):
|
|
|
|
|
)
|
|
|
|
|
default_chunk_class = chunk.__class__
|
|
|
|
|
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
|
|
|
|
if run_manager:
|
|
|
|
|
run_manager.on_llm_new_token(chunk.content, chunk=chunk)
|
|
|
|
|
|
|
|
|
|
async def _astream(
|
|
|
|
|
self,
|
|
|
|
@ -221,6 +226,8 @@ class ChatFireworks(BaseChatModel):
|
|
|
|
|
)
|
|
|
|
|
default_chunk_class = chunk.__class__
|
|
|
|
|
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
|
|
|
|
if run_manager:
|
|
|
|
|
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def completion_with_retry(
|
|
|
|
|