Propagate context vars in all classes/methods

- Any direct usage of ThreadPoolExecutor or asyncio.run_in_executor needs manual handling of context vars
pull/15329/head
Nuno Campos 6 months ago
parent 70e5d05952
commit eb5e250188

@ -1,12 +1,9 @@
"""ChatModel wrapper which returns user input as the response.."""
import asyncio
from functools import partial
from io import StringIO
from typing import Any, Callable, Dict, List, Mapping, Optional
import yaml
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
@ -111,15 +108,3 @@ class HumanInputChatModel(BaseChatModel):
self.message_func(messages, **self.message_kwargs)
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
return ChatResult(generations=[ChatGeneration(message=user_input)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@ -1,11 +1,8 @@
import asyncio
import logging
from functools import partial
from typing import Any, Dict, List, Mapping, Optional
from urllib.parse import urlparse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
@ -125,18 +122,6 @@ class ChatMlflow(BaseChatModel):
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
return ChatMlflow._create_chat_result(resp)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)
@property
def _identifying_params(self) -> Dict[str, Any]:
return self._default_params

@ -1,11 +1,8 @@
import asyncio
import logging
import warnings
from functools import partial
from typing import Any, Dict, List, Mapping, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
@ -116,18 +113,6 @@ class ChatMLflowAIGateway(BaseChatModel):
resp = mlflow.gateway.query(self.route, data=data)
return ChatMLflowAIGateway._create_chat_result(resp)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)
@property
def _identifying_params(self) -> Dict[str, Any]:
return self._default_params

@ -1,7 +1,5 @@
import asyncio
import json
import logging
from functools import partial
from typing import Any, AsyncIterator, Dict, List, Optional, cast
import requests
@ -300,25 +298,3 @@ class PaiEasChatEndpoint(BaseChatModel):
# break if stop sequence found
if stop_seq_found:
break
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
if stream if stream is not None else self.streaming:
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
generation = chunk
assert generation is not None
return ChatResult(generations=[generation])
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@ -1,11 +1,11 @@
import asyncio
import json
import os
from functools import partial
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.runnables.config import run_in_executor
class BedrockEmbeddings(BaseModel, Embeddings):
@ -181,9 +181,7 @@ class BedrockEmbeddings(BaseModel, Embeddings):
Embeddings for the text.
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.embed_query, text)
)
return await run_in_executor(None, self.embed_query, text)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous compute doc embeddings using a Bedrock model.

@ -1,12 +1,12 @@
import asyncio
import logging
import threading
from functools import partial
from typing import Dict, List, Optional
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import run_in_executor
from langchain_core.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
@ -134,9 +134,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
List[float]: Embeddings for the text.
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.embed_query, text)
)
return await run_in_executor(None, self.embed_query, text)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.

@ -1,8 +1,6 @@
import asyncio
from typing import TYPE_CHECKING, Optional, Type
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field
@ -57,11 +55,3 @@ Note: SessionId must be received from previous Browser window creation."""
print(f"{e}, retrying...")
except Exception as e:
raise Exception(f"An error occurred: {e}")
async def _arun(
self,
sessionId: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> None:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._run, sessionId)

@ -1,8 +1,6 @@
import asyncio
from typing import TYPE_CHECKING, Optional, Type
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field
@ -67,14 +65,3 @@ class MultionCreateSession(BaseTool):
}
except Exception as e:
raise Exception(f"An error occurred: {e}")
async def _arun(
self,
query: str,
url: Optional[str] = "https://www.google.com/",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> dict:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, self._run, query, url)
return result

@ -1,8 +1,6 @@
import asyncio
from typing import TYPE_CHECKING, Optional, Type
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field
@ -74,15 +72,3 @@ Note: sessionId must be received from previous Browser window creation."""
return {"error": f"{e}", "Response": "retrying..."}
except Exception as e:
raise Exception(f"An error occurred: {e}")
async def _arun(
self,
sessionId: str,
query: str,
url: Optional[str] = "https://www.google.com/",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> dict:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, self._run, sessionId, query, url)
return result

@ -1,10 +1,8 @@
import asyncio
import platform
import warnings
from typing import Any, List, Optional, Type, Union
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
@ -77,13 +75,3 @@ class ShellTool(BaseTool):
) -> str:
"""Run commands and return final output."""
return self.process.run(commands)
async def _arun(
self,
commands: Union[str, List[str]],
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Run commands asynchronously and return final output."""
return await asyncio.get_event_loop().run_in_executor(
None, self.process.run, commands
)

@ -1,13 +1,11 @@
from __future__ import annotations
import asyncio
import logging
import operator
import os
import pickle
import uuid
import warnings
from functools import partial
from pathlib import Path
from typing import (
Any,
@ -24,6 +22,7 @@ from typing import (
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore
from langchain_community.docstore.base import AddableMixin, Docstore
@ -359,7 +358,8 @@ class FAISS(VectorStore):
"""
# This is a temporary workaround to make the similarity search asynchronous.
func = partial(
return await run_in_executor(
None,
self.similarity_search_with_score_by_vector,
embedding,
k=k,
@ -367,7 +367,6 @@ class FAISS(VectorStore):
fetch_k=fetch_k,
**kwargs,
)
return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_with_score(
self,
@ -640,7 +639,8 @@ class FAISS(VectorStore):
relevance and score for each.
"""
# This is a temporary workaround to make the similarity search asynchronous.
func = partial(
return await run_in_executor(
None,
self.max_marginal_relevance_search_with_score_by_vector,
embedding,
k=k,
@ -648,7 +648,6 @@ class FAISS(VectorStore):
lambda_mult=lambda_mult,
filter=filter,
)
return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search_by_vector(
self,

@ -1,11 +1,9 @@
from __future__ import annotations
import asyncio
import contextlib
import enum
import logging
import uuid
from functools import partial
from typing import (
Any,
Callable,
@ -31,6 +29,7 @@ except ImportError:
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore
@ -941,7 +940,8 @@ class PGVector(VectorStore):
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(
return await run_in_executor(
None,
self.max_marginal_relevance_search_by_vector,
embedding,
k=k,
@ -950,4 +950,3 @@ class PGVector(VectorStore):
filter=filter,
**kwargs,
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import functools
import uuid
import warnings
@ -25,6 +24,7 @@ from typing import (
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
@ -58,10 +58,9 @@ def sync_call_fallback(method: Callable) -> Callable:
# by removing the first letter from the method name. For example,
# if the async method is called ``aaad_texts``, the synchronous method
# will be called ``aad_texts``.
sync_method = functools.partial(
getattr(self, method.__name__[1:]), *args, **kwargs
return await run_in_executor(
None, getattr(self, method.__name__[1:]), *args, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, sync_method)
return wrapper

@ -23,7 +23,7 @@ from langchain_core.runnables.base import (
RunnableSerializable,
coerce_to_runnable,
)
from langchain_core.runnables.config import RunnableConfig, patch_config
from langchain_core.runnables.config import RunnableConfig, ensure_config, patch_config
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
T = TypeVar("T")
@ -186,7 +186,7 @@ class ContextGet(RunnableSerializable):
]
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
config = config or {}
config = ensure_config(config)
configurable = config.get("configurable", {})
if isinstance(self.key, list):
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
@ -196,7 +196,7 @@ class ContextGet(RunnableSerializable):
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = config or {}
config = ensure_config(config)
configurable = config.get("configurable", {})
if isinstance(self.key, list):
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
@ -281,7 +281,7 @@ class ContextSet(RunnableSerializable):
]
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
config = config or {}
config = ensure_config(config)
configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None:
@ -293,7 +293,7 @@ class ContextSet(RunnableSerializable):
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = config or {}
config = ensure_config(config)
configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None:

@ -4,13 +4,15 @@ import asyncio
import functools
import logging
import uuid
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager
from contextvars import Context, copy_context
from contextvars import copy_context
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Coroutine,
Dict,
Generator,
@ -272,25 +274,14 @@ def handle_event(
# we end up in a deadlock, as we'd have gotten here from a
# running coroutine, which we cannot interrupt to run this one.
# The solution is to create a new loop in a new thread.
with _executor_w_context(1) as executor:
executor.submit(_run_coros, coros).result()
with ThreadPoolExecutor(1) as executor:
executor.submit(
cast(Callable, copy_context().run), _run_coros, coros
).result()
else:
_run_coros(coros)
def _set_context(context: Context) -> None:
for var, value in context.items():
var.set(value)
def _executor_w_context(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
return ThreadPoolExecutor(
max_workers=max_workers,
initializer=_set_context,
initargs=(copy_context(),),
)
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
if hasattr(asyncio, "Runner"):
# Python 3.11+
@ -315,7 +306,6 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
async def _ahandle_event_for_handler(
executor: ThreadPoolExecutor,
handler: BaseCallbackHandler,
event_name: str,
ignore_condition_name: Optional[str],
@ -332,13 +322,18 @@ async def _ahandle_event_for_handler(
event(*args, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
executor, functools.partial(event, *args, **kwargs)
None,
cast(
Callable,
functools.partial(
copy_context().run, event, *args, **kwargs
),
),
)
except NotImplementedError as e:
if event_name == "on_chat_model_start":
message_strings = [get_buffer_string(m) for m in args[1]]
await _ahandle_event_for_handler(
executor,
handler,
"on_llm_start",
"ignore_llm",
@ -380,25 +375,23 @@ async def ahandle_event(
*args: The arguments to pass to the event handler
**kwargs: The keyword arguments to pass to the event handler
"""
with _executor_w_context() as executor:
for handler in [h for h in handlers if h.run_inline]:
await _ahandle_event_for_handler(
executor, handler, event_name, ignore_condition_name, *args, **kwargs
)
await asyncio.gather(
*(
_ahandle_event_for_handler(
executor,
handler,
event_name,
ignore_condition_name,
*args,
**kwargs,
)
for handler in handlers
if not handler.run_inline
for handler in [h for h in handlers if h.run_inline]:
await _ahandle_event_for_handler(
handler, event_name, ignore_condition_name, *args, **kwargs
)
await asyncio.gather(
*(
_ahandle_event_for_handler(
handler,
event_name,
ignore_condition_name,
*args,
**kwargs,
)
for handler in handlers
if not handler.run_inline
)
)
BRM = TypeVar("BRM", bound="BaseRunManager")
@ -526,9 +519,17 @@ class ParentRunManager(RunManager):
return manager
class AsyncRunManager(BaseRunManager):
class AsyncRunManager(BaseRunManager, ABC):
"""Async Run Manager."""
@abstractmethod
def get_sync(self) -> RunManager:
"""Get the equivalent sync RunManager.
Returns:
RunManager: The sync RunManager.
"""
async def on_text(
self,
text: str,
@ -664,6 +665,23 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"""Async callback manager for LLM run."""
def get_sync(self) -> CallbackManagerForLLMRun:
"""Get the equivalent sync RunManager.
Returns:
CallbackManagerForLLMRun: The sync RunManager.
"""
return CallbackManagerForLLMRun(
run_id=self.run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
async def on_llm_new_token(
self,
token: str,
@ -818,6 +836,23 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
"""Async callback manager for chain run."""
def get_sync(self) -> CallbackManagerForChainRun:
"""Get the equivalent sync RunManager.
Returns:
CallbackManagerForChainRun: The sync RunManager.
"""
return CallbackManagerForChainRun(
run_id=self.run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
) -> None:
@ -948,6 +983,23 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
"""Async callback manager for tool run."""
def get_sync(self) -> CallbackManagerForToolRun:
"""Get the equivalent sync RunManager.
Returns:
CallbackManagerForToolRun: The sync RunManager.
"""
return CallbackManagerForToolRun(
run_id=self.run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running.
@ -1031,6 +1083,23 @@ class AsyncCallbackManagerForRetrieverRun(
):
"""Async callback manager for retriever run."""
def get_sync(self) -> CallbackManagerForRetrieverRun:
"""Get the equivalent sync RunManager.
Returns:
CallbackManagerForRetrieverRun: The sync RunManager.
"""
return CallbackManagerForRetrieverRun(
run_id=self.run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
async def on_retriever_end(
self, documents: Sequence[Document], **kwargs: Any
) -> None:

@ -1,10 +1,10 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from functools import partial
from typing import TYPE_CHECKING, Any, Sequence
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING:
from langchain_core.documents import Document
@ -69,6 +69,6 @@ class BaseDocumentTransformer(ABC):
Returns:
A list of transformed Documents.
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
return await run_in_executor(
None, self.transform_documents, documents, **kwargs
)

@ -1,7 +1,8 @@
import asyncio
from abc import ABC, abstractmethod
from typing import List
from langchain_core.runnables.config import run_in_executor
class Embeddings(ABC):
"""Interface for embedding models."""
@ -16,12 +17,8 @@ class Embeddings(ABC):
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_documents, texts
)
return await run_in_executor(None, self.embed_documents, texts)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_query, text
)
return await run_in_executor(None, self.embed_query, text)

@ -4,7 +4,6 @@ import asyncio
import inspect
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
@ -45,6 +44,7 @@ from langchain_core.outputs import (
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables.config import ensure_config, run_in_executor
if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
@ -158,7 +158,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = config or {}
config = ensure_config(config)
return cast(
ChatGeneration,
self.generate_prompt(
@ -180,7 +180,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = config or {}
config = ensure_config(config)
llm_result = await self.agenerate_prompt(
[self._convert_input(input)],
stop=stop,
@ -206,7 +206,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
config = ensure_config(config)
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
@ -264,7 +264,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
await self.ainvoke(input, config=config, stop=stop, **kwargs),
)
else:
config = config or {}
config = ensure_config(config)
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
@ -605,8 +605,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), messages, stop, run_manager
return await run_in_executor(
None,
self._generate,
messages,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
def _stream(
@ -766,7 +771,11 @@ class SimpleChatModel(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
return await run_in_executor(
None,
self._generate,
messages,
stop=stop,
run_manager=run_manager.get_sync() if run_manager else None,
**kwargs,
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@ -8,7 +8,6 @@ import json
import logging
import warnings
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import (
Any,
@ -52,7 +51,8 @@ from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator, validator
from langchain_core.runnables import RunnableConfig, get_config_list
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
from langchain_core.runnables.config import run_in_executor
logger = logging.getLogger(__name__)
@ -221,7 +221,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
config = config or {}
config = ensure_config(config)
return (
self.generate_prompt(
[self._convert_input(input)],
@ -244,7 +244,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
config = config or {}
config = ensure_config(config)
llm_result = await self.agenerate_prompt(
[self._convert_input(input)],
stop=stop,
@ -362,7 +362,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
yield self.invoke(input, config=config, stop=stop, **kwargs)
else:
prompt = self._convert_input(input).to_string()
config = config or {}
config = ensure_config(config)
params = self.dict()
params["stop"] = stop
params = {**params, **kwargs}
@ -419,7 +419,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
else:
prompt = self._convert_input(input).to_string()
config = config or {}
config = ensure_config(config)
params = self.dict()
params["stop"] = stop
params = {**params, **kwargs}
@ -483,8 +483,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompts."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), prompts, stop, run_manager
return await run_in_executor(
None,
self._generate,
prompts,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
def _stream(
@ -1049,8 +1054,13 @@ class LLM(BaseLLM):
**kwargs: Any,
) -> str:
"""Run the LLM on the given prompt and input."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._call, **kwargs), prompt, stop, run_manager
return await run_in_executor(
None,
self._call,
prompt,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
def _generate(

@ -1,7 +1,5 @@
from __future__ import annotations
import asyncio
import functools
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
@ -20,6 +18,7 @@ from typing_extensions import get_args
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING:
from langchain_core.prompt_values import PromptValue
@ -54,9 +53,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
return await run_in_executor(None, self.parse_result, result)
class BaseGenerationOutputParser(
@ -247,9 +244,7 @@ class BaseOutputParser(
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(
None, functools.partial(self.parse_result, partial=partial), result
)
return await run_in_executor(None, self.parse_result, result, partial=partial)
async def aparse(self, text: str) -> T:
"""Parse a single string model output into some structure.
@ -260,7 +255,7 @@ class BaseOutputParser(
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text)
return await run_in_executor(None, self.parse, text)
# TODO: rename 'completion' -> 'text'.
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:

@ -1,15 +1,19 @@
from __future__ import annotations
import asyncio
import warnings
from abc import ABC, abstractmethod
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableSerializable,
ensure_config,
)
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
@ -113,7 +117,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
def invoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
config = config or {}
config = ensure_config(config)
return self.get_relevant_documents(
input,
callbacks=config.get("callbacks"),
@ -128,7 +132,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> List[Document]:
config = config or {}
config = ensure_config(config)
return await self.aget_relevant_documents(
input,
callbacks=config.get("callbacks"),
@ -159,8 +163,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
Returns:
List of relevant documents
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._get_relevant_documents, run_manager=run_manager), query
return await run_in_executor(
None,
self._get_relevant_documents,
query,
run_manager=run_manager.get_sync(),
)
def get_relevant_documents(

@ -27,8 +27,10 @@ from langchain_core.runnables.base import (
from langchain_core.runnables.branch import RunnableBranch
from langchain_core.runnables.config import (
RunnableConfig,
ensure_config,
get_config_list,
patch_config,
run_in_executor,
)
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
from langchain_core.runnables.passthrough import (
@ -42,6 +44,7 @@ from langchain_core.runnables.utils import (
ConfigurableField,
ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption,
ConfigurableFieldSpec,
aadd,
add,
)
@ -51,6 +54,9 @@ __all__ = [
"ConfigurableField",
"ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption",
"ConfigurableFieldSpec",
"ensure_config",
"run_in_executor",
"patch_config",
"RouterInput",
"RouterRunnable",

@ -6,7 +6,7 @@ import threading
from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait
from copy import deepcopy
from functools import partial, wraps
from functools import wraps
from itertools import groupby, tee
from operator import itemgetter
from typing import (
@ -47,6 +47,7 @@ from langchain_core.runnables.config import (
get_executor_for_config,
merge_configs,
patch_config,
run_in_executor,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import (
@ -472,10 +473,7 @@ class Runnable(Generic[Input, Output], ABC):
Subclasses should override this method if they can run asynchronously.
"""
with get_executor_for_config(config) as executor:
return await asyncio.get_running_loop().run_in_executor(
executor, partial(self.invoke, **kwargs), input, config
)
return await run_in_executor(config, self.invoke, input, config, **kwargs)
def batch(
self,
@ -665,7 +663,7 @@ class Runnable(Generic[Input, Output], ABC):
)
# Assign the stream handler to the config
config = config or {}
config = ensure_config(config)
callbacks = config.get("callbacks")
if callbacks is None:
config["callbacks"] = [stream]
@ -2883,10 +2881,7 @@ class RunnableLambda(Runnable[Input, Output]):
@wraps(self.func)
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
with get_executor_for_config(config) as executor:
return await asyncio.get_running_loop().run_in_executor(
executor, partial(self.func, **kwargs), *args
)
return await run_in_executor(config, self.func, *args, **kwargs)
afunc = f
@ -2913,7 +2908,7 @@ class RunnableLambda(Runnable[Input, Output]):
def _config(
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
) -> RunnableConfig:
config = config or {}
config = ensure_config(config)
if config.get("run_name") is None:
try:
@ -3052,9 +3047,7 @@ class RunnableLambda(Runnable[Input, Output]):
@wraps(self.func)
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.func, **kwargs), *args
)
return await run_in_executor(config, self.func, *args, **kwargs)
afunc = f

@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import Context, copy_context
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
@ -12,6 +14,8 @@ from typing import (
Generator,
List,
Optional,
ParamSpec,
TypeVar,
Union,
cast,
)
@ -412,3 +416,36 @@ def get_executor_for_config(
initargs=(copy_context(),),
) as executor:
yield executor
P = ParamSpec("P")
T = TypeVar("T")
async def run_in_executor(
executor_or_config: Optional[Union[Executor, RunnableConfig]],
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
"""Run a function in an executor.
Args:
executor (Executor): The executor.
func (Callable[P, Output]): The function.
*args (Any): The positional arguments to the function.
**kwargs (Any): The keyword arguments to the function.
Returns:
Output: The output of the function.
"""
if executor_or_config is None or isinstance(executor_or_config, dict):
# Use default executor with context copied from current context
return await asyncio.get_running_loop().run_in_executor(
None,
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs)),
)
return await asyncio.get_running_loop().run_in_executor(
executor_or_config, partial(func, **kwargs), *args
)

@ -23,6 +23,7 @@ from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
ensure_config,
get_config_list,
get_executor_for_config,
)
@ -259,7 +260,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
config = config or {}
config = ensure_config(config)
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
configurable_fields = {
specs_by_id[k][0]: v
@ -392,7 +393,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
config = config or {}
config = ensure_config(config)
which = config.get("configurable", {}).get(self.which.id, self.default_key)
# remap configurable keys for the chosen alternative
if self.prefix_keys:

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import inspect
from typing import (
TYPE_CHECKING,
@ -18,6 +17,7 @@ from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load import load
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.config import run_in_executor
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
@ -331,9 +331,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
async def _aenter_history(
self, input: Dict[str, Any], config: RunnableConfig
) -> List[BaseMessage]:
return await asyncio.get_running_loop().run_in_executor(
None, self._enter_history, input, config
)
return await run_in_executor(config, self._enter_history, input, config)
def _exit_history(self, run: Run, config: RunnableConfig) -> None:
hist = config["configurable"]["message_history"]

@ -31,6 +31,7 @@ from langchain_core.runnables.config import (
RunnableConfig,
acall_func_with_variable_args,
call_func_with_variable_args,
ensure_config,
get_executor_for_config,
patch_config,
)
@ -206,7 +207,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Other:
if self.func is not None:
call_func_with_variable_args(self.func, input, config or {}, **kwargs)
call_func_with_variable_args(
self.func, input, ensure_config(config), **kwargs
)
return self._call_with_config(identity, input, config)
async def ainvoke(
@ -217,10 +220,12 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
) -> Other:
if self.afunc is not None:
await acall_func_with_variable_args(
self.afunc, input, config or {}, **kwargs
self.afunc, input, ensure_config(config), **kwargs
)
elif self.func is not None:
call_func_with_variable_args(self.func, input, config or {}, **kwargs)
call_func_with_variable_args(
self.func, input, ensure_config(config), **kwargs
)
return await self._acall_with_config(aidentity, input, config)
def transform(
@ -243,7 +248,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
final = final + chunk
if final is not None:
call_func_with_variable_args(self.func, final, config or {}, **kwargs)
call_func_with_variable_args(
self.func, final, ensure_config(config), **kwargs
)
async def atransform(
self,
@ -269,7 +276,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
final = final + chunk
if final is not None:
config = config or {}
config = ensure_config(config)
if self.afunc is not None:
await acall_func_with_variable_args(
self.afunc, final, config, **kwargs
@ -458,7 +465,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
)
# get executor to start map output stream in background
with get_executor_for_config(config or {}) as executor:
with get_executor_for_config(config) as executor:
# start map output stream
first_map_chunk_future = executor.submit(
next,

@ -1,11 +1,9 @@
"""Base implementation for tools or skills."""
from __future__ import annotations
import asyncio
import inspect
import warnings
from abc import abstractmethod
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
@ -26,7 +24,13 @@ from langchain_core.pydantic_v1 import (
root_validator,
validate_arguments,
)
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableSerializable,
ensure_config,
)
from langchain_core.runnables.config import run_in_executor
class SchemaAnnotationError(TypeError):
@ -202,7 +206,7 @@ class ChildTool(BaseTool):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
config = config or {}
config = ensure_config(config)
return self.run(
input,
callbacks=config.get("callbacks"),
@ -218,7 +222,7 @@ class ChildTool(BaseTool):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
config = config or {}
config = ensure_config(config)
return await self.arun(
input,
callbacks=config.get("callbacks"),
@ -280,11 +284,7 @@ class ChildTool(BaseTool):
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""
return await asyncio.get_running_loop().run_in_executor(
None,
partial(self._run, **kwargs),
*args,
)
return await run_in_executor(None, self._run, *args, **kwargs)
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
@ -468,9 +468,7 @@ class Tool(BaseTool):
) -> Any:
if not self.coroutine:
# If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
return await run_in_executor(config, self.invoke, input, config, **kwargs)
return await super().ainvoke(input, config, **kwargs)
@ -538,8 +536,12 @@ class Tool(BaseTool):
else await self.coroutine(*args, **kwargs)
)
else:
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._run, run_manager=run_manager, **kwargs), *args
return await run_in_executor(
None,
self._run,
run_manager=run_manager.get_sync() if run_manager else None,
*args,
**kwargs,
)
# TODO: this is for backwards compatibility, remove in future
@ -599,9 +601,7 @@ class StructuredTool(BaseTool):
) -> Any:
if not self.coroutine:
# If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
return await run_in_executor(config, self.invoke, input, config, **kwargs)
return await super().ainvoke(input, config, **kwargs)
@ -652,10 +652,12 @@ class StructuredTool(BaseTool):
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
return await asyncio.get_running_loop().run_in_executor(
return await run_in_executor(
None,
partial(self._run, run_manager=run_manager, **kwargs),
self._run,
run_manager=run_manager.get_sync() if run_manager else None,
*args,
**kwargs,
)
@classmethod

@ -1,11 +1,9 @@
from __future__ import annotations
import asyncio
import logging
import math
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
@ -24,6 +22,7 @@ from typing import (
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
@ -103,9 +102,7 @@ class VectorStore(ABC):
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.add_texts, **kwargs), texts, metadatas
)
return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs)
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore.
@ -224,8 +221,9 @@ class VectorStore(ABC):
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search_with_score, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
return await run_in_executor(
None, self.similarity_search_with_score, *args, **kwargs
)
def _similarity_search_with_relevance_scores(
self,
@ -383,8 +381,7 @@ class VectorStore(ABC):
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search, query, k=k, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
return await run_in_executor(None, self.similarity_search, query, k=k, **kwargs)
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
@ -408,8 +405,9 @@ class VectorStore(ABC):
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
return await run_in_executor(
None, self.similarity_search_by_vector, embedding, k=k, **kwargs
)
def max_marginal_relevance_search(
self,
@ -450,7 +448,8 @@ class VectorStore(ABC):
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(
return await run_in_executor(
None,
self.max_marginal_relevance_search,
query,
k=k,
@ -458,7 +457,6 @@ class VectorStore(ABC):
lambda_mult=lambda_mult,
**kwargs,
)
return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search_by_vector(
self,
@ -541,8 +539,8 @@ class VectorStore(ABC):
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from texts and embeddings."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
return await run_in_executor(
None, cls.from_texts, texts, embedding, metadatas, **kwargs
)
def _get_retriever_tags(self) -> List[str]:

@ -5,6 +5,9 @@ EXPECTED_ALL = [
"ConfigurableField",
"ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption",
"ConfigurableFieldSpec",
"ensure_config",
"run_in_executor",
"patch_config",
"RouterInput",
"RouterRunnable",

@ -1,7 +1,6 @@
"""A tool for running python code in a REPL."""
import ast
import asyncio
import re
import sys
from contextlib import redirect_stdout
@ -14,6 +13,7 @@ from langchain.callbacks.manager import (
)
from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.tools.base import BaseTool
from langchain_core.runnables.config import run_in_executor
from langchain_experimental.utilities.python import PythonREPL
@ -72,10 +72,7 @@ class PythonREPLTool(BaseTool):
if self.sanitize_input:
query = sanitize_input(query)
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, self.run, query)
return result
return await run_in_executor(None, self.run, query)
class PythonInputs(BaseModel):
@ -144,7 +141,4 @@ class PythonAstREPLTool(BaseTool):
) -> Any:
"""Use the tool asynchronously."""
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, self._run, query)
return result
return await run_in_executor(None, self._run, query)

@ -30,7 +30,7 @@ from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from langchain_core.runnables.utils import AddableDict
from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping
@ -1437,7 +1437,7 @@ class AgentExecutor(Chain):
**kwargs: Any,
) -> Iterator[AddableDict]:
"""Enables streaming over steps taken to reach final output."""
config = config or {}
config = ensure_config(config)
iterator = AgentExecutorIterator(
self,
input,
@ -1458,7 +1458,7 @@ class AgentExecutor(Chain):
**kwargs: Any,
) -> AsyncIterator[AddableDict]:
"""Enables streaming over steps taken to reach final output."""
config = config or {}
config = ensure_config(config)
iterator = AgentExecutorIterator(
self,
input,

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Un
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.load import dumpd
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
from langchain_core.tools import BaseTool
from langchain.callbacks.manager import CallbackManager
@ -222,7 +222,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]].
"""
config = config or {}
config = ensure_config(config)
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),

@ -1,4 +1,3 @@
import asyncio
import json
from json import JSONDecodeError
from typing import List, Union
@ -85,12 +84,5 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
message = result[0].message
return self._parse_ai_message(message)
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
raise ValueError("Can only parse messages")

@ -1,4 +1,3 @@
import asyncio
import json
from json import JSONDecodeError
from typing import List, Union
@ -92,12 +91,5 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
message = result[0].message
return parse_ai_message_to_openai_tool_action(message)
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages")

@ -1,5 +1,4 @@
"""Base interface that all chains should implement."""
import asyncio
import inspect
import json
import logging
@ -19,7 +18,12 @@ from langchain_core.pydantic_v1 import (
root_validator,
validator,
)
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables import (
RunnableConfig,
RunnableSerializable,
ensure_config,
run_in_executor,
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
@ -85,7 +89,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
config = config or {}
config = ensure_config(config)
return self(
input,
callbacks=config.get("callbacks"),
@ -101,7 +105,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
config = config or {}
config = ensure_config(config)
return await self.acall(
input,
callbacks=config.get("callbacks"),
@ -245,9 +249,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self._call, inputs, run_manager
)
return await run_in_executor(None, self._call, inputs, run_manager)
def __call__(
self,

@ -1,16 +1,15 @@
"""Interfaces to be implemented by general evaluators."""
from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from enum import Enum
from functools import partial
from typing import Any, Optional, Sequence, Tuple, Union
from warnings import warn
from langchain_core.agents import AgentAction
from langchain_core.language_models import BaseLanguageModel
from langchain_core.runnables.config import run_in_executor
from langchain.chains.base import Chain
@ -189,15 +188,13 @@ class StringEvaluator(_EvalArgsMixin, ABC):
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
""" # noqa: E501
return await asyncio.get_running_loop().run_in_executor(
return await run_in_executor(
None,
partial(
self._evaluate_strings,
prediction=prediction,
reference=reference,
input=input,
**kwargs,
),
self._evaluate_strings,
prediction=prediction,
reference=reference,
input=input,
**kwargs,
)
def evaluate_strings(
@ -292,16 +289,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: A dictionary containing the preference, scores, and/or other information.
""" # noqa: E501
return await asyncio.get_running_loop().run_in_executor(
return await run_in_executor(
None,
partial(
self._evaluate_string_pairs,
prediction=prediction,
prediction_b=prediction_b,
reference=reference,
input=input,
**kwargs,
),
self._evaluate_string_pairs,
prediction=prediction,
prediction_b=prediction_b,
reference=reference,
input=input,
**kwargs,
)
def evaluate_string_pairs(
@ -415,16 +410,14 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: The evaluation result.
"""
return await asyncio.get_running_loop().run_in_executor(
return await run_in_executor(
None,
partial(
self._evaluate_agent_trajectory,
prediction=prediction,
agent_trajectory=agent_trajectory,
reference=reference,
input=input,
**kwargs,
),
self._evaluate_agent_trajectory,
prediction=prediction,
agent_trajectory=agent_trajectory,
reference=reference,
input=input,
**kwargs,
)
def evaluate_agent_trajectory(

@ -1,10 +1,10 @@
import asyncio
from abc import ABC, abstractmethod
from inspect import signature
from typing import List, Optional, Sequence, Union
from langchain_core.documents import BaseDocumentTransformer, Document
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.config import run_in_executor
from langchain.callbacks.manager import Callbacks
@ -28,7 +28,7 @@ class BaseDocumentCompressor(BaseModel, ABC):
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context."""
return await asyncio.get_running_loop().run_in_executor(
return await run_in_executor(
None, self.compress_documents, documents, query, callbacks
)

@ -21,7 +21,6 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive
from __future__ import annotations
import asyncio
import copy
import logging
import pathlib
@ -29,7 +28,6 @@ import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from functools import partial
from io import BytesIO, StringIO
from typing import (
AbstractSet,
@ -283,14 +281,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)
class CharacterTextSplitter(TextSplitter):
"""Splitting text that looks at characters."""

Loading…
Cancel
Save