From 4e7e6bfe0a7bd15c4ccd72ed33fe1b35b47be3ef Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 21 Aug 2023 18:01:49 -0700 Subject: [PATCH] revert --- libs/langchain/langchain/chains/base.py | 23 +++++++++++------ libs/langchain/langchain/chat_models/base.py | 20 +++++++++------ libs/langchain/langchain/llms/base.py | 27 ++++++++++++-------- libs/langchain/langchain/schema/retriever.py | 20 +++++++++------ 4 files changed, 56 insertions(+), 34 deletions(-) diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 8a49784f7d..5a21dc6a66 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -63,10 +63,13 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): **kwargs: Any, ) -> Dict[str, Any]: config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } - return self(input, **config_kwargs, **kwargs) + return self( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, + ) async def ainvoke( self, @@ -79,11 +82,15 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): return await asyncio.get_running_loop().run_in_executor( None, partial(self.invoke, input, config, **kwargs) ) + config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } - return await self.acall(input, **config_kwargs, **kwargs) + return await self.acall( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, + ) memory: Optional[BaseMemory] = None """Optional memory object. Defaults to None. diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index d4c582c19e..09199e30dc 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -105,15 +105,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): **kwargs: Any, ) -> BaseMessageChunk: config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } return cast( BaseMessageChunk, cast( ChatGeneration, self.generate_prompt( - [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, ).generations[0][0], ).message, ) @@ -133,11 +135,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): ) config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, ) return cast( BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 401fe61d06..a833487ffb 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -220,13 +220,18 @@ class BaseLLM(BaseLanguageModel[str], ABC): **kwargs: Any, ) -> str: config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } - result = self.generate_prompt( - [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs + return ( + self.generate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, + ) + .generations[0][0] + .text ) - return result.generations[0][0].text async def ainvoke( self, @@ -243,11 +248,13 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, ) return llm_result.generations[0][0].text diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 55a1acb086..5da50e1497 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -108,10 +108,12 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): self, input: str, config: Optional[RunnableConfig] = None ) -> List[Document]: config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } - return self.get_relevant_documents(input, **config_kwargs) + return self.get_relevant_documents( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + ) async def ainvoke( self, @@ -124,10 +126,12 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): return await super().ainvoke(input, config) config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } - return await self.aget_relevant_documents(input, **config_kwargs) + return await self.aget_relevant_documents( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + ) @abstractmethod def _get_relevant_documents(