add default async (#11141)

pull/11381/head
Bagatur 1 year ago committed by GitHub
parent 88c5349196
commit 106608bc89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,7 +5,6 @@ import json
import logging
import warnings
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union
@ -97,12 +96,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if type(self)._acall == Chain._acall:
# If the chain does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
config = config or {}
return await self.acall(
input,
@ -246,7 +239,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
raise NotImplementedError("Async call not supported for this chain type.")
return await asyncio.get_running_loop().run_in_executor(
None, self._call, inputs, run_manager
)
def __call__(
self,

@ -577,10 +577,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) -> ChatResult:
"""Top Level call"""
return await asyncio.get_running_loop().run_in_executor(
None,
partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
),
None, partial(self._generate, **kwargs), messages, stop, run_manager
)
def _stream(

@ -248,12 +248,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async invoke, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, stop=stop, **kwargs)
)
config = config or {}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)],
@ -319,13 +313,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) -> List[str]:
if not inputs:
return []
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async batch, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.batch, **kwargs), inputs, config
)
config = get_config_list(config, len(inputs))
max_concurrency = config[0].get("max_concurrency")
@ -478,7 +465,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompts."""
raise NotImplementedError()
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), prompts, stop, run_manager
)
def _stream(
self,
@ -1035,7 +1024,9 @@ class LLM(BaseLLM):
**kwargs: Any,
) -> str:
"""Run the LLM on the given prompt and input."""
raise NotImplementedError()
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._call, **kwargs), prompt, stop, run_manager
)
def _generate(
self,
@ -1064,12 +1055,6 @@ class LLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
if type(self)._acall == LLM._acall:
# model doesn't implement async call, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, prompts, stop, run_manager, **kwargs)
)
"""Run the LLM on the given prompt and input."""
generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")

@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Sequence
from langchain.load.serializable import Serializable
@ -72,7 +74,6 @@ class BaseDocumentTransformer(ABC):
A list of transformed Documents.
"""
@abstractmethod
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
@ -84,3 +85,6 @@ class BaseDocumentTransformer(ABC):
Returns:
A list of transformed Documents.
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)

@ -1,3 +1,4 @@
import asyncio
from abc import ABC, abstractmethod
from typing import List
@ -15,8 +16,12 @@ class Embeddings(ABC):
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_documents, texts
)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
raise NotImplementedError
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_query, text
)

@ -1,7 +1,9 @@
from __future__ import annotations
import asyncio
import warnings
from abc import ABC, abstractmethod
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
@ -121,10 +123,6 @@ class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> List[Document]:
if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents:
# If the retriever doesn't implement async, use default implementation
return await super().ainvoke(input, config)
config = config or {}
return await self.aget_relevant_documents(
input,
@ -156,7 +154,9 @@ class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
Returns:
List of relevant documents
"""
raise NotImplementedError()
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._get_relevant_documents, run_manager=run_manager), query
)
def get_relevant_documents(
self,

@ -87,7 +87,9 @@ class VectorStore(ABC):
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore."""
raise NotImplementedError
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.add_texts, **kwargs), texts, metadatas
)
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore.
@ -451,7 +453,9 @@ class VectorStore(ABC):
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from texts and embeddings."""
raise NotImplementedError
return await asyncio.get_running_loop().run_in_executor(
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
)
def _get_retriever_tags(self) -> List[str]:
"""Get tags for retriever."""

@ -21,6 +21,7 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive
from __future__ import annotations
import asyncio
import copy
import logging
import pathlib
@ -28,6 +29,7 @@ import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from functools import partial
from io import BytesIO, StringIO
from typing import (
AbstractSet,
@ -284,7 +286,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them."""
raise NotImplementedError
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)
class CharacterTextSplitter(TextSplitter):

@ -217,10 +217,6 @@ class ChildTool(BaseTool):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
if type(self)._arun == BaseTool._arun:
# If the tool does not implement async, fall back to default implementation
return await super().ainvoke(input, config, **kwargs)
config = config or {}
return await self.arun(
input,

Loading…
Cancel
Save