pull/9007/head
Bagatur 1 year ago
parent a9bf409a09
commit 4e7e6bfe0a

@ -63,10 +63,13 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
**kwargs: Any,
) -> Dict[str, Any]:
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return self(input, **config_kwargs, **kwargs)
return self(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
async def ainvoke(
self,
@ -79,11 +82,15 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return await self.acall(input, **config_kwargs, **kwargs)
return await self.acall(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.

@ -105,15 +105,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
**kwargs: Any,
) -> BaseMessageChunk:
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return cast(
BaseMessageChunk,
cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
).generations[0][0],
).message,
)
@ -133,11 +135,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message

@ -220,13 +220,18 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any,
) -> str:
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
result = self.generate_prompt(
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
return (
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
.generations[0][0]
.text
)
return result.generations[0][0].text
async def ainvoke(
self,
@ -243,11 +248,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
return llm_result.generations[0][0].text

@ -108,10 +108,12 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return self.get_relevant_documents(input, **config_kwargs)
return self.get_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
)
async def ainvoke(
self,
@ -124,10 +126,12 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
return await super().ainvoke(input, config)
config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return await self.aget_relevant_documents(input, **config_kwargs)
return await self.aget_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
)
@abstractmethod
def _get_relevant_documents(

Loading…
Cancel
Save