mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
fix map reduce chain (#550)
This commit is contained in:
parent
ba0cbb4a41
commit
330a5b42d4
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
@ -11,6 +11,13 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class CombineDocsProtocol(Protocol):
|
||||
"""Interface for the combine_docs method."""
|
||||
|
||||
def __call__(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Interface for the combine_docs method."""
|
||||
|
||||
|
||||
def _split_list_of_docs(
|
||||
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
||||
) -> List[List[Document]]:
|
||||
@ -38,10 +45,10 @@ def _split_list_of_docs(
|
||||
|
||||
def _collapse_docs(
|
||||
docs: List[Document],
|
||||
combine_document_func: Callable,
|
||||
combine_document_func: CombineDocsProtocol,
|
||||
**kwargs: Any,
|
||||
) -> Document:
|
||||
result = combine_document_func(docs, **kwargs)
|
||||
result, _ = combine_document_func(docs, **kwargs)
|
||||
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
||||
for doc in docs[1:]:
|
||||
for k, v in doc.metadata.items():
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test functionality related to combining documents."""
|
||||
|
||||
from typing import List
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
@ -12,11 +12,11 @@ from langchain.docstore.document import Document
|
||||
|
||||
|
||||
def _fake_docs_len_func(docs: List[Document]) -> int:
|
||||
return len(_fake_combine_docs_func(docs))
|
||||
return len(_fake_combine_docs_func(docs)[0])
|
||||
|
||||
|
||||
def _fake_combine_docs_func(docs: List[Document]) -> str:
|
||||
return "".join([d.page_content for d in docs])
|
||||
def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
return "".join([d.page_content for d in docs]), {}
|
||||
|
||||
|
||||
def test__split_list_long_single_doc() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user