From 106608bc89c4b03c9356efd4b688b91160793949 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:40:35 -0400 Subject: [PATCH] add default async (#11141) --- libs/langchain/langchain/chains/base.py | 11 +++----- libs/langchain/langchain/chat_models/base.py | 5 +--- libs/langchain/langchain/llms/base.py | 27 +++++-------------- libs/langchain/langchain/schema/document.py | 6 ++++- libs/langchain/langchain/schema/embeddings.py | 9 +++++-- libs/langchain/langchain/schema/retriever.py | 10 +++---- .../langchain/langchain/schema/vectorstore.py | 8 ++++-- libs/langchain/langchain/text_splitter.py | 6 ++++- libs/langchain/langchain/tools/base.py | 4 --- 9 files changed, 38 insertions(+), 48 deletions(-) diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index fd54cfe6bc..bea11d135c 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -5,7 +5,6 @@ import json import logging import warnings from abc import ABC, abstractmethod -from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Type, Union @@ -97,12 +96,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Dict[str, Any]: - if type(self)._acall == Chain._acall: - # If the chain does not implement async, fall back to default implementation - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, **kwargs) - ) - config = config or {} return await self.acall( input, @@ -246,7 +239,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): A dict of named outputs. Should contain all outputs specified in `Chain.output_keys`. """ - raise NotImplementedError("Async call not supported for this chain type.") + return await asyncio.get_running_loop().run_in_executor( + None, self._call, inputs, run_manager + ) def __call__( self, diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 4a16393525..afb6926742 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -577,10 +577,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): ) -> ChatResult: """Top Level call""" return await asyncio.get_running_loop().run_in_executor( - None, - partial( - self._generate, messages, stop=stop, run_manager=run_manager, **kwargs - ), + None, partial(self._generate, **kwargs), messages, stop, run_manager ) def _stream( diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index e7d165d439..8a459434cf 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -248,12 +248,6 @@ class BaseLLM(BaseLanguageModel[str], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> str: - if type(self)._agenerate == BaseLLM._agenerate: - # model doesn't implement async invoke, so use default implementation - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, stop=stop, **kwargs) - ) - config = config or {} llm_result = await self.agenerate_prompt( [self._convert_input(input)], @@ -319,13 +313,6 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) -> List[str]: if not inputs: return [] - - if type(self)._agenerate == BaseLLM._agenerate: - # model doesn't implement async batch, so use default implementation - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.batch, **kwargs), inputs, config - ) - config = get_config_list(config, len(inputs)) max_concurrency = config[0].get("max_concurrency") @@ -478,7 +465,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompts.""" - raise NotImplementedError() + return await asyncio.get_running_loop().run_in_executor( + None, partial(self._generate, **kwargs), prompts, stop, run_manager + ) def _stream( self, @@ -1035,7 +1024,9 @@ class LLM(BaseLLM): **kwargs: Any, ) -> str: """Run the LLM on the given prompt and input.""" - raise NotImplementedError() + return await asyncio.get_running_loop().run_in_executor( + None, partial(self._call, **kwargs), prompt, stop, run_manager + ) def _generate( self, @@ -1064,12 +1055,6 @@ class LLM(BaseLLM): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: - if type(self)._acall == LLM._acall: - # model doesn't implement async call, so use default implementation - return await asyncio.get_running_loop().run_in_executor( - None, partial(self._generate, prompts, stop, run_manager, **kwargs) - ) - """Run the LLM on the given prompt and input.""" generations = [] new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") diff --git a/libs/langchain/langchain/schema/document.py b/libs/langchain/langchain/schema/document.py index d0a4f666f5..acb46dd1d2 100644 --- a/libs/langchain/langchain/schema/document.py +++ b/libs/langchain/langchain/schema/document.py @@ -1,6 +1,8 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod +from functools import partial from typing import Any, Sequence from langchain.load.serializable import Serializable @@ -72,7 +74,6 @@ class BaseDocumentTransformer(ABC): A list of transformed Documents. """ - @abstractmethod async def atransform_documents( self, documents: Sequence[Document], **kwargs: Any ) -> Sequence[Document]: @@ -84,3 +85,6 @@ class BaseDocumentTransformer(ABC): Returns: A list of transformed Documents. """ + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.transform_documents, **kwargs), documents + ) diff --git a/libs/langchain/langchain/schema/embeddings.py b/libs/langchain/langchain/schema/embeddings.py index 2cae8c1406..c08a279750 100644 --- a/libs/langchain/langchain/schema/embeddings.py +++ b/libs/langchain/langchain/schema/embeddings.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod from typing import List @@ -15,8 +16,12 @@ class Embeddings(ABC): async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Asynchronous Embed search docs.""" - raise NotImplementedError + return await asyncio.get_running_loop().run_in_executor( + None, self.embed_documents, texts + ) async def aembed_query(self, text: str) -> List[float]: """Asynchronous Embed query text.""" - raise NotImplementedError + return await asyncio.get_running_loop().run_in_executor( + None, self.embed_query, text + ) diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 25934eb3ed..180093a511 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -1,7 +1,9 @@ from __future__ import annotations +import asyncio import warnings from abc import ABC, abstractmethod +from functools import partial from inspect import signature from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -121,10 +123,6 @@ class BaseRetriever(RunnableSerializable[str, List[Document]], ABC): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> List[Document]: - if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents: - # If the retriever doesn't implement async, use default implementation - return await super().ainvoke(input, config) - config = config or {} return await self.aget_relevant_documents( input, @@ -156,7 +154,9 @@ class BaseRetriever(RunnableSerializable[str, List[Document]], ABC): Returns: List of relevant documents """ - raise NotImplementedError() + return await asyncio.get_running_loop().run_in_executor( + None, partial(self._get_relevant_documents, run_manager=run_manager), query + ) def get_relevant_documents( self, diff --git a/libs/langchain/langchain/schema/vectorstore.py b/libs/langchain/langchain/schema/vectorstore.py index 68c7f94c7f..861c04ffd9 100644 --- a/libs/langchain/langchain/schema/vectorstore.py +++ b/libs/langchain/langchain/schema/vectorstore.py @@ -87,7 +87,9 @@ class VectorStore(ABC): **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore.""" - raise NotImplementedError + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.add_texts, **kwargs), texts, metadatas + ) def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: """Run more documents through the embeddings and add to the vectorstore. @@ -451,7 +453,9 @@ class VectorStore(ABC): **kwargs: Any, ) -> VST: """Return VectorStore initialized from texts and embeddings.""" - raise NotImplementedError + return await asyncio.get_running_loop().run_in_executor( + None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas + ) def _get_retriever_tags(self) -> List[str]: """Get tags for retriever.""" diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index 012e03984c..a9cc3b5dfe 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -21,6 +21,7 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from __future__ import annotations +import asyncio import copy import logging import pathlib @@ -28,6 +29,7 @@ import re from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum +from functools import partial from io import BytesIO, StringIO from typing import ( AbstractSet, @@ -284,7 +286,9 @@ class TextSplitter(BaseDocumentTransformer, ABC): self, documents: Sequence[Document], **kwargs: Any ) -> Sequence[Document]: """Asynchronously transform a sequence of documents by splitting them.""" - raise NotImplementedError + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.transform_documents, **kwargs), documents + ) class CharacterTextSplitter(TextSplitter): diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index e974070ece..4a574ecc3e 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -217,10 +217,6 @@ class ChildTool(BaseTool): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: - if type(self)._arun == BaseTool._arun: - # If the tool does not implement async, fall back to default implementation - return await super().ainvoke(input, config, **kwargs) - config = config or {} return await self.arun( input,