fix map reduce chain (#550)

This commit is contained in:
Harrison Chase 2023-01-06 07:15:57 -08:00 committed by GitHub
parent ba0cbb4a41
commit 330a5b42d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 7 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
from pydantic import BaseModel, Extra, root_validator
@ -11,6 +11,13 @@ from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
class CombineDocsProtocol(Protocol):
"""Interface for the combine_docs method."""
def __call__(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Interface for the combine_docs method."""
def _split_list_of_docs(
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
) -> List[List[Document]]:
@ -38,10 +45,10 @@ def _split_list_of_docs(
def _collapse_docs(
docs: List[Document],
combine_document_func: Callable,
combine_document_func: CombineDocsProtocol,
**kwargs: Any,
) -> Document:
result = combine_document_func(docs, **kwargs)
result, _ = combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():

View File

@ -1,6 +1,6 @@
"""Test functionality related to combining documents."""
from typing import List
from typing import Any, List, Tuple
import pytest
@ -12,11 +12,11 @@ from langchain.docstore.document import Document
def _fake_docs_len_func(docs: List[Document]) -> int:
return len(_fake_combine_docs_func(docs))
return len(_fake_combine_docs_func(docs)[0])
def _fake_combine_docs_func(docs: List[Document]) -> str:
return "".join([d.page_content for d in docs])
def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
return "".join([d.page_content for d in docs]), {}
def test__split_list_long_single_doc() -> None: