Async Support for LLMChainExtractor (new) (#3780)

@vowelparrot @hwchase17 Here a new implementation of
`acompress_documents` for `LLMChainExtractor ` without changes to the
sync-version, as you suggested in #3587 / [Async Support for
LLMChainExtractor](https://github.com/hwchase17/langchain/pull/3587) .

I created a new PR to avoid cluttering history with reverted commits,
hope that is the right way.
Happy for any improvements/suggestions.

(PS:
I also tried an alternative implementation with a nested helper function
like

``` python
  async def acompress_documents_old(
      self, documents: Sequence[Document], query: str
  ) -> Sequence[Document]:
      """Compress page content of raw documents."""
      async def _compress_concurrently(doc):
          _input = self.get_input(query, doc)
          output = await self.llm_chain.apredict_and_parse(**_input)
          return Document(page_content=output, metadata=doc.metadata)
      outputs=await asyncio.gather(*[_compress_concurrently(doc) for doc in documents])
      compressed_docs=list(filter(lambda x: len(x.page_content)>0,outputs))
      return compressed_docs
```

But in the end I found the commited version to be better readable and
more "canonical" - hope you agree.
This commit is contained in:
Jan Philipp Harries 2023-05-02 06:23:13 +02:00 committed by GitHub
parent 2cecc572f9
commit fc3c2c4406
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
"""DocumentFilter that uses an LLM chain to extract the relevant parts of documents."""
from __future__ import annotations
import asyncio
from typing import Any, Callable, Dict, Optional, Sequence
from langchain import LLMChain, PromptTemplate
@ -62,7 +63,21 @@ class LLMChainExtractor(BaseDocumentCompressor):
async def acompress_documents(
self, documents: Sequence[Document], query: str
) -> Sequence[Document]:
raise NotImplementedError
"""Compress page content of raw documents asynchronously."""
outputs = await asyncio.gather(
*[
self.llm_chain.apredict_and_parse(**self.get_input(query, doc))
for doc in documents
]
)
compressed_docs = []
for i, doc in enumerate(documents):
if len(outputs[i]) == 0:
continue
compressed_docs.append(
Document(page_content=outputs[i], metadata=doc.metadata)
)
return compressed_docs
@classmethod
def from_llm(