diff --git a/libs/langchain/langchain/schema/runnable/_locals.py b/libs/langchain/langchain/schema/runnable/_locals.py index 755a709fc9..5b2f8e758a 100644 --- a/libs/langchain/langchain/schema/runnable/_locals.py +++ b/libs/langchain/langchain/schema/runnable/_locals.py @@ -1,16 +1,27 @@ from __future__ import annotations -from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union - -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Iterator, + Mapping, + Optional, + Union, ) + from langchain.load.serializable import Serializable from langchain.schema.runnable.base import Input, Output, Runnable from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.passthrough import RunnablePassthrough +if TYPE_CHECKING: + from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + ) + class PutLocalVar(RunnablePassthrough): key: Union[str, Mapping[str, str]] diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index ce4e11861e..a431fb6358 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -3,10 +3,11 @@ from __future__ import annotations from concurrent.futures import Executor, ThreadPoolExecutor from contextlib import contextmanager from copy import deepcopy -from typing import Any, Dict, Generator, List, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict -from langchain.callbacks.base import BaseCallbackManager, Callbacks -from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager +if TYPE_CHECKING: + from langchain.callbacks.base import BaseCallbackManager, Callbacks + from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager class RunnableConfig(TypedDict, total=False): @@ -87,6 +88,8 @@ def patch_config( def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: + from langchain.callbacks.manager import CallbackManager + return CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), inheritable_tags=config.get("tags"), @@ -97,6 +100,8 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: def get_async_callback_manager_for_config( config: RunnableConfig, ) -> AsyncCallbackManager: + from langchain.callbacks.manager import AsyncCallbackManager + return AsyncCallbackManager.configure( inheritable_callbacks=config.get("callbacks"), inheritable_tags=config.get("tags"),