From 945124094151cc128194e4070fe82b51572c83d1 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 27 Sep 2023 08:09:33 -0700 Subject: [PATCH] Fix fireworks chat linting issues --- .../langchain/chat_models/fireworks.py | 66 +++++++++++++------ libs/langchain/langchain/llms/fireworks.py | 59 ++++++++++++----- .../chat_models/test_fireworks.py | 1 - 3 files changed, 88 insertions(+), 38 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index 8e7774a569..22d9384d58 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -1,8 +1,13 @@ -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union - -import fireworks -import fireworks.client -from pydantic import root_validator +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, + Union, +) from langchain.adapters.openai import convert_message_to_dict from langchain.callbacks.manager import ( @@ -11,6 +16,7 @@ from langchain.callbacks.manager import ( ) from langchain.chat_models.base import BaseChatModel from langchain.llms.base import create_base_retry_decorator +from langchain.pydantic_v1 import Field, root_validator from langchain.schema.messages import ( AIMessage, AIMessageChunk, @@ -30,12 +36,12 @@ from langchain.utils.env import get_from_dict_or_env def _convert_delta_to_message_chunk( - _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] + _dict: Any, default_class: type[BaseMessageChunk] ) -> BaseMessageChunk: """Convert a delta response to a message chunk.""" role = _dict.role content = _dict.content or "" - additional_kwargs = {} + additional_kwargs: Dict = {} if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) @@ -51,7 +57,7 @@ def _convert_delta_to_message_chunk( return default_class(content=content) -def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: +def convert_dict_to_message(_dict: Any) -> BaseMessage: """Convert a dict response to a message.""" role = _dict.role content = _dict.content or "" @@ -59,7 +65,7 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: return HumanMessage(content=content) elif role == "assistant": content = _dict.content - additional_kwargs = {} + additional_kwargs: Dict = {} return AIMessage(content=content, additional_kwargs=additional_kwargs) elif role == "system": return SystemMessage(content=content) @@ -73,13 +79,23 @@ class ChatFireworks(BaseChatModel): """Fireworks Chat models.""" model: str = "accounts/fireworks/models/llama-v2-7b-chat" - model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} + model_kwargs: dict = Field( + default_factory=lambda: { + "temperature": 0.7, + "max_tokens": 512, + "top_p": 1, + }.copy() + ) fireworks_api_key: Optional[str] = None max_retries: int = 20 @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key in environment.""" + try: + import fireworks.client + except ImportError as e: + raise ImportError("") from e fireworks_api_key = get_from_dict_or_env( values, "fireworks_api_key", "FIREWORKS_API_KEY" ) @@ -105,14 +121,14 @@ class ChatFireworks(BaseChatModel): "messages": message_dicts, **self.model_kwargs, } - response = completion_with_retry(self, **params) + response = completion_with_retry(self, run_manager=run_manager, **params) return self._create_chat_result(response) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: message_dicts = self._create_message_dicts(messages, stop) @@ -121,13 +137,15 @@ class ChatFireworks(BaseChatModel): "messages": message_dicts, **self.model_kwargs, } - response = await acompletion_with_retry(self, **params) + response = await acompletion_with_retry(self, run_manager=run_manager, **params) return self._create_chat_result(response) def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + if llm_outputs[0] is None: + return {} return llm_outputs[0] - def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + def _create_chat_result(self, response: Any) -> ChatResult: generations = [] for res in response.choices: message = convert_dict_to_message(res.message) @@ -141,7 +159,7 @@ class ChatFireworks(BaseChatModel): def _create_message_dicts( self, messages: List[BaseMessage], stop: Optional[List[str]] - ) -> Tuple[List[Dict[str, Any]]]: + ) -> List[Dict[str, Any]]: message_dicts = [convert_message_to_dict(m) for m in messages] return message_dicts @@ -160,7 +178,7 @@ class ChatFireworks(BaseChatModel): "stream": True, **self.model_kwargs, } - for chunk in completion_with_retry(self, **params): + for chunk in completion_with_retry(self, run_manager=run_manager, **params): choice = chunk.choices[0] chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class) finish_reason = choice.finish_reason @@ -174,9 +192,9 @@ class ChatFireworks(BaseChatModel): self, messages: List[BaseMessage], stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: + ) -> AsyncIterator[ChatGenerationChunk]: message_dicts = self._create_message_dicts(messages, stop) default_chunk_class = AIMessageChunk params = { @@ -185,7 +203,9 @@ class ChatFireworks(BaseChatModel): "stream": True, **self.model_kwargs, } - async for chunk in await acompletion_with_retry_streaming(self, **params): + async for chunk in await acompletion_with_retry_streaming( + self, run_manager=run_manager, **params + ): choice = chunk.choices[0] chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class) finish_reason = choice.finish_reason @@ -202,6 +222,8 @@ def completion_with_retry( **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call.""" + import fireworks.client + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator @@ -219,6 +241,8 @@ async def acompletion_with_retry( **kwargs: Any, ) -> Any: """Use tenacity to retry the async completion call.""" + import fireworks.client + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator @@ -236,6 +260,8 @@ async def acompletion_with_retry_streaming( **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call for streaming.""" + import fireworks.client + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator @@ -254,6 +280,8 @@ def _create_retry_decorator( ] = None, ) -> Callable[[Any], Any]: """Define retry mechanism.""" + import fireworks.client + errors = [ fireworks.client.error.RateLimitError, fireworks.client.error.ServiceUnavailableError, diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index cf52b4b74b..47c7659750 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -1,14 +1,11 @@ -from typing import Any, Callable, Dict, Iterator, List, Optional, Union - -import fireworks -import fireworks.client -from pydantic import root_validator +from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM, create_base_retry_decorator +from langchain.pydantic_v1 import Field, root_validator from langchain.schema.language_model import LanguageModelInput from langchain.schema.output import GenerationChunk from langchain.schema.runnable.config import RunnableConfig @@ -16,7 +13,7 @@ from langchain.utils.env import get_from_dict_or_env def _stream_response_to_generation_chunk( - stream_response: Dict[str, Any], + stream_response: Any, ) -> GenerationChunk: """Convert a stream response to a generation chunk.""" return GenerationChunk( @@ -32,13 +29,23 @@ class Fireworks(LLM): """Fireworks models.""" model: str = "accounts/fireworks/models/llama-v2-7b-chat" - model_kwargs: Optional[dict] = {"temperature": 0.7, "max_tokens": 512, "top_p": 1} + model_kwargs: dict = Field( + default_factory=lambda: { + "temperature": 0.7, + "max_tokens": 512, + "top_p": 1, + }.copy() + ) fireworks_api_key: Optional[str] = None max_retries: int = 20 @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key in environment.""" + try: + import fireworks.client + except ImportError as e: + raise ImportError("") from e fireworks_api_key = get_from_dict_or_env( values, "fireworks_api_key", "FIREWORKS_API_KEY" ) @@ -58,12 +65,12 @@ class Fireworks(LLM): **kwargs: Any, ) -> str: """Run the LLM on the given prompt and input.""" - params = { + params: dict = { "model": self.model, "prompt": prompt, **self.model_kwargs, } - response = completion_with_retry(self, **params) + response = completion_with_retry(self, run_manager=run_manager, **params) return response.choices[0].text @@ -71,7 +78,7 @@ class Fireworks(LLM): self, prompt: str, stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Run the LLM on the given prompt and input.""" @@ -80,7 +87,7 @@ class Fireworks(LLM): "prompt": prompt, **self.model_kwargs, } - response = await acompletion_with_retry(self, **params) + response = await acompletion_with_retry(self, run_manager=run_manager, **params) return response.choices[0].text @@ -97,7 +104,9 @@ class Fireworks(LLM): "stream": True, **self.model_kwargs, } - for stream_resp in completion_with_retry(self, **params): + for stream_resp in completion_with_retry( + self, run_manager=run_manager, **params + ): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk @@ -105,16 +114,18 @@ class Fireworks(LLM): self, prompt: str, stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> Iterator[GenerationChunk]: + ) -> AsyncIterator[GenerationChunk]: params = { "model": self.model, "prompt": prompt, "stream": True, **self.model_kwargs, } - async for stream_resp in await acompletion_with_retry_streaming(self, **params): + async for stream_resp in await acompletion_with_retry_streaming( + self, run_manager=run_manager, **params + ): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk @@ -143,7 +154,7 @@ class Fireworks(LLM): *, stop: Optional[List[str]] = None, **kwargs: Any, - ) -> Iterator[str]: + ) -> AsyncIterator[str]: prompt = self._convert_input(input).to_string() generation: Optional[GenerationChunk] = None async for chunk in self._astream(prompt): @@ -157,10 +168,13 @@ class Fireworks(LLM): def completion_with_retry( llm: Fireworks, + *, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call.""" + import fireworks.client + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator @@ -174,10 +188,13 @@ def completion_with_retry( async def acompletion_with_retry( llm: Fireworks, - run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call.""" + import fireworks.client + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator @@ -191,10 +208,13 @@ async def acompletion_with_retry( async def acompletion_with_retry_streaming( llm: Fireworks, - run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call for streaming.""" + import fireworks.client + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator @@ -208,11 +228,14 @@ async def acompletion_with_retry_streaming( def _create_retry_decorator( llm: Fireworks, + *, run_manager: Optional[ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] ] = None, ) -> Callable[[Any], Any]: """Define retry mechanism.""" + import fireworks.client + errors = [ fireworks.client.error.RateLimitError, fireworks.client.error.ServiceUnavailableError, diff --git a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py index 9cc681b038..b6930a7b9e 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py @@ -2,7 +2,6 @@ import pytest -from langchain.callbacks.manager import CallbackManager from langchain.chat_models.fireworks import ChatFireworks from langchain.schema import ( ChatGeneration,