parse output of combine docs

harrison/combine-docs-parse
Harrison Chase 1 year ago
parent c59c5f5164
commit 275e58eab8

@ -1,12 +1,13 @@
"""Base interface for chains combining documents."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.prompts.base import BaseOutputParser
class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
@ -42,6 +43,21 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Combine documents into a single string."""
@abstractmethod
@property
def output_parser(self) -> Optional[BaseOutputParser]:
"""Output parser to use for results of combine_docs."""
def combine_and_parse(
self, docs: List[Document], **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]:
"""Combine documents and parse the result."""
result, _ = self.combine_docs(docs, **kwargs)
if self.output_parser is not None:
return self.output_parser.parse(result)
else:
return result
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction

@ -9,6 +9,7 @@ from pydantic import BaseModel, Extra, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.base import BaseOutputParser
def _split_list_of_docs(
@ -113,6 +114,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
else:
return self.combine_document_chain
@property
def output_parser(self) -> Optional[BaseOutputParser]:
"""Output parser to use for results of combine_docs."""
return self.combine_document_chain.output_parser
def combine_docs(
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
) -> Tuple[str, dict]:

@ -2,14 +2,14 @@
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.base import BaseOutputParser, BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
@ -74,6 +74,11 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
)
return values
@property
def output_parser(self) -> Optional[BaseOutputParser]:
"""Output parser to use for results of combine_docs."""
return self.refine_llm_chain.prompt.output_parser
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain."""
base_info = {"page_content": docs[0].page_content}

@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.base import BaseOutputParser, BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
@ -78,6 +78,11 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
@property
def output_parser(self) -> Optional[BaseOutputParser]:
"""Output parser to use for results of combine_docs."""
return self.llm_chain.prompt.output_parser
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)

Loading…
Cancel
Save