From 28ee6a7c125f1eb209b6b6428d1a50040408ea9f Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Wed, 11 Oct 2023 13:37:03 -0700 Subject: [PATCH] Track ChatFireworks time to first_token (#11672) --- libs/langchain/langchain/chat_models/fireworks.py | 9 ++++++++- libs/langchain/langchain/llms/fireworks.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index fd851c0eed..841ee8a0d8 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -96,7 +96,10 @@ class ChatFireworks(BaseChatModel): try: import fireworks.client except ImportError as e: - raise ImportError("") from e + raise ImportError( + "Could not import fireworks-ai python package. " + "Please install it with `pip install fireworks-ai`." + ) from e fireworks_api_key = get_from_dict_or_env( values, "fireworks_api_key", "FIREWORKS_API_KEY" ) @@ -194,6 +197,8 @@ class ChatFireworks(BaseChatModel): ) default_chunk_class = chunk.__class__ yield ChatGenerationChunk(message=chunk, generation_info=generation_info) + if run_manager: + run_manager.on_llm_new_token(chunk.content, chunk=chunk) async def _astream( self, @@ -221,6 +226,8 @@ class ChatFireworks(BaseChatModel): ) default_chunk_class = chunk.__class__ yield ChatGenerationChunk(message=chunk, generation_info=generation_info) + if run_manager: + await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk) def completion_with_retry( diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 6922b2a6e9..b99451515c 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -45,7 +45,10 @@ class Fireworks(LLM): try: import fireworks.client except ImportError as e: - raise ImportError("") from e + raise ImportError( + "Could not import fireworks-ai python package. " + "Please install it with `pip install fireworks-ai`." + ) from e fireworks_api_key = get_from_dict_or_env( values, "fireworks_api_key", "FIREWORKS_API_KEY" ) @@ -113,6 +116,8 @@ class Fireworks(LLM): ): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) async def _astream( self, @@ -132,6 +137,8 @@ class Fireworks(LLM): ): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk + if run_manager: + await run_manager.on_llm_new_token(chunk.text, chunk=chunk) def stream( self,