Resolve circular imports in runnables (#9675)

These are about to cause circular imports.
pull/9755/head
Nuno Campos 11 months ago committed by GitHub
commit 6283f3b63c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,16 +1,27 @@
from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
Mapping,
Optional,
Union,
)
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
class PutLocalVar(RunnablePassthrough):
key: Union[str, Mapping[str, str]]

@ -3,10 +3,11 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Dict, Generator, List, Optional, TypedDict
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
if TYPE_CHECKING:
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
class RunnableConfig(TypedDict, total=False):
@ -87,6 +88,8 @@ def patch_config(
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
from langchain.callbacks.manager import CallbackManager
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
@ -97,6 +100,8 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
def get_async_callback_manager_for_config(
config: RunnableConfig,
) -> AsyncCallbackManager:
from langchain.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),

Loading…
Cancel
Save