mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Do not share executors between parent and child tasks (#9701)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. These live is docs/extras directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17, @rlancemartin. -->
This commit is contained in:
commit
9666e752b1
@ -117,13 +117,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return [self.invoke(inputs[0], configs[0], **kwargs)]
|
||||
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
return list(
|
||||
executor.map(
|
||||
partial(self.invoke, **kwargs),
|
||||
inputs,
|
||||
(patch_config(c, executor=executor) for c in configs),
|
||||
)
|
||||
)
|
||||
return list(executor.map(partial(self.invoke, **kwargs), inputs, configs))
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
@ -852,18 +846,15 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
|
||||
# invoke
|
||||
try:
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
for step in self.steps:
|
||||
inputs = step.batch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(
|
||||
config, callbacks=rm.get_child(), executor=executor
|
||||
)
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
)
|
||||
for step in self.steps:
|
||||
inputs = step.batch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
)
|
||||
# finish the root runs
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
for rm in run_managers:
|
||||
@ -1152,7 +1143,6 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
config,
|
||||
deep_copy_locals=True,
|
||||
callbacks=run_manager.get_child(),
|
||||
executor=executor,
|
||||
),
|
||||
)
|
||||
for step in steps.values()
|
||||
@ -1219,9 +1209,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
name,
|
||||
step.transform(
|
||||
input_copies.pop(),
|
||||
patch_config(
|
||||
config, callbacks=run_manager.get_child(), executor=executor
|
||||
),
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
),
|
||||
)
|
||||
for name, step in steps.items()
|
||||
|
@ -42,12 +42,6 @@ class RunnableConfig(TypedDict, total=False):
|
||||
ThreadPoolExecutor's default. This is ignored if an executor is provided.
|
||||
"""
|
||||
|
||||
executor: Executor
|
||||
"""
|
||||
Externally-managed executor to use for parallel calls. If not provided, a new
|
||||
ThreadPoolExecutor will be created.
|
||||
"""
|
||||
|
||||
recursion_limit: int
|
||||
"""
|
||||
Maximum number of times a call can recurse. If not provided, defaults to 10.
|
||||
@ -72,7 +66,6 @@ def patch_config(
|
||||
*,
|
||||
deep_copy_locals: bool = False,
|
||||
callbacks: Optional[BaseCallbackManager] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
recursion_limit: Optional[int] = None,
|
||||
) -> RunnableConfig:
|
||||
config = ensure_config(config)
|
||||
@ -80,8 +73,6 @@ def patch_config(
|
||||
config["_locals"] = deepcopy(config["_locals"])
|
||||
if callbacks is not None:
|
||||
config["callbacks"] = callbacks
|
||||
if executor is not None:
|
||||
config["executor"] = executor
|
||||
if recursion_limit is not None:
|
||||
config["recursion_limit"] = recursion_limit
|
||||
return config
|
||||
@ -111,8 +102,5 @@ def get_async_callback_manager_for_config(
|
||||
|
||||
@contextmanager
|
||||
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
|
||||
if config.get("executor"):
|
||||
yield config["executor"]
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
|
||||
yield executor
|
||||
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
|
||||
yield executor
|
||||
|
Loading…
Reference in New Issue
Block a user