Async Support for LLMChainExtractor (new) (#3780)

@vowelparrot @hwchase17 Here a new implementation of
`acompress_documents` for `LLMChainExtractor ` without changes to the
sync-version, as you suggested in #3587 / [Async Support for
LLMChainExtractor](https://github.com/hwchase17/langchain/pull/3587) .

I created a new PR to avoid cluttering history with reverted commits,
hope that is the right way.
Happy for any improvements/suggestions.

(PS:
I also tried an alternative implementation with a nested helper function
like

``` python
  async def acompress_documents_old(
      self, documents: Sequence[Document], query: str
  ) -> Sequence[Document]:
      """Compress page content of raw documents."""
      async def _compress_concurrently(doc):
          _input = self.get_input(query, doc)
          output = await self.llm_chain.apredict_and_parse(**_input)
          return Document(page_content=output, metadata=doc.metadata)
      outputs=await asyncio.gather(*[_compress_concurrently(doc) for doc in documents])
      compressed_docs=list(filter(lambda x: len(x.page_content)>0,outputs))
      return compressed_docs
```

But in the end I found the commited version to be better readable and
more "canonical" - hope you agree.
fix_agent_callbacks
Jan Philipp Harries 1 year ago committed by GitHub
parent 2cecc572f9
commit fc3c2c4406
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
"""DocumentFilter that uses an LLM chain to extract the relevant parts of documents.""" """DocumentFilter that uses an LLM chain to extract the relevant parts of documents."""
from __future__ import annotations from __future__ import annotations
import asyncio
from typing import Any, Callable, Dict, Optional, Sequence from typing import Any, Callable, Dict, Optional, Sequence
from langchain import LLMChain, PromptTemplate from langchain import LLMChain, PromptTemplate
@ -62,7 +63,21 @@ class LLMChainExtractor(BaseDocumentCompressor):
async def acompress_documents( async def acompress_documents(
self, documents: Sequence[Document], query: str self, documents: Sequence[Document], query: str
) -> Sequence[Document]: ) -> Sequence[Document]:
raise NotImplementedError """Compress page content of raw documents asynchronously."""
outputs = await asyncio.gather(
*[
self.llm_chain.apredict_and_parse(**self.get_input(query, doc))
for doc in documents
]
)
compressed_docs = []
for i, doc in enumerate(documents):
if len(outputs[i]) == 0:
continue
compressed_docs.append(
Document(page_content=outputs[i], metadata=doc.metadata)
)
return compressed_docs
@classmethod @classmethod
def from_llm( def from_llm(

Loading…
Cancel
Save