Track ChatFireworks time to first_token (#11672)

pull/11689/head
Erick Friis 1 year ago committed by GitHub
parent 2c1e735403
commit 28ee6a7c12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -96,7 +96,10 @@ class ChatFireworks(BaseChatModel):
try:
import fireworks.client
except ImportError as e:
raise ImportError("") from e
raise ImportError(
"Could not import fireworks-ai python package. "
"Please install it with `pip install fireworks-ai`."
) from e
fireworks_api_key = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
@ -194,6 +197,8 @@ class ChatFireworks(BaseChatModel):
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=chunk)
async def _astream(
self,
@ -221,6 +226,8 @@ class ChatFireworks(BaseChatModel):
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
if run_manager:
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
def completion_with_retry(

@ -45,7 +45,10 @@ class Fireworks(LLM):
try:
import fireworks.client
except ImportError as e:
raise ImportError("") from e
raise ImportError(
"Could not import fireworks-ai python package. "
"Please install it with `pip install fireworks-ai`."
) from e
fireworks_api_key = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY"
)
@ -113,6 +116,8 @@ class Fireworks(LLM):
):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
async def _astream(
self,
@ -132,6 +137,8 @@ class Fireworks(LLM):
):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
def stream(
self,

Loading…
Cancel
Save