diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index a130dc62b8..bdbd7fc699 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -117,13 +117,7 @@ class Runnable(Generic[Input, Output], ABC): return [self.invoke(inputs[0], configs[0], **kwargs)] with get_executor_for_config(configs[0]) as executor: - return list( - executor.map( - partial(self.invoke, **kwargs), - inputs, - (patch_config(c, executor=executor) for c in configs), - ) - ) + return list(executor.map(partial(self.invoke, **kwargs), inputs, configs)) async def abatch( self, @@ -852,18 +846,15 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # invoke try: - with get_executor_for_config(configs[0]) as executor: - for step in self.steps: - inputs = step.batch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config( - config, callbacks=rm.get_child(), executor=executor - ) - for rm, config in zip(run_managers, configs) - ], - ) + for step in self.steps: + inputs = step.batch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + ) # finish the root runs except (KeyboardInterrupt, Exception) as e: for rm in run_managers: @@ -1152,7 +1143,6 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): config, deep_copy_locals=True, callbacks=run_manager.get_child(), - executor=executor, ), ) for step in steps.values() @@ -1219,9 +1209,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): name, step.transform( input_copies.pop(), - patch_config( - config, callbacks=run_manager.get_child(), executor=executor - ), + patch_config(config, callbacks=run_manager.get_child()), ), ) for name, step in steps.items() diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index a431fb6358..b97d904414 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -42,12 +42,6 @@ class RunnableConfig(TypedDict, total=False): ThreadPoolExecutor's default. This is ignored if an executor is provided. """ - executor: Executor - """ - Externally-managed executor to use for parallel calls. If not provided, a new - ThreadPoolExecutor will be created. - """ - recursion_limit: int """ Maximum number of times a call can recurse. If not provided, defaults to 10. @@ -72,7 +66,6 @@ def patch_config( *, deep_copy_locals: bool = False, callbacks: Optional[BaseCallbackManager] = None, - executor: Optional[Executor] = None, recursion_limit: Optional[int] = None, ) -> RunnableConfig: config = ensure_config(config) @@ -80,8 +73,6 @@ def patch_config( config["_locals"] = deepcopy(config["_locals"]) if callbacks is not None: config["callbacks"] = callbacks - if executor is not None: - config["executor"] = executor if recursion_limit is not None: config["recursion_limit"] = recursion_limit return config @@ -111,8 +102,5 @@ def get_async_callback_manager_for_config( @contextmanager def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]: - if config.get("executor"): - yield config["executor"] - else: - with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor: - yield executor + with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor: + yield executor diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index c48d4edbd4..fcb621fe8c 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -2081,7 +2081,8 @@ "stop": [ "Thought:" ] - } + }, + "config": {} } }, "llm": {