From 9e1dbd4b490d423b8ed4fc699975d91cef2e7cfc Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 23 Aug 2023 22:51:49 -0400 Subject: [PATCH] x --- .../langchain/schema/runnable/_locals.py | 21 ++++++++++++++----- .../langchain/schema/runnable/config.py | 11 +++++++--- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/_locals.py b/libs/langchain/langchain/schema/runnable/_locals.py index 755a709fc9..5b2f8e758a 100644 --- a/libs/langchain/langchain/schema/runnable/_locals.py +++ b/libs/langchain/langchain/schema/runnable/_locals.py @@ -1,16 +1,27 @@ from __future__ import annotations -from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union - -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Iterator, + Mapping, + Optional, + Union, ) + from langchain.load.serializable import Serializable from langchain.schema.runnable.base import Input, Output, Runnable from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.passthrough import RunnablePassthrough +if TYPE_CHECKING: + from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + ) + class PutLocalVar(RunnablePassthrough): key: Union[str, Mapping[str, str]] diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index ce4e11861e..a431fb6358 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -3,10 +3,11 @@ from __future__ import annotations from concurrent.futures import Executor, ThreadPoolExecutor from contextlib import contextmanager from copy import deepcopy -from typing import Any, Dict, Generator, List, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict -from langchain.callbacks.base import BaseCallbackManager, Callbacks -from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager +if TYPE_CHECKING: + from langchain.callbacks.base import BaseCallbackManager, Callbacks + from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager class RunnableConfig(TypedDict, total=False): @@ -87,6 +88,8 @@ def patch_config( def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: + from langchain.callbacks.manager import CallbackManager + return CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), inheritable_tags=config.get("tags"), @@ -97,6 +100,8 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: def get_async_callback_manager_for_config( config: RunnableConfig, ) -> AsyncCallbackManager: + from langchain.callbacks.manager import AsyncCallbackManager + return AsyncCallbackManager.configure( inheritable_callbacks=config.get("callbacks"), inheritable_tags=config.get("tags"),