From 394b67ab92602fd47a75916833641776057dae94 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 28 Jul 2023 01:13:11 -0700 Subject: [PATCH] add kwargs to llm runnables (#8388) --- libs/langchain/langchain/chat_models/base.py | 8 ++++-- libs/langchain/langchain/llms/base.py | 30 ++++++++++++++------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index fe9ed5c59f..b06b99f99d 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -101,13 +101,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, + **kwargs: Any, ) -> BaseMessageChunk: return cast( BaseMessageChunk, cast( ChatGeneration, self.generate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}) + [self._convert_input(input)], stop=stop, **(config or {}), **kwargs ).generations[0][0], ).message, ) @@ -118,15 +119,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, + **kwargs: Any, ) -> BaseMessageChunk: if type(self)._agenerate == BaseChatModel._agenerate: # model doesn't implement async generation, so use default implementation return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, stop=stop) + None, partial(self.invoke, input, config, stop=stop, **kwargs) ) llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}) + [self._convert_input(input)], stop=stop, **(config or {}), **kwargs ) return cast( BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index a40b79de80..8ccaf0deb6 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -213,10 +213,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, + **kwargs: Any, ) -> str: return ( self.generate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}) + [self._convert_input(input)], stop=stop, **(config or {}), **kwargs ) .generations[0][0] .text @@ -228,15 +229,16 @@ class BaseLLM(BaseLanguageModel[str], ABC): config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, + **kwargs: Any, ) -> str: if type(self)._agenerate == BaseLLM._agenerate: # model doesn't implement async invoke, so use default implementation return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, stop=stop) + None, partial(self.invoke, input, config, stop=stop, **kwargs) ) llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}) + [self._convert_input(input)], stop=stop, **(config or {}), **kwargs ) return llm_result.generations[0][0].text @@ -245,6 +247,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): inputs: List[LanguageModelInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, max_concurrency: Optional[int] = None, + **kwargs: Any, ) -> List[str]: config = self._get_config_list(config, len(inputs)) @@ -254,6 +257,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): callbacks=[c.get("callbacks") for c in config], tags=[c.get("tags") for c in config], metadata=[c.get("metadata") for c in config], + **kwargs, ) return [g[0].text for g in llm_result.generations] else: @@ -264,7 +268,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): return [ output for batch in batches - for output in self.batch(batch, config=config) + for output in self.batch(batch, config=config, **kwargs) ] async def abatch( @@ -272,6 +276,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): inputs: List[LanguageModelInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, max_concurrency: Optional[int] = None, + **kwargs: Any, ) -> List[str]: if type(self)._agenerate == BaseLLM._agenerate: # model doesn't implement async batch, so use default implementation @@ -287,6 +292,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): callbacks=[c.get("callbacks") for c in config], tags=[c.get("tags") for c in config], metadata=[c.get("metadata") for c in config], + **kwargs, ) return [g[0].text for g in llm_result.generations] else: @@ -297,7 +303,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): return [ output for batch in batches - for output in await self.abatch(batch, config=config) + for output in await self.abatch(batch, config=config, **kwargs) ] def stream( @@ -306,15 +312,17 @@ class BaseLLM(BaseLanguageModel[str], ABC): config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, + **kwargs: Any, ) -> Iterator[str]: if type(self)._stream == BaseLLM._stream: # model doesn't implement streaming, so use default implementation - yield self.invoke(input, config=config, stop=stop) + yield self.invoke(input, config=config, stop=stop, **kwargs) else: prompt = self._convert_input(input).to_string() config = config or {} params = self.dict() params["stop"] = stop + params = {**params, **kwargs} options = {"stop": stop} callback_manager = CallbackManager.configure( config.get("callbacks"), @@ -330,7 +338,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) try: generation: Optional[GenerationChunk] = None - for chunk in self._stream(prompt, stop=stop, run_manager=run_manager): + for chunk in self._stream( + prompt, stop=stop, run_manager=run_manager, **kwargs + ): yield chunk.text if generation is None: generation = chunk @@ -349,15 +359,17 @@ class BaseLLM(BaseLanguageModel[str], ABC): config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, + **kwargs: Any, ) -> AsyncIterator[str]: if type(self)._astream == BaseLLM._astream: # model doesn't implement streaming, so use default implementation - yield await self.ainvoke(input, config=config, stop=stop) + yield await self.ainvoke(input, config=config, stop=stop, **kwargs) else: prompt = self._convert_input(input).to_string() config = config or {} params = self.dict() params["stop"] = stop + params = {**params, **kwargs} options = {"stop": stop} callback_manager = AsyncCallbackManager.configure( config.get("callbacks"), @@ -374,7 +386,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): try: generation: Optional[GenerationChunk] = None async for chunk in self._astream( - prompt, stop=stop, run_manager=run_manager + prompt, stop=stop, run_manager=run_manager, **kwargs ): yield chunk.text if generation is None: