mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
parent
f2ef3ff54a
commit
4a810756f8
@ -11,6 +11,7 @@ from langchain.schema import (
|
||||
|
||||
|
||||
def import_context() -> Any:
|
||||
"""Import the `getcontext` package."""
|
||||
try:
|
||||
import getcontext # noqa: F401
|
||||
from getcontext.generated.models import (
|
||||
@ -30,7 +31,9 @@ def import_context() -> Any:
|
||||
|
||||
|
||||
class ContextCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that records transcripts to Context (https://getcontext.ai).
|
||||
"""Callback Handler that records transcripts to the Context service.
|
||||
|
||||
(https://getcontext.ai).
|
||||
|
||||
Keyword Args:
|
||||
token (optional): The token with which to authenticate requests to Context.
|
||||
|
@ -1,4 +1,17 @@
|
||||
"""Chains are easily reusable components which can be linked together."""
|
||||
"""Chains are easily reusable components which can be linked together.
|
||||
|
||||
Chains should be used to encode a sequence of calls to components like
|
||||
models, document retrievers, other chains, etc., and provide a simple interface
|
||||
to this sequence.
|
||||
|
||||
The Chain interface makes it easy to create apps that are:
|
||||
- Stateful: add Memory to any Chain to give it state,
|
||||
- Observable: pass Callbacks to a Chain to execute additional functionality,
|
||||
like logging, outside the main sequence of component calls,
|
||||
- Composable: the Chain API is flexible enough that it is easy to combine
|
||||
Chains with other components, including other Chains.
|
||||
"""
|
||||
|
||||
from langchain.chains.api.base import APIChain
|
||||
from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
|
||||
from langchain.chains.combine_documents.base import AnalyzeDocumentChain
|
||||
|
@ -72,13 +72,13 @@ class Chain(Serializable, ABC):
|
||||
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
||||
will be printed to the console. Defaults to `langchain.verbose` value."""
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the chain. Defaults to None
|
||||
"""Optional list of tags associated with the chain. Defaults to None.
|
||||
These tags will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the chain. Defaults to None
|
||||
"""Optional metadata associated with the chain. Defaults to None.
|
||||
This metadata will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
@ -118,12 +118,12 @@ class Chain(Serializable, ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the keys expected to be in the chain input."""
|
||||
"""Keys expected to be in the chain input."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the keys expected to be in the chain output."""
|
||||
"""Keys expected to be in the chain output."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
@ -391,7 +391,7 @@ class Chain(Serializable, ABC):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Convenience method for executing chain when there's a single string output.
|
||||
"""Execute chain when there's a single string output.
|
||||
|
||||
The main difference between this method and `Chain.__call__` is that this method
|
||||
can only be used for chains that return a single string output. If a Chain
|
||||
@ -465,7 +465,7 @@ class Chain(Serializable, ABC):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Convenience method for executing chain when there's a single string output.
|
||||
"""Execute chain when there's a single string output.
|
||||
|
||||
The main difference between this method and `Chain.__call__` is that this method
|
||||
can only be used for chains that return a single string output. If a Chain
|
||||
@ -532,7 +532,7 @@ class Chain(Serializable, ABC):
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of chain.
|
||||
"""Dictionary representation of chain.
|
||||
|
||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||
null.
|
||||
|
@ -22,7 +22,7 @@ class AsyncCombineDocsProtocol(Protocol):
|
||||
"""Interface for the combine_docs method."""
|
||||
|
||||
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
||||
"""Async nterface for the combine_docs method."""
|
||||
"""Async interface for the combine_docs method."""
|
||||
|
||||
|
||||
def _split_list_of_docs(
|
||||
@ -78,7 +78,7 @@ async def _acollapse_docs(
|
||||
|
||||
|
||||
class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Combining documents by recursively reducing them.
|
||||
"""Combine documents by recursively reducing them.
|
||||
|
||||
This involves
|
||||
|
||||
@ -206,7 +206,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine multiple documents recursively.
|
||||
"""Async combine multiple documents recursively.
|
||||
|
||||
Args:
|
||||
docs: List of documents to combine, assumed that each one is less than
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Combining documents by doing a first pass and then refining on more documents."""
|
||||
"""Combine documents by doing a first pass and then refining on more documents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -161,7 +161,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain.
|
||||
"""Async combine by mapping a first chain over all, then stuffing
|
||||
into a final chain.
|
||||
|
||||
Args:
|
||||
docs: List of documents to combine
|
||||
|
@ -167,7 +167,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM.
|
||||
"""Async stuff all documents into one prompt and pass to LLM.
|
||||
|
||||
Args:
|
||||
docs: List of documents to join together into one variable
|
||||
|
@ -80,12 +80,12 @@ class ConstitutionalChain(Chain):
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Defines the input keys."""
|
||||
"""Input keys."""
|
||||
return self.chain.input_keys
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Defines the output keys."""
|
||||
"""Output keys."""
|
||||
if self.return_intermediate_steps:
|
||||
return ["output", "critiques_and_revisions", "initial_output"]
|
||||
return ["output"]
|
||||
|
@ -23,6 +23,8 @@ from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class _ResponseChain(LLMChain):
|
||||
"""Base class for chains that generate responses."""
|
||||
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
|
||||
@property
|
||||
@ -46,6 +48,8 @@ class _ResponseChain(LLMChain):
|
||||
|
||||
|
||||
class _OpenAIResponseChain(_ResponseChain):
|
||||
"""Chain that generates responses from user input and context."""
|
||||
|
||||
llm: OpenAI = Field(
|
||||
default_factory=lambda: OpenAI(
|
||||
max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0
|
||||
@ -66,10 +70,14 @@ class _OpenAIResponseChain(_ResponseChain):
|
||||
|
||||
|
||||
class QuestionGeneratorChain(LLMChain):
|
||||
"""Chain that generates questions from uncertain spans."""
|
||||
|
||||
prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT
|
||||
"""Prompt template for the chain."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys for the chain."""
|
||||
return ["user_input", "context", "response"]
|
||||
|
||||
|
||||
@ -95,22 +103,36 @@ def _low_confidence_spans(
|
||||
|
||||
|
||||
class FlareChain(Chain):
|
||||
"""Chain that combines a retriever, a question generator,
|
||||
and a response generator."""
|
||||
|
||||
question_generator_chain: QuestionGeneratorChain
|
||||
"""Chain that generates questions from uncertain spans."""
|
||||
response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain)
|
||||
"""Chain that generates responses from user input and context."""
|
||||
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
||||
"""Parser that determines whether the chain is finished."""
|
||||
retriever: BaseRetriever
|
||||
"""Retriever that retrieves relevant documents from a user input."""
|
||||
min_prob: float = 0.2
|
||||
"""Minimum probability for a token to be considered low confidence."""
|
||||
min_token_gap: int = 5
|
||||
"""Minimum number of tokens between two low confidence spans."""
|
||||
num_pad_tokens: int = 2
|
||||
"""Number of tokens to pad around a low confidence span."""
|
||||
max_iter: int = 10
|
||||
"""Maximum number of iterations."""
|
||||
start_with_retrieval: bool = True
|
||||
"""Whether to start with retrieval."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys for the chain."""
|
||||
return ["user_input"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys for the chain."""
|
||||
return ["response"]
|
||||
|
||||
def _do_generation(
|
||||
@ -213,6 +235,16 @@ class FlareChain(Chain):
|
||||
def from_llm(
|
||||
cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any
|
||||
) -> FlareChain:
|
||||
"""Creates a FlareChain from a language model.
|
||||
|
||||
Args:
|
||||
llm: Language model to use.
|
||||
max_generation_len: Maximum length of the generated response.
|
||||
**kwargs: Additional arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
FlareChain class with the given language model.
|
||||
"""
|
||||
question_gen_chain = QuestionGeneratorChain(llm=llm)
|
||||
response_llm = OpenAI(
|
||||
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
|
||||
|
@ -5,7 +5,10 @@ from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]):
|
||||
"""Output parser that checks if the output is finished."""
|
||||
|
||||
finished_value: str = "FINISHED"
|
||||
"""Value that indicates the output is finished."""
|
||||
|
||||
def parse(self, text: str) -> Tuple[str, bool]:
|
||||
cleaned = text.strip()
|
||||
|
@ -25,7 +25,7 @@ class GraphQAChain(Chain):
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
@ -33,7 +33,7 @@ class GraphQAChain(Chain):
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
|
@ -18,8 +18,8 @@ INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""
|
||||
Extract Cypher code from a text.
|
||||
"""Extract Cypher code from a text.
|
||||
|
||||
Args:
|
||||
text: Text to extract Cypher code from.
|
||||
|
||||
|
@ -28,7 +28,7 @@ class HugeGraphQAChain(Chain):
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
@ -36,7 +36,7 @@ class HugeGraphQAChain(Chain):
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
|
||||
"""Chain that interprets a prompt and executes bash operations."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMBashChain(Chain):
|
||||
"""Chain that interprets a prompt and executes bash code to perform bash operations.
|
||||
"""Chain that interprets a prompt and executes bash operations.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
@ -16,7 +16,7 @@ DEFAULT_HEADERS = {
|
||||
|
||||
|
||||
class LLMRequestsChain(Chain):
|
||||
"""Chain that hits a URL and then uses an LLM to parse results."""
|
||||
"""Chain that requests a URL and then uses an LLM to parse results."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
requests_wrapper: TextRequestsWrapper = Field(
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Chain that interprets a prompt and executes python code to do math."""
|
||||
"""Chain that interprets a prompt and executes python code to do symbolic math."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
@ -18,7 +18,7 @@ from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
class LLMSymbolicMathChain(Chain):
|
||||
"""Chain that interprets a prompt and executes python code to do math.
|
||||
"""Chain that interprets a prompt and executes python code to do symbolic math.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
@ -13,7 +13,7 @@ from langchain.schema.messages import HumanMessage, SystemMessage
|
||||
|
||||
|
||||
class FactWithEvidence(BaseModel):
|
||||
"""Class representing single statement.
|
||||
"""Class representing a single statement.
|
||||
|
||||
Each fact has a body and a list of sources.
|
||||
If there are multiple facts make sure to break them apart
|
||||
|
@ -188,9 +188,14 @@ def openapi_spec_to_openai_fn(
|
||||
|
||||
|
||||
class SimpleRequestChain(Chain):
|
||||
"""Chain for making a simple request to an API endpoint."""
|
||||
|
||||
request_method: Callable
|
||||
"""Method to use for making the request."""
|
||||
output_key: str = "response"
|
||||
"""Key to use for the output of the request."""
|
||||
input_key: str = "function"
|
||||
"""Key to use for the input of the request."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
|
@ -16,7 +16,7 @@ from langchain.schema.messages import HumanMessage, SystemMessage
|
||||
|
||||
|
||||
class AnswerWithSources(BaseModel):
|
||||
"""An answer to the question being asked, with sources."""
|
||||
"""An answer to the question, with sources."""
|
||||
|
||||
answer: str = Field(..., description="Answer to the question that was asked")
|
||||
sources: List[str] = Field(
|
||||
@ -30,7 +30,8 @@ def create_qa_with_structure_chain(
|
||||
output_parser: str = "base",
|
||||
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
|
||||
) -> LLMChain:
|
||||
"""Create a question answering chain that returns an answer with sources.
|
||||
"""Create a question answering chain that returns an answer with sources
|
||||
based on schema.
|
||||
|
||||
Args:
|
||||
llm: Language model to use for the chain.
|
||||
|
@ -34,7 +34,8 @@ def create_tagging_chain(
|
||||
prompt: Optional[ChatPromptTemplate] = None,
|
||||
**kwargs: Any
|
||||
) -> Chain:
|
||||
"""Creates a chain that extracts information from a passage.
|
||||
"""Creates a chain that extracts information from a passage
|
||||
based on a schema.
|
||||
|
||||
Args:
|
||||
schema: The schema of the entities to extract.
|
||||
@ -63,7 +64,8 @@ def create_tagging_chain_pydantic(
|
||||
prompt: Optional[ChatPromptTemplate] = None,
|
||||
**kwargs: Any
|
||||
) -> Chain:
|
||||
"""Creates a chain that extracts information from a passage.
|
||||
"""Creates a chain that extracts information from a passage
|
||||
based on a pydantic schema.
|
||||
|
||||
Args:
|
||||
pydantic_schema: The pydantic schema of the entities to extract.
|
||||
|
@ -10,6 +10,8 @@ from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class BasePromptSelector(BaseModel, ABC):
|
||||
"""Base class for prompt selectors."""
|
||||
|
||||
@abstractmethod
|
||||
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
|
||||
"""Get default prompt for a language model."""
|
||||
@ -19,11 +21,21 @@ class ConditionalPromptSelector(BasePromptSelector):
|
||||
"""Prompt collection that goes through conditionals."""
|
||||
|
||||
default_prompt: BasePromptTemplate
|
||||
"""Default prompt to use if no conditionals match."""
|
||||
conditionals: List[
|
||||
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
|
||||
] = Field(default_factory=list)
|
||||
"""List of conditionals and prompts to use if the conditionals match."""
|
||||
|
||||
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
|
||||
"""Get default prompt for a language model.
|
||||
|
||||
Args:
|
||||
llm: Language model to get prompt for.
|
||||
|
||||
Returns:
|
||||
Prompt to use for the language model.
|
||||
"""
|
||||
for condition, prompt in self.conditionals:
|
||||
if condition(llm):
|
||||
return prompt
|
||||
|
@ -15,13 +15,20 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
class QAGenerationChain(Chain):
|
||||
"""Base class for question-answer generation chains."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM Chain that generates responses from user input and context."""
|
||||
text_splitter: TextSplitter = Field(
|
||||
default=RecursiveCharacterTextSplitter(chunk_overlap=500)
|
||||
)
|
||||
"""Text splitter that splits the input into chunks."""
|
||||
input_key: str = "text"
|
||||
"""Key of the input to the chain."""
|
||||
output_key: str = "questions"
|
||||
"""Key of the output of the chain."""
|
||||
k: Optional[int] = None
|
||||
"""Number of questions to generate."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@ -30,6 +37,17 @@ class QAGenerationChain(Chain):
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> QAGenerationChain:
|
||||
"""
|
||||
Create a QAGenerationChain from a language model.
|
||||
|
||||
Args:
|
||||
llm: a language model
|
||||
prompt: a prompt template
|
||||
**kwargs: additional arguments
|
||||
|
||||
Returns:
|
||||
a QAGenerationChain class
|
||||
"""
|
||||
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||
chain = LLMChain(llm=llm, prompt=_prompt)
|
||||
return cls(llm_chain=chain, **kwargs)
|
||||
|
@ -31,7 +31,7 @@ from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class BaseQAWithSourcesChain(Chain, ABC):
|
||||
"""Question answering with sources over documents."""
|
||||
"""Question answering chain with sources over documents."""
|
||||
|
||||
combine_documents_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to combine documents."""
|
||||
|
@ -155,7 +155,7 @@ def load_qa_with_sources_chain(
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering with sources chain.
|
||||
"""Load a question answering with sources chain.
|
||||
|
||||
Args:
|
||||
llm: Language Model to use in the chain.
|
||||
|
@ -27,6 +27,8 @@ from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
"""Output parser that parses a structured query."""
|
||||
|
||||
ast_parse: Callable
|
||||
"""Callable that parses dict into internal representation of query language."""
|
||||
|
||||
@ -57,6 +59,16 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
) -> StructuredQueryOutputParser:
|
||||
"""
|
||||
Create a structured query output parser from components.
|
||||
|
||||
Args:
|
||||
allowed_comparators: allowed comparators
|
||||
allowed_operators: allowed operators
|
||||
|
||||
Returns:
|
||||
a structured query output parser
|
||||
"""
|
||||
ast_parser = get_parser(
|
||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
||||
)
|
||||
|
@ -53,7 +53,17 @@ def _to_snake_case(name: str) -> str:
|
||||
|
||||
|
||||
class Expr(BaseModel):
|
||||
"""Base class for all expressions."""
|
||||
|
||||
def accept(self, visitor: Visitor) -> Any:
|
||||
"""Accept a visitor.
|
||||
|
||||
Args:
|
||||
visitor: visitor to accept
|
||||
|
||||
Returns:
|
||||
result of visiting
|
||||
"""
|
||||
return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")(
|
||||
self
|
||||
)
|
||||
@ -99,6 +109,11 @@ class Operation(FilterDirective):
|
||||
|
||||
|
||||
class StructuredQuery(Expr):
|
||||
"""A structured query."""
|
||||
|
||||
query: str
|
||||
"""Query string."""
|
||||
filter: Optional[FilterDirective]
|
||||
"""Filtering expression."""
|
||||
limit: Optional[int]
|
||||
"""Limit on the number of results."""
|
||||
|
@ -54,7 +54,8 @@ GRAMMAR = """
|
||||
@v_args(inline=True)
|
||||
class QueryTransformer(Transformer):
|
||||
"""Transforms a query string into an IR representation
|
||||
(intermediate representation)."""
|
||||
(intermediate representation).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -25,12 +25,14 @@ from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
class BaseRetrievalQA(Chain):
|
||||
"""Base class for question-answering chains."""
|
||||
|
||||
combine_documents_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to combine the documents."""
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
return_source_documents: bool = False
|
||||
"""Return the source documents."""
|
||||
"""Return the source documents or not."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -41,7 +43,7 @@ class BaseRetrievalQA(Chain):
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
@ -49,7 +51,7 @@ class BaseRetrievalQA(Chain):
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
|
@ -27,6 +27,16 @@ class RouterChain(Chain, ABC):
|
||||
return ["destination", "next_inputs"]
|
||||
|
||||
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
|
||||
"""
|
||||
Route inputs to a destination chain.
|
||||
|
||||
Args:
|
||||
inputs: inputs to the chain
|
||||
callbacks: callbacks to use for the chain
|
||||
|
||||
Returns:
|
||||
a Route object
|
||||
"""
|
||||
result = self(inputs, callbacks=callbacks)
|
||||
return Route(result["destination"], result["next_inputs"])
|
||||
|
||||
|
@ -12,7 +12,7 @@ from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
class EmbeddingRouterChain(RouterChain):
|
||||
"""Class that uses embeddings to route between options."""
|
||||
"""Chain that uses embeddings to route between options."""
|
||||
|
||||
vectorstore: VectorStore
|
||||
routing_keys: List[str] = ["query"]
|
||||
|
Loading…
Reference in New Issue
Block a user