Do not share executors between parent and child tasks (#9701)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
 -->
pull/9706/head
Nuno Campos 1 year ago committed by GitHub
commit 9666e752b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -117,13 +117,7 @@ class Runnable(Generic[Input, Output], ABC):
return [self.invoke(inputs[0], configs[0], **kwargs)]
with get_executor_for_config(configs[0]) as executor:
return list(
executor.map(
partial(self.invoke, **kwargs),
inputs,
(patch_config(c, executor=executor) for c in configs),
)
)
return list(executor.map(partial(self.invoke, **kwargs), inputs, configs))
async def abatch(
self,
@ -852,18 +846,15 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke
try:
with get_executor_for_config(configs[0]) as executor:
for step in self.steps:
inputs = step.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(
config, callbacks=rm.get_child(), executor=executor
)
for rm, config in zip(run_managers, configs)
],
)
for step in self.steps:
inputs = step.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
for rm in run_managers:
@ -1152,7 +1143,6 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
config,
deep_copy_locals=True,
callbacks=run_manager.get_child(),
executor=executor,
),
)
for step in steps.values()
@ -1219,9 +1209,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
name,
step.transform(
input_copies.pop(),
patch_config(
config, callbacks=run_manager.get_child(), executor=executor
),
patch_config(config, callbacks=run_manager.get_child()),
),
)
for name, step in steps.items()

@ -42,12 +42,6 @@ class RunnableConfig(TypedDict, total=False):
ThreadPoolExecutor's default. This is ignored if an executor is provided.
"""
executor: Executor
"""
Externally-managed executor to use for parallel calls. If not provided, a new
ThreadPoolExecutor will be created.
"""
recursion_limit: int
"""
Maximum number of times a call can recurse. If not provided, defaults to 10.
@ -72,7 +66,6 @@ def patch_config(
*,
deep_copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None,
executor: Optional[Executor] = None,
recursion_limit: Optional[int] = None,
) -> RunnableConfig:
config = ensure_config(config)
@ -80,8 +73,6 @@ def patch_config(
config["_locals"] = deepcopy(config["_locals"])
if callbacks is not None:
config["callbacks"] = callbacks
if executor is not None:
config["executor"] = executor
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
return config
@ -111,8 +102,5 @@ def get_async_callback_manager_for_config(
@contextmanager
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
if config.get("executor"):
yield config["executor"]
else:
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
yield executor
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
yield executor

Loading…
Cancel
Save