pull/9675/head
Eugene Yurtsev 1 year ago
parent b88dfcb42a
commit 9e1dbd4b49

@ -1,16 +1,27 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union from typing import (
TYPE_CHECKING,
from langchain.callbacks.manager import ( Any,
AsyncCallbackManagerForChainRun, AsyncIterator,
CallbackManagerForChainRun, Dict,
Iterator,
Mapping,
Optional,
Union,
) )
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Output, Runnable from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.passthrough import RunnablePassthrough
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
class PutLocalVar(RunnablePassthrough): class PutLocalVar(RunnablePassthrough):
key: Union[str, Mapping[str, str]] key: Union[str, Mapping[str, str]]

@ -3,10 +3,11 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Generator, List, Optional, TypedDict from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict
from langchain.callbacks.base import BaseCallbackManager, Callbacks if TYPE_CHECKING:
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
class RunnableConfig(TypedDict, total=False): class RunnableConfig(TypedDict, total=False):
@ -87,6 +88,8 @@ def patch_config(
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
from langchain.callbacks.manager import CallbackManager
return CallbackManager.configure( return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"), inheritable_tags=config.get("tags"),
@ -97,6 +100,8 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
def get_async_callback_manager_for_config( def get_async_callback_manager_for_config(
config: RunnableConfig, config: RunnableConfig,
) -> AsyncCallbackManager: ) -> AsyncCallbackManager:
from langchain.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager.configure( return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"), inheritable_tags=config.get("tags"),

Loading…
Cancel
Save