mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
add kwargs to llm runnables (#8388)
This commit is contained in:
parent
d5884017a9
commit
394b67ab92
@ -101,13 +101,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessageChunk:
|
||||
return cast(
|
||||
BaseMessageChunk,
|
||||
cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {})
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
).generations[0][0],
|
||||
).message,
|
||||
)
|
||||
@ -118,15 +119,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessageChunk:
|
||||
if type(self)._agenerate == BaseChatModel._agenerate:
|
||||
# model doesn't implement async generation, so use default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, stop=stop)
|
||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||
)
|
||||
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {})
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
)
|
||||
return cast(
|
||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
|
@ -213,10 +213,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return (
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {})
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
)
|
||||
.generations[0][0]
|
||||
.text
|
||||
@ -228,15 +229,16 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if type(self)._agenerate == BaseLLM._agenerate:
|
||||
# model doesn't implement async invoke, so use default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, stop=stop)
|
||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||
)
|
||||
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {})
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
)
|
||||
return llm_result.generations[0][0].text
|
||||
|
||||
@ -245,6 +247,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
config = self._get_config_list(config, len(inputs))
|
||||
|
||||
@ -254,6 +257,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callbacks=[c.get("callbacks") for c in config],
|
||||
tags=[c.get("tags") for c in config],
|
||||
metadata=[c.get("metadata") for c in config],
|
||||
**kwargs,
|
||||
)
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
else:
|
||||
@ -264,7 +268,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
return [
|
||||
output
|
||||
for batch in batches
|
||||
for output in self.batch(batch, config=config)
|
||||
for output in self.batch(batch, config=config, **kwargs)
|
||||
]
|
||||
|
||||
async def abatch(
|
||||
@ -272,6 +276,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
if type(self)._agenerate == BaseLLM._agenerate:
|
||||
# model doesn't implement async batch, so use default implementation
|
||||
@ -287,6 +292,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callbacks=[c.get("callbacks") for c in config],
|
||||
tags=[c.get("tags") for c in config],
|
||||
metadata=[c.get("metadata") for c in config],
|
||||
**kwargs,
|
||||
)
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
else:
|
||||
@ -297,7 +303,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
return [
|
||||
output
|
||||
for batch in batches
|
||||
for output in await self.abatch(batch, config=config)
|
||||
for output in await self.abatch(batch, config=config, **kwargs)
|
||||
]
|
||||
|
||||
def stream(
|
||||
@ -306,15 +312,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
if type(self)._stream == BaseLLM._stream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield self.invoke(input, config=config, stop=stop)
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
config = config or {}
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
@ -330,7 +338,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
)
|
||||
try:
|
||||
generation: Optional[GenerationChunk] = None
|
||||
for chunk in self._stream(prompt, stop=stop, run_manager=run_manager):
|
||||
for chunk in self._stream(
|
||||
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.text
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
@ -349,15 +359,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
if type(self)._astream == BaseLLM._astream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield await self.ainvoke(input, config=config, stop=stop)
|
||||
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
config = config or {}
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
@ -374,7 +386,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
try:
|
||||
generation: Optional[GenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
prompt, stop=stop, run_manager=run_manager
|
||||
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.text
|
||||
if generation is None:
|
||||
|
Loading…
Reference in New Issue
Block a user