From 330a5b42d41f127221ed8762e04af1332bade986 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 6 Jan 2023 07:15:57 -0800 Subject: [PATCH] fix map reduce chain (#550) --- langchain/chains/combine_documents/map_reduce.py | 13 ++++++++++--- tests/unit_tests/chains/test_combine_documents.py | 8 ++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index 81e324174a..0addc9dcf4 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple from pydantic import BaseModel, Extra, root_validator @@ -11,6 +11,13 @@ from langchain.chains.llm import LLMChain from langchain.docstore.document import Document +class CombineDocsProtocol(Protocol): + """Interface for the combine_docs method.""" + + def __call__(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: + """Interface for the combine_docs method.""" + + def _split_list_of_docs( docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any ) -> List[List[Document]]: @@ -38,10 +45,10 @@ def _split_list_of_docs( def _collapse_docs( docs: List[Document], - combine_document_func: Callable, + combine_document_func: CombineDocsProtocol, **kwargs: Any, ) -> Document: - result = combine_document_func(docs, **kwargs) + result, _ = combine_document_func(docs, **kwargs) combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()} for doc in docs[1:]: for k, v in doc.metadata.items(): diff --git a/tests/unit_tests/chains/test_combine_documents.py b/tests/unit_tests/chains/test_combine_documents.py index 81468f9718..fca09f4ab4 100644 --- a/tests/unit_tests/chains/test_combine_documents.py +++ b/tests/unit_tests/chains/test_combine_documents.py @@ -1,6 +1,6 @@ """Test functionality related to combining documents.""" -from typing import List +from typing import Any, List, Tuple import pytest @@ -12,11 +12,11 @@ from langchain.docstore.document import Document def _fake_docs_len_func(docs: List[Document]) -> int: - return len(_fake_combine_docs_func(docs)) + return len(_fake_combine_docs_func(docs)[0]) -def _fake_combine_docs_func(docs: List[Document]) -> str: - return "".join([d.page_content for d in docs]) +def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: + return "".join([d.page_content for d in docs]), {} def test__split_list_long_single_doc() -> None: