Propagate context vars in all classes/methods (#15329)

- Any direct usage of ThreadPoolExecutor or asyncio.run_in_executor
needs manual handling of context vars

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/15336/head
Nuno Campos 6 months ago committed by GitHub
commit 99000c612e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,12 +1,9 @@
"""ChatModel wrapper which returns user input as the response..""" """ChatModel wrapper which returns user input as the response.."""
import asyncio
from functools import partial
from io import StringIO from io import StringIO
from typing import Any, Callable, Dict, List, Mapping, Optional from typing import Any, Callable, Dict, List, Mapping, Optional
import yaml import yaml
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
@ -111,15 +108,3 @@ class HumanInputChatModel(BaseChatModel):
self.message_func(messages, **self.message_kwargs) self.message_func(messages, **self.message_kwargs)
user_input = self.input_func(messages, stop=stop, **self.input_kwargs) user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
return ChatResult(generations=[ChatGeneration(message=user_input)]) return ChatResult(generations=[ChatGeneration(message=user_input)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@ -1,11 +1,8 @@
import asyncio
import logging import logging
from functools import partial
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
@ -125,18 +122,6 @@ class ChatMlflow(BaseChatModel):
resp = self._client.predict(endpoint=self.endpoint, inputs=data) resp = self._client.predict(endpoint=self.endpoint, inputs=data)
return ChatMlflow._create_chat_result(resp) return ChatMlflow._create_chat_result(resp)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
return self._default_params return self._default_params

@ -1,11 +1,8 @@
import asyncio
import logging import logging
import warnings import warnings
from functools import partial
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
@ -116,18 +113,6 @@ class ChatMLflowAIGateway(BaseChatModel):
resp = mlflow.gateway.query(self.route, data=data) resp = mlflow.gateway.query(self.route, data=data)
return ChatMLflowAIGateway._create_chat_result(resp) return ChatMLflowAIGateway._create_chat_result(resp)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
return self._default_params return self._default_params

@ -1,7 +1,5 @@
import asyncio
import json import json
import logging import logging
from functools import partial
from typing import Any, AsyncIterator, Dict, List, Optional, cast from typing import Any, AsyncIterator, Dict, List, Optional, cast
import requests import requests
@ -300,25 +298,3 @@ class PaiEasChatEndpoint(BaseChatModel):
# break if stop sequence found # break if stop sequence found
if stop_seq_found: if stop_seq_found:
break break
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
if stream if stream is not None else self.streaming:
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
generation = chunk
assert generation is not None
return ChatResult(generations=[generation])
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@ -1,11 +1,11 @@
import asyncio import asyncio
import json import json
import os import os
from functools import partial
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.runnables.config import run_in_executor
class BedrockEmbeddings(BaseModel, Embeddings): class BedrockEmbeddings(BaseModel, Embeddings):
@ -181,9 +181,7 @@ class BedrockEmbeddings(BaseModel, Embeddings):
Embeddings for the text. Embeddings for the text.
""" """
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(None, self.embed_query, text)
None, partial(self.embed_query, text)
)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous compute doc embeddings using a Bedrock model. """Asynchronous compute doc embeddings using a Bedrock model.

@ -1,12 +1,12 @@
import asyncio import asyncio
import logging import logging
import threading import threading
from functools import partial
from typing import Dict, List, Optional from typing import Dict, List, Optional
import requests import requests
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import run_in_executor
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -134,9 +134,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
List[float]: Embeddings for the text. List[float]: Embeddings for the text.
""" """
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(None, self.embed_query, text)
None, partial(self.embed_query, text)
)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs. """Asynchronous Embed search docs.

@ -1,8 +1,6 @@
import asyncio
from typing import TYPE_CHECKING, Optional, Type from typing import TYPE_CHECKING, Optional, Type
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
@ -57,11 +55,3 @@ Note: SessionId must be received from previous Browser window creation."""
print(f"{e}, retrying...") print(f"{e}, retrying...")
except Exception as e: except Exception as e:
raise Exception(f"An error occurred: {e}") raise Exception(f"An error occurred: {e}")
async def _arun(
self,
sessionId: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> None:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._run, sessionId)

@ -1,8 +1,6 @@
import asyncio
from typing import TYPE_CHECKING, Optional, Type from typing import TYPE_CHECKING, Optional, Type
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
@ -67,14 +65,3 @@ class MultionCreateSession(BaseTool):
} }
except Exception as e: except Exception as e:
raise Exception(f"An error occurred: {e}") raise Exception(f"An error occurred: {e}")
async def _arun(
self,
query: str,
url: Optional[str] = "https://www.google.com/",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> dict:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, self._run, query, url)
return result

@ -1,8 +1,6 @@
import asyncio
from typing import TYPE_CHECKING, Optional, Type from typing import TYPE_CHECKING, Optional, Type
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
@ -74,15 +72,3 @@ Note: sessionId must be received from previous Browser window creation."""
return {"error": f"{e}", "Response": "retrying..."} return {"error": f"{e}", "Response": "retrying..."}
except Exception as e: except Exception as e:
raise Exception(f"An error occurred: {e}") raise Exception(f"An error occurred: {e}")
async def _arun(
self,
sessionId: str,
query: str,
url: Optional[str] = "https://www.google.com/",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> dict:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, self._run, sessionId, query, url)
return result

@ -1,10 +1,8 @@
import asyncio
import platform import platform
import warnings import warnings
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
@ -77,13 +75,3 @@ class ShellTool(BaseTool):
) -> str: ) -> str:
"""Run commands and return final output.""" """Run commands and return final output."""
return self.process.run(commands) return self.process.run(commands)
async def _arun(
self,
commands: Union[str, List[str]],
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Run commands asynchronously and return final output."""
return await asyncio.get_event_loop().run_in_executor(
None, self.process.run, commands
)

@ -1,13 +1,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import operator import operator
import os import os
import pickle import pickle
import uuid import uuid
import warnings import warnings
from functools import partial
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
@ -24,6 +22,7 @@ from typing import (
import numpy as np import numpy as np
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from langchain_community.docstore.base import AddableMixin, Docstore from langchain_community.docstore.base import AddableMixin, Docstore
@ -359,7 +358,8 @@ class FAISS(VectorStore):
""" """
# This is a temporary workaround to make the similarity search asynchronous. # This is a temporary workaround to make the similarity search asynchronous.
func = partial( return await run_in_executor(
None,
self.similarity_search_with_score_by_vector, self.similarity_search_with_score_by_vector,
embedding, embedding,
k=k, k=k,
@ -367,7 +367,6 @@ class FAISS(VectorStore):
fetch_k=fetch_k, fetch_k=fetch_k,
**kwargs, **kwargs,
) )
return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_with_score( def similarity_search_with_score(
self, self,
@ -640,7 +639,8 @@ class FAISS(VectorStore):
relevance and score for each. relevance and score for each.
""" """
# This is a temporary workaround to make the similarity search asynchronous. # This is a temporary workaround to make the similarity search asynchronous.
func = partial( return await run_in_executor(
None,
self.max_marginal_relevance_search_with_score_by_vector, self.max_marginal_relevance_search_with_score_by_vector,
embedding, embedding,
k=k, k=k,
@ -648,7 +648,6 @@ class FAISS(VectorStore):
lambda_mult=lambda_mult, lambda_mult=lambda_mult,
filter=filter, filter=filter,
) )
return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, self,

@ -1,11 +1,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import contextlib import contextlib
import enum import enum
import logging import logging
import uuid import uuid
from functools import partial
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -31,6 +29,7 @@ except ImportError:
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
@ -941,7 +940,8 @@ class PGVector(VectorStore):
# This is a temporary workaround to make the similarity search # This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search # asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations. # asynchronous in the vector store implementations.
func = partial( return await run_in_executor(
None,
self.max_marginal_relevance_search_by_vector, self.max_marginal_relevance_search_by_vector,
embedding, embedding,
k=k, k=k,
@ -950,4 +950,3 @@ class PGVector(VectorStore):
filter=filter, filter=filter,
**kwargs, **kwargs,
) )
return await asyncio.get_event_loop().run_in_executor(None, func)

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import functools import functools
import uuid import uuid
import warnings import warnings
@ -25,6 +24,7 @@ from typing import (
import numpy as np import numpy as np
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance from langchain_community.vectorstores.utils import maximal_marginal_relevance
@ -58,10 +58,9 @@ def sync_call_fallback(method: Callable) -> Callable:
# by removing the first letter from the method name. For example, # by removing the first letter from the method name. For example,
# if the async method is called ``aaad_texts``, the synchronous method # if the async method is called ``aaad_texts``, the synchronous method
# will be called ``aad_texts``. # will be called ``aad_texts``.
sync_method = functools.partial( return await run_in_executor(
getattr(self, method.__name__[1:]), *args, **kwargs None, getattr(self, method.__name__[1:]), *args, **kwargs
) )
return await asyncio.get_event_loop().run_in_executor(None, sync_method)
return wrapper return wrapper

@ -23,7 +23,7 @@ from langchain_core.runnables.base import (
RunnableSerializable, RunnableSerializable,
coerce_to_runnable, coerce_to_runnable,
) )
from langchain_core.runnables.config import RunnableConfig, patch_config from langchain_core.runnables.config import RunnableConfig, ensure_config, patch_config
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
T = TypeVar("T") T = TypeVar("T")
@ -186,7 +186,7 @@ class ContextGet(RunnableSerializable):
] ]
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
config = config or {} config = ensure_config(config)
configurable = config.get("configurable", {}) configurable = config.get("configurable", {})
if isinstance(self.key, list): if isinstance(self.key, list):
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)} return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
@ -196,7 +196,7 @@ class ContextGet(RunnableSerializable):
async def ainvoke( async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any: ) -> Any:
config = config or {} config = ensure_config(config)
configurable = config.get("configurable", {}) configurable = config.get("configurable", {})
if isinstance(self.key, list): if isinstance(self.key, list):
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids)) values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
@ -281,7 +281,7 @@ class ContextSet(RunnableSerializable):
] ]
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
config = config or {} config = ensure_config(config)
configurable = config.get("configurable", {}) configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()): for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None: if mapper is not None:
@ -293,7 +293,7 @@ class ContextSet(RunnableSerializable):
async def ainvoke( async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any: ) -> Any:
config = config or {} config = ensure_config(config)
configurable = config.get("configurable", {}) configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()): for id_, mapper in zip(self.ids, self.keys.values()):
if mapper is not None: if mapper is not None:

@ -4,13 +4,15 @@ import asyncio
import functools import functools
import logging import logging
import uuid import uuid
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from contextvars import Context, copy_context from contextvars import copy_context
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncGenerator, AsyncGenerator,
Callable,
Coroutine, Coroutine,
Dict, Dict,
Generator, Generator,
@ -272,25 +274,14 @@ def handle_event(
# we end up in a deadlock, as we'd have gotten here from a # we end up in a deadlock, as we'd have gotten here from a
# running coroutine, which we cannot interrupt to run this one. # running coroutine, which we cannot interrupt to run this one.
# The solution is to create a new loop in a new thread. # The solution is to create a new loop in a new thread.
with _executor_w_context(1) as executor: with ThreadPoolExecutor(1) as executor:
executor.submit(_run_coros, coros).result() executor.submit(
cast(Callable, copy_context().run), _run_coros, coros
).result()
else: else:
_run_coros(coros) _run_coros(coros)
def _set_context(context: Context) -> None:
for var, value in context.items():
var.set(value)
def _executor_w_context(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
return ThreadPoolExecutor(
max_workers=max_workers,
initializer=_set_context,
initargs=(copy_context(),),
)
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
if hasattr(asyncio, "Runner"): if hasattr(asyncio, "Runner"):
# Python 3.11+ # Python 3.11+
@ -315,7 +306,6 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
async def _ahandle_event_for_handler( async def _ahandle_event_for_handler(
executor: ThreadPoolExecutor,
handler: BaseCallbackHandler, handler: BaseCallbackHandler,
event_name: str, event_name: str,
ignore_condition_name: Optional[str], ignore_condition_name: Optional[str],
@ -332,13 +322,18 @@ async def _ahandle_event_for_handler(
event(*args, **kwargs) event(*args, **kwargs)
else: else:
await asyncio.get_event_loop().run_in_executor( await asyncio.get_event_loop().run_in_executor(
executor, functools.partial(event, *args, **kwargs) None,
cast(
Callable,
functools.partial(
copy_context().run, event, *args, **kwargs
),
),
) )
except NotImplementedError as e: except NotImplementedError as e:
if event_name == "on_chat_model_start": if event_name == "on_chat_model_start":
message_strings = [get_buffer_string(m) for m in args[1]] message_strings = [get_buffer_string(m) for m in args[1]]
await _ahandle_event_for_handler( await _ahandle_event_for_handler(
executor,
handler, handler,
"on_llm_start", "on_llm_start",
"ignore_llm", "ignore_llm",
@ -380,25 +375,23 @@ async def ahandle_event(
*args: The arguments to pass to the event handler *args: The arguments to pass to the event handler
**kwargs: The keyword arguments to pass to the event handler **kwargs: The keyword arguments to pass to the event handler
""" """
with _executor_w_context() as executor: for handler in [h for h in handlers if h.run_inline]:
for handler in [h for h in handlers if h.run_inline]: await _ahandle_event_for_handler(
await _ahandle_event_for_handler( handler, event_name, ignore_condition_name, *args, **kwargs
executor, handler, event_name, ignore_condition_name, *args, **kwargs )
) await asyncio.gather(
await asyncio.gather( *(
*( _ahandle_event_for_handler(
_ahandle_event_for_handler( handler,
executor, event_name,
handler, ignore_condition_name,
event_name, *args,
ignore_condition_name, **kwargs,
*args,
**kwargs,
)
for handler in handlers
if not handler.run_inline
) )
for handler in handlers
if not handler.run_inline
) )
)
BRM = TypeVar("BRM", bound="BaseRunManager") BRM = TypeVar("BRM", bound="BaseRunManager")
@ -526,9 +519,17 @@ class ParentRunManager(RunManager):
return manager return manager
class AsyncRunManager(BaseRunManager): class AsyncRunManager(BaseRunManager, ABC):
"""Async Run Manager.""" """Async Run Manager."""
@abstractmethod
def get_sync(self) -> RunManager:
"""Get the equivalent sync RunManager.
Returns:
RunManager: The sync RunManager.
"""
async def on_text( async def on_text(
self, self,
text: str, text: str,
@ -664,6 +665,23 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"""Async callback manager for LLM run.""" """Async callback manager for LLM run."""
def get_sync(self) -> CallbackManagerForLLMRun:
"""Get the equivalent sync RunManager.
Returns:
CallbackManagerForLLMRun: The sync RunManager.
"""
return CallbackManagerForLLMRun(
run_id=self.run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
async def on_llm_new_token( async def on_llm_new_token(
self, self,
token: str, token: str,
@ -818,6 +836,23 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
"""Async callback manager for chain run.""" """Async callback manager for chain run."""
def get_sync(self) -> CallbackManagerForChainRun:
"""Get the equivalent sync RunManager.
Returns:
CallbackManagerForChainRun: The sync RunManager.
"""
return CallbackManagerForChainRun(
run_id=self.run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
async def on_chain_end( async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
) -> None: ) -> None:
@ -948,6 +983,23 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
"""Async callback manager for tool run.""" """Async callback manager for tool run."""
def get_sync(self) -> CallbackManagerForToolRun:
"""Get the equivalent sync RunManager.
Returns:
CallbackManagerForToolRun: The sync RunManager.
"""
return CallbackManagerForToolRun(
run_id=self.run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
async def on_tool_end(self, output: str, **kwargs: Any) -> None: async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running. """Run when tool ends running.
@ -1031,6 +1083,23 @@ class AsyncCallbackManagerForRetrieverRun(
): ):
"""Async callback manager for retriever run.""" """Async callback manager for retriever run."""
def get_sync(self) -> CallbackManagerForRetrieverRun:
"""Get the equivalent sync RunManager.
Returns:
CallbackManagerForRetrieverRun: The sync RunManager.
"""
return CallbackManagerForRetrieverRun(
run_id=self.run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
async def on_retriever_end( async def on_retriever_end(
self, documents: Sequence[Document], **kwargs: Any self, documents: Sequence[Document], **kwargs: Any
) -> None: ) -> None:

@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from typing import TYPE_CHECKING, Any, Sequence from typing import TYPE_CHECKING, Any, Sequence
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.documents import Document from langchain_core.documents import Document
@ -69,6 +69,6 @@ class BaseDocumentTransformer(ABC):
Returns: Returns:
A list of transformed Documents. A list of transformed Documents.
""" """
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, partial(self.transform_documents, **kwargs), documents None, self.transform_documents, documents, **kwargs
) )

@ -1,7 +1,8 @@
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List from typing import List
from langchain_core.runnables.config import run_in_executor
class Embeddings(ABC): class Embeddings(ABC):
"""Interface for embedding models.""" """Interface for embedding models."""
@ -16,12 +17,8 @@ class Embeddings(ABC):
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.""" """Asynchronous Embed search docs."""
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(None, self.embed_documents, texts)
None, self.embed_documents, texts
)
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text.""" """Asynchronous Embed query text."""
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(None, self.embed_query, text)
None, self.embed_query, text
)

@ -4,7 +4,6 @@ import asyncio
import inspect import inspect
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -45,6 +44,7 @@ from langchain_core.outputs import (
) )
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables.config import ensure_config, run_in_executor
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
@ -158,7 +158,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
config = config or {} config = ensure_config(config)
return cast( return cast(
ChatGeneration, ChatGeneration,
self.generate_prompt( self.generate_prompt(
@ -180,7 +180,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
config = config or {} config = ensure_config(config)
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], [self._convert_input(input)],
stop=stop, stop=stop,
@ -206,7 +206,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
) )
else: else:
config = config or {} config = ensure_config(config)
messages = self._convert_input(input).to_messages() messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs} options = {"stop": stop, **kwargs}
@ -264,7 +264,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
await self.ainvoke(input, config=config, stop=stop, **kwargs), await self.ainvoke(input, config=config, stop=stop, **kwargs),
) )
else: else:
config = config or {} config = ensure_config(config)
messages = self._convert_input(input).to_messages() messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs} options = {"stop": stop, **kwargs}
@ -605,8 +605,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call"""
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, partial(self._generate, **kwargs), messages, stop, run_manager None,
self._generate,
messages,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
) )
def _stream( def _stream(
@ -766,7 +771,11 @@ class SimpleChatModel(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
func = partial( return await run_in_executor(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs None,
self._generate,
messages,
stop=stop,
run_manager=run_manager.get_sync() if run_manager else None,
**kwargs,
) )
return await asyncio.get_event_loop().run_in_executor(None, func)

@ -8,7 +8,6 @@ import json
import logging import logging
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
@ -52,7 +51,8 @@ from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator, validator from langchain_core.pydantic_v1 import Field, root_validator, validator
from langchain_core.runnables import RunnableConfig, get_config_list from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
from langchain_core.runnables.config import run_in_executor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -221,7 +221,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
config = config or {} config = ensure_config(config)
return ( return (
self.generate_prompt( self.generate_prompt(
[self._convert_input(input)], [self._convert_input(input)],
@ -244,7 +244,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
config = config or {} config = ensure_config(config)
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], [self._convert_input(input)],
stop=stop, stop=stop,
@ -362,7 +362,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
yield self.invoke(input, config=config, stop=stop, **kwargs) yield self.invoke(input, config=config, stop=stop, **kwargs)
else: else:
prompt = self._convert_input(input).to_string() prompt = self._convert_input(input).to_string()
config = config or {} config = ensure_config(config)
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
params = {**params, **kwargs} params = {**params, **kwargs}
@ -419,7 +419,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
yield await self.ainvoke(input, config=config, stop=stop, **kwargs) yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
else: else:
prompt = self._convert_input(input).to_string() prompt = self._convert_input(input).to_string()
config = config or {} config = ensure_config(config)
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
params = {**params, **kwargs} params = {**params, **kwargs}
@ -483,8 +483,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, partial(self._generate, **kwargs), prompts, stop, run_manager None,
self._generate,
prompts,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
) )
def _stream( def _stream(
@ -1049,8 +1054,13 @@ class LLM(BaseLLM):
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, partial(self._call, **kwargs), prompt, stop, run_manager None,
self._call,
prompt,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
) )
def _generate( def _generate(

@ -1,7 +1,5 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import functools
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -20,6 +18,7 @@ from typing_extensions import get_args
from langchain_core.messages import AnyMessage, BaseMessage from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.runnables import RunnableConfig, RunnableSerializable from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue
@ -54,9 +53,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
Returns: Returns:
Structured output. Structured output.
""" """
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(None, self.parse_result, result)
None, self.parse_result, result
)
class BaseGenerationOutputParser( class BaseGenerationOutputParser(
@ -247,9 +244,7 @@ class BaseOutputParser(
Returns: Returns:
Structured output. Structured output.
""" """
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(None, self.parse_result, result, partial=partial)
None, functools.partial(self.parse_result, partial=partial), result
)
async def aparse(self, text: str) -> T: async def aparse(self, text: str) -> T:
"""Parse a single string model output into some structure. """Parse a single string model output into some structure.
@ -260,7 +255,7 @@ class BaseOutputParser(
Returns: Returns:
Structured output. Structured output.
""" """
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text) return await run_in_executor(None, self.parse, text)
# TODO: rename 'completion' -> 'text'. # TODO: rename 'completion' -> 'text'.
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:

@ -1,15 +1,19 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.load.dump import dumpd from langchain_core.load.dump import dumpd
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableSerializable,
ensure_config,
)
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.callbacks.manager import ( from langchain_core.callbacks.manager import (
@ -113,7 +117,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
def invoke( def invoke(
self, input: str, config: Optional[RunnableConfig] = None self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]: ) -> List[Document]:
config = config or {} config = ensure_config(config)
return self.get_relevant_documents( return self.get_relevant_documents(
input, input,
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
@ -128,7 +132,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Document]: ) -> List[Document]:
config = config or {} config = ensure_config(config)
return await self.aget_relevant_documents( return await self.aget_relevant_documents(
input, input,
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
@ -159,8 +163,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
Returns: Returns:
List of relevant documents List of relevant documents
""" """
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, partial(self._get_relevant_documents, run_manager=run_manager), query None,
self._get_relevant_documents,
query,
run_manager=run_manager.get_sync(),
) )
def get_relevant_documents( def get_relevant_documents(

@ -27,8 +27,10 @@ from langchain_core.runnables.base import (
from langchain_core.runnables.branch import RunnableBranch from langchain_core.runnables.branch import RunnableBranch
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
RunnableConfig, RunnableConfig,
ensure_config,
get_config_list, get_config_list,
patch_config, patch_config,
run_in_executor,
) )
from langchain_core.runnables.fallbacks import RunnableWithFallbacks from langchain_core.runnables.fallbacks import RunnableWithFallbacks
from langchain_core.runnables.passthrough import ( from langchain_core.runnables.passthrough import (
@ -42,6 +44,7 @@ from langchain_core.runnables.utils import (
ConfigurableField, ConfigurableField,
ConfigurableFieldMultiOption, ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption, ConfigurableFieldSingleOption,
ConfigurableFieldSpec,
aadd, aadd,
add, add,
) )
@ -51,6 +54,9 @@ __all__ = [
"ConfigurableField", "ConfigurableField",
"ConfigurableFieldSingleOption", "ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption", "ConfigurableFieldMultiOption",
"ConfigurableFieldSpec",
"ensure_config",
"run_in_executor",
"patch_config", "patch_config",
"RouterInput", "RouterInput",
"RouterRunnable", "RouterRunnable",

@ -6,7 +6,7 @@ import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait from concurrent.futures import FIRST_COMPLETED, wait
from copy import deepcopy from copy import deepcopy
from functools import partial, wraps from functools import wraps
from itertools import groupby, tee from itertools import groupby, tee
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
@ -47,6 +47,7 @@ from langchain_core.runnables.config import (
get_executor_for_config, get_executor_for_config,
merge_configs, merge_configs,
patch_config, patch_config,
run_in_executor,
) )
from langchain_core.runnables.graph import Graph from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
@ -472,10 +473,7 @@ class Runnable(Generic[Input, Output], ABC):
Subclasses should override this method if they can run asynchronously. Subclasses should override this method if they can run asynchronously.
""" """
with get_executor_for_config(config) as executor: return await run_in_executor(config, self.invoke, input, config, **kwargs)
return await asyncio.get_running_loop().run_in_executor(
executor, partial(self.invoke, **kwargs), input, config
)
def batch( def batch(
self, self,
@ -665,7 +663,7 @@ class Runnable(Generic[Input, Output], ABC):
) )
# Assign the stream handler to the config # Assign the stream handler to the config
config = config or {} config = ensure_config(config)
callbacks = config.get("callbacks") callbacks = config.get("callbacks")
if callbacks is None: if callbacks is None:
config["callbacks"] = [stream] config["callbacks"] = [stream]
@ -2883,10 +2881,7 @@ class RunnableLambda(Runnable[Input, Output]):
@wraps(self.func) @wraps(self.func)
async def f(*args, **kwargs): # type: ignore[no-untyped-def] async def f(*args, **kwargs): # type: ignore[no-untyped-def]
with get_executor_for_config(config) as executor: return await run_in_executor(config, self.func, *args, **kwargs)
return await asyncio.get_running_loop().run_in_executor(
executor, partial(self.func, **kwargs), *args
)
afunc = f afunc = f
@ -2913,7 +2908,7 @@ class RunnableLambda(Runnable[Input, Output]):
def _config( def _config(
self, config: Optional[RunnableConfig], callable: Callable[..., Any] self, config: Optional[RunnableConfig], callable: Callable[..., Any]
) -> RunnableConfig: ) -> RunnableConfig:
config = config or {} config = ensure_config(config)
if config.get("run_name") is None: if config.get("run_name") is None:
try: try:
@ -3052,9 +3047,7 @@ class RunnableLambda(Runnable[Input, Output]):
@wraps(self.func) @wraps(self.func)
async def f(*args, **kwargs): # type: ignore[no-untyped-def] async def f(*args, **kwargs): # type: ignore[no-untyped-def]
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(config, self.func, *args, **kwargs)
None, partial(self.func, **kwargs), *args
)
afunc = f afunc = f

@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor import asyncio
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import Context, copy_context from contextvars import ContextVar, copy_context
from functools import partial
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -10,13 +12,16 @@ from typing import (
Callable, Callable,
Dict, Dict,
Generator, Generator,
Iterable,
Iterator,
List, List,
Optional, Optional,
TypeVar,
Union, Union,
cast, cast,
) )
from typing_extensions import TypedDict from typing_extensions import ParamSpec, TypedDict
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
Input, Input,
@ -91,6 +96,11 @@ class RunnableConfig(TypedDict, total=False):
""" """
var_child_runnable_config = ContextVar(
"child_runnable_config", default=RunnableConfig()
)
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
"""Ensure that a config is a dict with all keys present. """Ensure that a config is a dict with all keys present.
@ -107,6 +117,10 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
callbacks=None, callbacks=None,
recursion_limit=25, recursion_limit=25,
) )
if var_config := var_child_runnable_config.get():
empty.update(
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
)
if config is not None: if config is not None:
empty.update( empty.update(
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}) cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
@ -388,9 +402,51 @@ def get_async_callback_manager_for_config(
) )
def _set_context(context: Context) -> None: P = ParamSpec("P")
for var, value in context.items(): T = TypeVar("T")
var.set(value)
class ContextThreadPoolExecutor(ThreadPoolExecutor):
"""ThreadPoolExecutor that copies the context to the child thread."""
def submit( # type: ignore[override]
self,
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> Future[T]:
"""Submit a function to the executor.
Args:
func (Callable[..., T]): The function to submit.
*args (Any): The positional arguments to the function.
**kwargs (Any): The keyword arguments to the function.
Returns:
Future[T]: The future for the function.
"""
return super().submit(
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs))
)
def map(
self,
fn: Callable[..., T],
*iterables: Iterable[Any],
timeout: float | None = None,
chunksize: int = 1,
) -> Iterator[T]:
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
def _wrapped_fn(*args: Any) -> T:
return contexts.pop().run(fn, *args)
return super().map(
_wrapped_fn,
*iterables,
timeout=timeout,
chunksize=chunksize,
)
@contextmanager @contextmanager
@ -406,9 +462,36 @@ def get_executor_for_config(
Generator[Executor, None, None]: The executor. Generator[Executor, None, None]: The executor.
""" """
config = config or {} config = config or {}
with ThreadPoolExecutor( with ContextThreadPoolExecutor(
max_workers=config.get("max_concurrency"), max_workers=config.get("max_concurrency")
initializer=_set_context,
initargs=(copy_context(),),
) as executor: ) as executor:
yield executor yield executor
async def run_in_executor(
executor_or_config: Optional[Union[Executor, RunnableConfig]],
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
"""Run a function in an executor.
Args:
executor (Executor): The executor.
func (Callable[P, Output]): The function.
*args (Any): The positional arguments to the function.
**kwargs (Any): The keyword arguments to the function.
Returns:
Output: The output of the function.
"""
if executor_or_config is None or isinstance(executor_or_config, dict):
# Use default executor with context copied from current context
return await asyncio.get_running_loop().run_in_executor(
None,
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs)),
)
return await asyncio.get_running_loop().run_in_executor(
executor_or_config, partial(func, **kwargs), *args
)

@ -23,6 +23,7 @@ from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
RunnableConfig, RunnableConfig,
ensure_config,
get_config_list, get_config_list,
get_executor_for_config, get_executor_for_config,
) )
@ -259,7 +260,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
def _prepare( def _prepare(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]: ) -> Tuple[Runnable[Input, Output], RunnableConfig]:
config = config or {} config = ensure_config(config)
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
configurable_fields = { configurable_fields = {
specs_by_id[k][0]: v specs_by_id[k][0]: v
@ -392,7 +393,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
def _prepare( def _prepare(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]: ) -> Tuple[Runnable[Input, Output], RunnableConfig]:
config = config or {} config = ensure_config(config)
which = config.get("configurable", {}).get(self.which.id, self.default_key) which = config.get("configurable", {}).get(self.which.id, self.default_key)
# remap configurable keys for the chosen alternative # remap configurable keys for the chosen alternative
if self.prefix_keys: if self.prefix_keys:

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import inspect import inspect
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -18,6 +17,7 @@ from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load import load from langchain_core.load import load
from langchain_core.pydantic_v1 import BaseModel, create_model from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.config import run_in_executor
from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
ConfigurableFieldSpec, ConfigurableFieldSpec,
@ -331,9 +331,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
async def _aenter_history( async def _aenter_history(
self, input: Dict[str, Any], config: RunnableConfig self, input: Dict[str, Any], config: RunnableConfig
) -> List[BaseMessage]: ) -> List[BaseMessage]:
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(config, self._enter_history, input, config)
None, self._enter_history, input, config
)
def _exit_history(self, run: Run, config: RunnableConfig) -> None: def _exit_history(self, run: Run, config: RunnableConfig) -> None:
hist = config["configurable"]["message_history"] hist = config["configurable"]["message_history"]

@ -31,6 +31,7 @@ from langchain_core.runnables.config import (
RunnableConfig, RunnableConfig,
acall_func_with_variable_args, acall_func_with_variable_args,
call_func_with_variable_args, call_func_with_variable_args,
ensure_config,
get_executor_for_config, get_executor_for_config,
patch_config, patch_config,
) )
@ -206,7 +207,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Other: ) -> Other:
if self.func is not None: if self.func is not None:
call_func_with_variable_args(self.func, input, config or {}, **kwargs) call_func_with_variable_args(
self.func, input, ensure_config(config), **kwargs
)
return self._call_with_config(identity, input, config) return self._call_with_config(identity, input, config)
async def ainvoke( async def ainvoke(
@ -217,10 +220,12 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
) -> Other: ) -> Other:
if self.afunc is not None: if self.afunc is not None:
await acall_func_with_variable_args( await acall_func_with_variable_args(
self.afunc, input, config or {}, **kwargs self.afunc, input, ensure_config(config), **kwargs
) )
elif self.func is not None: elif self.func is not None:
call_func_with_variable_args(self.func, input, config or {}, **kwargs) call_func_with_variable_args(
self.func, input, ensure_config(config), **kwargs
)
return await self._acall_with_config(aidentity, input, config) return await self._acall_with_config(aidentity, input, config)
def transform( def transform(
@ -243,7 +248,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
final = final + chunk final = final + chunk
if final is not None: if final is not None:
call_func_with_variable_args(self.func, final, config or {}, **kwargs) call_func_with_variable_args(
self.func, final, ensure_config(config), **kwargs
)
async def atransform( async def atransform(
self, self,
@ -269,7 +276,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
final = final + chunk final = final + chunk
if final is not None: if final is not None:
config = config or {} config = ensure_config(config)
if self.afunc is not None: if self.afunc is not None:
await acall_func_with_variable_args( await acall_func_with_variable_args(
self.afunc, final, config, **kwargs self.afunc, final, config, **kwargs
@ -458,7 +465,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
) )
# get executor to start map output stream in background # get executor to start map output stream in background
with get_executor_for_config(config or {}) as executor: with get_executor_for_config(config) as executor:
# start map output stream # start map output stream
first_map_chunk_future = executor.submit( first_map_chunk_future = executor.submit(
next, next,

@ -1,11 +1,9 @@
"""Base implementation for tools or skills.""" """Base implementation for tools or skills."""
from __future__ import annotations from __future__ import annotations
import asyncio
import inspect import inspect
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from functools import partial
from inspect import signature from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
@ -26,7 +24,13 @@ from langchain_core.pydantic_v1 import (
root_validator, root_validator,
validate_arguments, validate_arguments,
) )
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableSerializable,
ensure_config,
)
from langchain_core.runnables.config import run_in_executor
class SchemaAnnotationError(TypeError): class SchemaAnnotationError(TypeError):
@ -202,7 +206,7 @@ class ChildTool(BaseTool):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
config = config or {} config = ensure_config(config)
return self.run( return self.run(
input, input,
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
@ -218,7 +222,7 @@ class ChildTool(BaseTool):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
config = config or {} config = ensure_config(config)
return await self.arun( return await self.arun(
input, input,
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
@ -280,11 +284,7 @@ class ChildTool(BaseTool):
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
to child implementations to enable tracing, to child implementations to enable tracing,
""" """
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(None, self._run, *args, **kwargs)
None,
partial(self._run, **kwargs),
*args,
)
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string, # For backwards compatibility, if run_input is a string,
@ -468,9 +468,7 @@ class Tool(BaseTool):
) -> Any: ) -> Any:
if not self.coroutine: if not self.coroutine:
# If the tool does not implement async, fall back to default implementation # If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(config, self.invoke, input, config, **kwargs)
None, partial(self.invoke, input, config, **kwargs)
)
return await super().ainvoke(input, config, **kwargs) return await super().ainvoke(input, config, **kwargs)
@ -538,8 +536,12 @@ class Tool(BaseTool):
else await self.coroutine(*args, **kwargs) else await self.coroutine(*args, **kwargs)
) )
else: else:
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, partial(self._run, run_manager=run_manager, **kwargs), *args None,
self._run,
run_manager=run_manager.get_sync() if run_manager else None,
*args,
**kwargs,
) )
# TODO: this is for backwards compatibility, remove in future # TODO: this is for backwards compatibility, remove in future
@ -599,9 +601,7 @@ class StructuredTool(BaseTool):
) -> Any: ) -> Any:
if not self.coroutine: if not self.coroutine:
# If the tool does not implement async, fall back to default implementation # If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(config, self.invoke, input, config, **kwargs)
None, partial(self.invoke, input, config, **kwargs)
)
return await super().ainvoke(input, config, **kwargs) return await super().ainvoke(input, config, **kwargs)
@ -652,10 +652,12 @@ class StructuredTool(BaseTool):
if new_argument_supported if new_argument_supported
else await self.coroutine(*args, **kwargs) else await self.coroutine(*args, **kwargs)
) )
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, None,
partial(self._run, run_manager=run_manager, **kwargs), self._run,
run_manager=run_manager.get_sync() if run_manager else None,
*args, *args,
**kwargs,
) )
@classmethod @classmethod

@ -1,11 +1,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import math import math
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -24,6 +22,7 @@ from typing import (
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.callbacks.manager import ( from langchain_core.callbacks.manager import (
@ -103,9 +102,7 @@ class VectorStore(ABC):
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.""" """Run more texts through the embeddings and add to the vectorstore."""
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs)
None, partial(self.add_texts, **kwargs), texts, metadatas
)
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore. """Run more documents through the embeddings and add to the vectorstore.
@ -224,8 +221,9 @@ class VectorStore(ABC):
# This is a temporary workaround to make the similarity search # This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search # asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations. # asynchronous in the vector store implementations.
func = partial(self.similarity_search_with_score, *args, **kwargs) return await run_in_executor(
return await asyncio.get_event_loop().run_in_executor(None, func) None, self.similarity_search_with_score, *args, **kwargs
)
def _similarity_search_with_relevance_scores( def _similarity_search_with_relevance_scores(
self, self,
@ -383,8 +381,7 @@ class VectorStore(ABC):
# This is a temporary workaround to make the similarity search # This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search # asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations. # asynchronous in the vector store implementations.
func = partial(self.similarity_search, query, k=k, **kwargs) return await run_in_executor(None, self.similarity_search, query, k=k, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_by_vector( def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any self, embedding: List[float], k: int = 4, **kwargs: Any
@ -408,8 +405,9 @@ class VectorStore(ABC):
# This is a temporary workaround to make the similarity search # This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search # asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations. # asynchronous in the vector store implementations.
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs) return await run_in_executor(
return await asyncio.get_event_loop().run_in_executor(None, func) None, self.similarity_search_by_vector, embedding, k=k, **kwargs
)
def max_marginal_relevance_search( def max_marginal_relevance_search(
self, self,
@ -450,7 +448,8 @@ class VectorStore(ABC):
# This is a temporary workaround to make the similarity search # This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search # asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations. # asynchronous in the vector store implementations.
func = partial( return await run_in_executor(
None,
self.max_marginal_relevance_search, self.max_marginal_relevance_search,
query, query,
k=k, k=k,
@ -458,7 +457,6 @@ class VectorStore(ABC):
lambda_mult=lambda_mult, lambda_mult=lambda_mult,
**kwargs, **kwargs,
) )
return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, self,
@ -541,8 +539,8 @@ class VectorStore(ABC):
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> VST:
"""Return VectorStore initialized from texts and embeddings.""" """Return VectorStore initialized from texts and embeddings."""
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas None, cls.from_texts, texts, embedding, metadatas, **kwargs
) )
def _get_retriever_tags(self) -> List[str]: def _get_retriever_tags(self) -> List[str]:

@ -5,6 +5,9 @@ EXPECTED_ALL = [
"ConfigurableField", "ConfigurableField",
"ConfigurableFieldSingleOption", "ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption", "ConfigurableFieldMultiOption",
"ConfigurableFieldSpec",
"ensure_config",
"run_in_executor",
"patch_config", "patch_config",
"RouterInput", "RouterInput",
"RouterRunnable", "RouterRunnable",

@ -1,7 +1,6 @@
"""A tool for running python code in a REPL.""" """A tool for running python code in a REPL."""
import ast import ast
import asyncio
import re import re
import sys import sys
from contextlib import redirect_stdout from contextlib import redirect_stdout
@ -14,6 +13,7 @@ from langchain.callbacks.manager import (
) )
from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain_core.runnables.config import run_in_executor
from langchain_experimental.utilities.python import PythonREPL from langchain_experimental.utilities.python import PythonREPL
@ -72,10 +72,7 @@ class PythonREPLTool(BaseTool):
if self.sanitize_input: if self.sanitize_input:
query = sanitize_input(query) query = sanitize_input(query)
loop = asyncio.get_running_loop() return await run_in_executor(None, self.run, query)
result = await loop.run_in_executor(None, self.run, query)
return result
class PythonInputs(BaseModel): class PythonInputs(BaseModel):
@ -144,7 +141,4 @@ class PythonAstREPLTool(BaseTool):
) -> Any: ) -> Any:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
loop = asyncio.get_running_loop() return await run_in_executor(None, self._run, query)
result = await loop.run_in_executor(None, self._run, query)
return result

@ -30,7 +30,7 @@ from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from langchain_core.runnables.utils import AddableDict from langchain_core.runnables.utils import AddableDict
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping from langchain_core.utils.input import get_color_mapping
@ -1437,7 +1437,7 @@ class AgentExecutor(Chain):
**kwargs: Any, **kwargs: Any,
) -> Iterator[AddableDict]: ) -> Iterator[AddableDict]:
"""Enables streaming over steps taken to reach final output.""" """Enables streaming over steps taken to reach final output."""
config = config or {} config = ensure_config(config)
iterator = AgentExecutorIterator( iterator = AgentExecutorIterator(
self, self,
input, input,
@ -1458,7 +1458,7 @@ class AgentExecutor(Chain):
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[AddableDict]: ) -> AsyncIterator[AddableDict]:
"""Enables streaming over steps taken to reach final output.""" """Enables streaming over steps taken to reach final output."""
config = config or {} config = ensure_config(config)
iterator = AgentExecutorIterator( iterator = AgentExecutorIterator(
self, self,
input, input,

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Un
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.load import dumpd from langchain_core.load import dumpd
from langchain_core.pydantic_v1 import Field from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import RunnableConfig, RunnableSerializable from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
@ -222,7 +222,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]]. Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]].
""" """
config = config or {} config = ensure_config(config)
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"), inheritable_tags=config.get("tags"),

@ -1,4 +1,3 @@
import asyncio
import json import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import List, Union from typing import List, Union
@ -85,12 +84,5 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
message = result[0].message message = result[0].message
return self._parse_ai_message(message) return self._parse_ai_message(message)
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
raise ValueError("Can only parse messages") raise ValueError("Can only parse messages")

@ -1,4 +1,3 @@
import asyncio
import json import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import List, Union from typing import List, Union
@ -92,12 +91,5 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
message = result[0].message message = result[0].message
return parse_ai_message_to_openai_tool_action(message) return parse_ai_message_to_openai_tool_action(message)
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages") raise ValueError("Can only parse messages")

@ -1,5 +1,4 @@
"""Base interface that all chains should implement.""" """Base interface that all chains should implement."""
import asyncio
import inspect import inspect
import json import json
import logging import logging
@ -19,7 +18,12 @@ from langchain_core.pydantic_v1 import (
root_validator, root_validator,
validator, validator,
) )
from langchain_core.runnables import RunnableConfig, RunnableSerializable from langchain_core.runnables import (
RunnableConfig,
RunnableSerializable,
ensure_config,
run_in_executor,
)
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
@ -85,7 +89,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
config = config or {} config = ensure_config(config)
return self( return self(
input, input,
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
@ -101,7 +105,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
config = config or {} config = ensure_config(config)
return await self.acall( return await self.acall(
input, input,
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
@ -245,8 +249,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
A dict of named outputs. Should contain all outputs specified in A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`. `Chain.output_keys`.
""" """
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, self._call, inputs, run_manager None, self._call, inputs, run_manager.get_sync() if run_manager else None
) )
def __call__( def __call__(

@ -1,16 +1,15 @@
"""Interfaces to be implemented by general evaluators.""" """Interfaces to be implemented by general evaluators."""
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from functools import partial
from typing import Any, Optional, Sequence, Tuple, Union from typing import Any, Optional, Sequence, Tuple, Union
from warnings import warn from warnings import warn
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.runnables.config import run_in_executor
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -189,15 +188,13 @@ class StringEvaluator(_EvalArgsMixin, ABC):
- value: the string value of the evaluation, if applicable. - value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable. - reasoning: the reasoning for the evaluation, if applicable.
""" # noqa: E501 """ # noqa: E501
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, None,
partial( self._evaluate_strings,
self._evaluate_strings, prediction=prediction,
prediction=prediction, reference=reference,
reference=reference, input=input,
input=input, **kwargs,
**kwargs,
),
) )
def evaluate_strings( def evaluate_strings(
@ -292,16 +289,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
Returns: Returns:
dict: A dictionary containing the preference, scores, and/or other information. dict: A dictionary containing the preference, scores, and/or other information.
""" # noqa: E501 """ # noqa: E501
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, None,
partial( self._evaluate_string_pairs,
self._evaluate_string_pairs, prediction=prediction,
prediction=prediction, prediction_b=prediction_b,
prediction_b=prediction_b, reference=reference,
reference=reference, input=input,
input=input, **kwargs,
**kwargs,
),
) )
def evaluate_string_pairs( def evaluate_string_pairs(
@ -415,16 +410,14 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
Returns: Returns:
dict: The evaluation result. dict: The evaluation result.
""" """
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, None,
partial( self._evaluate_agent_trajectory,
self._evaluate_agent_trajectory, prediction=prediction,
prediction=prediction, agent_trajectory=agent_trajectory,
agent_trajectory=agent_trajectory, reference=reference,
reference=reference, input=input,
input=input, **kwargs,
**kwargs,
),
) )
def evaluate_agent_trajectory( def evaluate_agent_trajectory(

@ -1,10 +1,10 @@
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import List, Optional, Sequence, Union from typing import List, Optional, Sequence, Union
from langchain_core.documents import BaseDocumentTransformer, Document from langchain_core.documents import BaseDocumentTransformer, Document
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.config import run_in_executor
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
@ -28,7 +28,7 @@ class BaseDocumentCompressor(BaseModel, ABC):
callbacks: Optional[Callbacks] = None, callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]: ) -> Sequence[Document]:
"""Compress retrieved documents given the query context.""" """Compress retrieved documents given the query context."""
return await asyncio.get_running_loop().run_in_executor( return await run_in_executor(
None, self.compress_documents, documents, query, callbacks None, self.compress_documents, documents, query, callbacks
) )

@ -21,7 +21,6 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive
from __future__ import annotations from __future__ import annotations
import asyncio
import copy import copy
import logging import logging
import pathlib import pathlib
@ -29,7 +28,6 @@ import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import partial
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import ( from typing import (
AbstractSet, AbstractSet,
@ -283,14 +281,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
"""Transform sequence of documents by splitting them.""" """Transform sequence of documents by splitting them."""
return self.split_documents(list(documents)) return self.split_documents(list(documents))
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)
class CharacterTextSplitter(TextSplitter): class CharacterTextSplitter(TextSplitter):
"""Splitting text that looks at characters.""" """Splitting text that looks at characters."""

Loading…
Cancel
Save