add kwargs to llm runnables (#8388)

pull/8414/head^2
Harrison Chase 1 year ago committed by GitHub
parent d5884017a9
commit 394b67ab92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -101,13 +101,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk: ) -> BaseMessageChunk:
return cast( return cast(
BaseMessageChunk, BaseMessageChunk,
cast( cast(
ChatGeneration, ChatGeneration,
self.generate_prompt( self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}) [self._convert_input(input)], stop=stop, **(config or {}), **kwargs
).generations[0][0], ).generations[0][0],
).message, ).message,
) )
@ -118,15 +119,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk: ) -> BaseMessageChunk:
if type(self)._agenerate == BaseChatModel._agenerate: if type(self)._agenerate == BaseChatModel._agenerate:
# model doesn't implement async generation, so use default implementation # model doesn't implement async generation, so use default implementation
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, stop=stop) None, partial(self.invoke, input, config, stop=stop, **kwargs)
) )
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}) [self._convert_input(input)], stop=stop, **(config or {}), **kwargs
) )
return cast( return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message

@ -213,10 +213,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str: ) -> str:
return ( return (
self.generate_prompt( self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}) [self._convert_input(input)], stop=stop, **(config or {}), **kwargs
) )
.generations[0][0] .generations[0][0]
.text .text
@ -228,15 +229,16 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str: ) -> str:
if type(self)._agenerate == BaseLLM._agenerate: if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async invoke, so use default implementation # model doesn't implement async invoke, so use default implementation
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, stop=stop) None, partial(self.invoke, input, config, stop=stop, **kwargs)
) )
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}) [self._convert_input(input)], stop=stop, **(config or {}), **kwargs
) )
return llm_result.generations[0][0].text return llm_result.generations[0][0].text
@ -245,6 +247,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
inputs: List[LanguageModelInput], inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Any,
) -> List[str]: ) -> List[str]:
config = self._get_config_list(config, len(inputs)) config = self._get_config_list(config, len(inputs))
@ -254,6 +257,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=[c.get("callbacks") for c in config], callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config], tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config], metadata=[c.get("metadata") for c in config],
**kwargs,
) )
return [g[0].text for g in llm_result.generations] return [g[0].text for g in llm_result.generations]
else: else:
@ -264,7 +268,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return [ return [
output output
for batch in batches for batch in batches
for output in self.batch(batch, config=config) for output in self.batch(batch, config=config, **kwargs)
] ]
async def abatch( async def abatch(
@ -272,6 +276,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
inputs: List[LanguageModelInput], inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Any,
) -> List[str]: ) -> List[str]:
if type(self)._agenerate == BaseLLM._agenerate: if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async batch, so use default implementation # model doesn't implement async batch, so use default implementation
@ -287,6 +292,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks=[c.get("callbacks") for c in config], callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config], tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config], metadata=[c.get("metadata") for c in config],
**kwargs,
) )
return [g[0].text for g in llm_result.generations] return [g[0].text for g in llm_result.generations]
else: else:
@ -297,7 +303,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return [ return [
output output
for batch in batches for batch in batches
for output in await self.abatch(batch, config=config) for output in await self.abatch(batch, config=config, **kwargs)
] ]
def stream( def stream(
@ -306,15 +312,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]: ) -> Iterator[str]:
if type(self)._stream == BaseLLM._stream: if type(self)._stream == BaseLLM._stream:
# model doesn't implement streaming, so use default implementation # model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop) yield self.invoke(input, config=config, stop=stop, **kwargs)
else: else:
prompt = self._convert_input(input).to_string() prompt = self._convert_input(input).to_string()
config = config or {} config = config or {}
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop} options = {"stop": stop}
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
config.get("callbacks"), config.get("callbacks"),
@ -330,7 +338,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) )
try: try:
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompt, stop=stop, run_manager=run_manager): for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.text yield chunk.text
if generation is None: if generation is None:
generation = chunk generation = chunk
@ -349,15 +359,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
if type(self)._astream == BaseLLM._astream: if type(self)._astream == BaseLLM._astream:
# model doesn't implement streaming, so use default implementation # model doesn't implement streaming, so use default implementation
yield await self.ainvoke(input, config=config, stop=stop) yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
else: else:
prompt = self._convert_input(input).to_string() prompt = self._convert_input(input).to_string()
config = config or {} config = config or {}
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop} options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"), config.get("callbacks"),
@ -374,7 +386,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
try: try:
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
async for chunk in self._astream( async for chunk in self._astream(
prompt, stop=stop, run_manager=run_manager prompt, stop=stop, run_manager=run_manager, **kwargs
): ):
yield chunk.text yield chunk.text
if generation is None: if generation is None:

Loading…
Cancel
Save