Resolve circular imports in runnables (#9675)

These are about to cause circular imports.
pull/9755/head
Nuno Campos 11 months ago committed by GitHub
commit 6283f3b63c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,16 +1,27 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union from typing import (
TYPE_CHECKING,
from langchain.callbacks.manager import ( Any,
AsyncCallbackManagerForChainRun, AsyncIterator,
CallbackManagerForChainRun, Dict,
Iterator,
Mapping,
Optional,
Union,
) )
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Output, Runnable from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.passthrough import RunnablePassthrough
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
class PutLocalVar(RunnablePassthrough): class PutLocalVar(RunnablePassthrough):
key: Union[str, Mapping[str, str]] key: Union[str, Mapping[str, str]]

@ -3,10 +3,11 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Generator, List, Optional, TypedDict from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict
from langchain.callbacks.base import BaseCallbackManager, Callbacks if TYPE_CHECKING:
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
class RunnableConfig(TypedDict, total=False): class RunnableConfig(TypedDict, total=False):
@ -87,6 +88,8 @@ def patch_config(
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
from langchain.callbacks.manager import CallbackManager
return CallbackManager.configure( return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"), inheritable_tags=config.get("tags"),
@ -97,6 +100,8 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
def get_async_callback_manager_for_config( def get_async_callback_manager_for_config(
config: RunnableConfig, config: RunnableConfig,
) -> AsyncCallbackManager: ) -> AsyncCallbackManager:
from langchain.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager.configure( return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"), inheritable_tags=config.get("tags"),

Loading…
Cancel
Save