add kwargs to llm runnables (#8388)

This commit is contained in:
Harrison Chase 2023-07-28 01:13:11 -07:00 committed by GitHub
parent d5884017a9
commit 394b67ab92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 12 deletions

View File

@ -101,13 +101,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk:
return cast(
BaseMessageChunk,
cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
).generations[0][0],
).message,
)
@ -118,15 +119,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk:
if type(self)._agenerate == BaseChatModel._agenerate:
# model doesn't implement async generation, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, stop=stop)
None, partial(self.invoke, input, config, stop=stop, **kwargs)
)
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
)
return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message

View File

@ -213,10 +213,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
return (
self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
)
.generations[0][0]
.text
@ -228,15 +229,16 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async invoke, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, stop=stop)
None, partial(self.invoke, input, config, stop=stop, **kwargs)
)
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
)
return llm_result.generations[0][0].text
@ -245,6 +247,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
config = self._get_config_list(config, len(inputs))
@ -254,6 +257,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
else:
@ -264,7 +268,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return [
output
for batch in batches
for output in self.batch(batch, config=config)
for output in self.batch(batch, config=config, **kwargs)
]
async def abatch(
@ -272,6 +276,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async batch, so use default implementation
@ -287,6 +292,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
else:
@ -297,7 +303,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return [
output
for batch in batches
for output in await self.abatch(batch, config=config)
for output in await self.abatch(batch, config=config, **kwargs)
]
def stream(
@ -306,15 +312,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
if type(self)._stream == BaseLLM._stream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop)
yield self.invoke(input, config=config, stop=stop, **kwargs)
else:
prompt = self._convert_input(input).to_string()
config = config or {}
params = self.dict()
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
@ -330,7 +338,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
try:
generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompt, stop=stop, run_manager=run_manager):
for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.text
if generation is None:
generation = chunk
@ -349,15 +359,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
if type(self)._astream == BaseLLM._astream:
# model doesn't implement streaming, so use default implementation
yield await self.ainvoke(input, config=config, stop=stop)
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
else:
prompt = self._convert_input(input).to_string()
config = config or {}
params = self.dict()
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
@ -374,7 +386,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
try:
generation: Optional[GenerationChunk] = None
async for chunk in self._astream(
prompt, stop=stop, run_manager=run_manager
prompt, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.text
if generation is None: