docstrings chains (#7892)

Added/updated docstrings.

@baskaryan
This commit is contained in:
Leonid Ganeline 2023-07-18 18:25:27 -07:00 committed by GitHub
parent f2ef3ff54a
commit 4a810756f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 170 additions and 40 deletions

View File

@ -11,6 +11,7 @@ from langchain.schema import (
def import_context() -> Any:
"""Import the `getcontext` package."""
try:
import getcontext # noqa: F401
from getcontext.generated.models import (
@ -30,7 +31,9 @@ def import_context() -> Any:
class ContextCallbackHandler(BaseCallbackHandler):
"""Callback Handler that records transcripts to Context (https://getcontext.ai).
"""Callback Handler that records transcripts to the Context service.
(https://getcontext.ai).
Keyword Args:
token (optional): The token with which to authenticate requests to Context.

View File

@ -1,4 +1,17 @@
"""Chains are easily reusable components which can be linked together."""
"""Chains are easily reusable components which can be linked together.
Chains should be used to encode a sequence of calls to components like
models, document retrievers, other chains, etc., and provide a simple interface
to this sequence.
The Chain interface makes it easy to create apps that are:
- Stateful: add Memory to any Chain to give it state,
- Observable: pass Callbacks to a Chain to execute additional functionality,
like logging, outside the main sequence of component calls,
- Composable: the Chain API is flexible enough that it is easy to combine
Chains with other components, including other Chains.
"""
from langchain.chains.api.base import APIChain
from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
from langchain.chains.combine_documents.base import AnalyzeDocumentChain

View File

@ -72,13 +72,13 @@ class Chain(Serializable, ABC):
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
will be printed to the console. Defaults to `langchain.verbose` value."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the chain. Defaults to None
"""Optional list of tags associated with the chain. Defaults to None.
These tags will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None
"""Optional metadata associated with the chain. Defaults to None.
This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
@ -118,12 +118,12 @@ class Chain(Serializable, ABC):
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Return the keys expected to be in the chain input."""
"""Keys expected to be in the chain input."""
@property
@abstractmethod
def output_keys(self) -> List[str]:
"""Return the keys expected to be in the chain output."""
"""Keys expected to be in the chain output."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present."""
@ -391,7 +391,7 @@ class Chain(Serializable, ABC):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Convenience method for executing chain when there's a single string output.
"""Execute chain when there's a single string output.
The main difference between this method and `Chain.__call__` is that this method
can only be used for chains that return a single string output. If a Chain
@ -465,7 +465,7 @@ class Chain(Serializable, ABC):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Convenience method for executing chain when there's a single string output.
"""Execute chain when there's a single string output.
The main difference between this method and `Chain.__call__` is that this method
can only be used for chains that return a single string output. If a Chain
@ -532,7 +532,7 @@ class Chain(Serializable, ABC):
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of chain.
"""Dictionary representation of chain.
Expects `Chain._chain_type` property to be implemented and for memory to be
null.

View File

@ -22,7 +22,7 @@ class AsyncCombineDocsProtocol(Protocol):
"""Interface for the combine_docs method."""
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
"""Async nterface for the combine_docs method."""
"""Async interface for the combine_docs method."""
def _split_list_of_docs(
@ -78,7 +78,7 @@ async def _acollapse_docs(
class ReduceDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by recursively reducing them.
"""Combine documents by recursively reducing them.
This involves
@ -206,7 +206,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
"""Combine multiple documents recursively.
"""Async combine multiple documents recursively.
Args:
docs: List of documents to combine, assumed that each one is less than

View File

@ -1,4 +1,4 @@
"""Combining documents by doing a first pass and then refining on more documents."""
"""Combine documents by doing a first pass and then refining on more documents."""
from __future__ import annotations
@ -161,7 +161,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain.
"""Async combine by mapping a first chain over all, then stuffing
into a final chain.
Args:
docs: List of documents to combine

View File

@ -167,7 +167,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.
"""Async stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable

View File

@ -80,12 +80,12 @@ class ConstitutionalChain(Chain):
@property
def input_keys(self) -> List[str]:
"""Defines the input keys."""
"""Input keys."""
return self.chain.input_keys
@property
def output_keys(self) -> List[str]:
"""Defines the output keys."""
"""Output keys."""
if self.return_intermediate_steps:
return ["output", "critiques_and_revisions", "initial_output"]
return ["output"]

View File

@ -23,6 +23,8 @@ from langchain.schema.language_model import BaseLanguageModel
class _ResponseChain(LLMChain):
"""Base class for chains that generate responses."""
prompt: BasePromptTemplate = PROMPT
@property
@ -46,6 +48,8 @@ class _ResponseChain(LLMChain):
class _OpenAIResponseChain(_ResponseChain):
"""Chain that generates responses from user input and context."""
llm: OpenAI = Field(
default_factory=lambda: OpenAI(
max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0
@ -66,10 +70,14 @@ class _OpenAIResponseChain(_ResponseChain):
class QuestionGeneratorChain(LLMChain):
"""Chain that generates questions from uncertain spans."""
prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT
"""Prompt template for the chain."""
@property
def input_keys(self) -> List[str]:
"""Input keys for the chain."""
return ["user_input", "context", "response"]
@ -95,22 +103,36 @@ def _low_confidence_spans(
class FlareChain(Chain):
"""Chain that combines a retriever, a question generator,
and a response generator."""
question_generator_chain: QuestionGeneratorChain
"""Chain that generates questions from uncertain spans."""
response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain)
"""Chain that generates responses from user input and context."""
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
"""Parser that determines whether the chain is finished."""
retriever: BaseRetriever
"""Retriever that retrieves relevant documents from a user input."""
min_prob: float = 0.2
"""Minimum probability for a token to be considered low confidence."""
min_token_gap: int = 5
"""Minimum number of tokens between two low confidence spans."""
num_pad_tokens: int = 2
"""Number of tokens to pad around a low confidence span."""
max_iter: int = 10
"""Maximum number of iterations."""
start_with_retrieval: bool = True
"""Whether to start with retrieval."""
@property
def input_keys(self) -> List[str]:
"""Input keys for the chain."""
return ["user_input"]
@property
def output_keys(self) -> List[str]:
"""Output keys for the chain."""
return ["response"]
def _do_generation(
@ -213,6 +235,16 @@ class FlareChain(Chain):
def from_llm(
cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any
) -> FlareChain:
"""Creates a FlareChain from a language model.
Args:
llm: Language model to use.
max_generation_len: Maximum length of the generated response.
**kwargs: Additional arguments to pass to the constructor.
Returns:
FlareChain class with the given language model.
"""
question_gen_chain = QuestionGeneratorChain(llm=llm)
response_llm = OpenAI(
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0

View File

@ -5,7 +5,10 @@ from langchain.schema import BaseOutputParser
class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]):
"""Output parser that checks if the output is finished."""
finished_value: str = "FINISHED"
"""Value that indicates the output is finished."""
def parse(self, text: str) -> Tuple[str, bool]:
cleaned = text.strip()

View File

@ -25,7 +25,7 @@ class GraphQAChain(Chain):
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
"""Input keys.
:meta private:
"""
@ -33,7 +33,7 @@ class GraphQAChain(Chain):
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
"""Output keys.
:meta private:
"""

View File

@ -18,8 +18,8 @@ INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_cypher(text: str) -> str:
"""
Extract Cypher code from a text.
"""Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.

View File

@ -28,7 +28,7 @@ class HugeGraphQAChain(Chain):
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
"""Input keys.
:meta private:
"""
@ -36,7 +36,7 @@ class HugeGraphQAChain(Chain):
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
"""Output keys.
:meta private:
"""

View File

@ -1,4 +1,4 @@
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
"""Chain that interprets a prompt and executes bash operations."""
from __future__ import annotations
import logging
@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
class LLMBashChain(Chain):
"""Chain that interprets a prompt and executes bash code to perform bash operations.
"""Chain that interprets a prompt and executes bash operations.
Example:
.. code-block:: python

View File

@ -16,7 +16,7 @@ DEFAULT_HEADERS = {
class LLMRequestsChain(Chain):
"""Chain that hits a URL and then uses an LLM to parse results."""
"""Chain that requests a URL and then uses an LLM to parse results."""
llm_chain: LLMChain
requests_wrapper: TextRequestsWrapper = Field(

View File

@ -1,4 +1,4 @@
"""Chain that interprets a prompt and executes python code to do math."""
"""Chain that interprets a prompt and executes python code to do symbolic math."""
from __future__ import annotations
import re
@ -18,7 +18,7 @@ from langchain.prompts.base import BasePromptTemplate
class LLMSymbolicMathChain(Chain):
"""Chain that interprets a prompt and executes python code to do math.
"""Chain that interprets a prompt and executes python code to do symbolic math.
Example:
.. code-block:: python

View File

@ -13,7 +13,7 @@ from langchain.schema.messages import HumanMessage, SystemMessage
class FactWithEvidence(BaseModel):
"""Class representing single statement.
"""Class representing a single statement.
Each fact has a body and a list of sources.
If there are multiple facts make sure to break them apart

View File

@ -188,9 +188,14 @@ def openapi_spec_to_openai_fn(
class SimpleRequestChain(Chain):
"""Chain for making a simple request to an API endpoint."""
request_method: Callable
"""Method to use for making the request."""
output_key: str = "response"
"""Key to use for the output of the request."""
input_key: str = "function"
"""Key to use for the input of the request."""
@property
def input_keys(self) -> List[str]:

View File

@ -16,7 +16,7 @@ from langchain.schema.messages import HumanMessage, SystemMessage
class AnswerWithSources(BaseModel):
"""An answer to the question being asked, with sources."""
"""An answer to the question, with sources."""
answer: str = Field(..., description="Answer to the question that was asked")
sources: List[str] = Field(
@ -30,7 +30,8 @@ def create_qa_with_structure_chain(
output_parser: str = "base",
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
) -> LLMChain:
"""Create a question answering chain that returns an answer with sources.
"""Create a question answering chain that returns an answer with sources
based on schema.
Args:
llm: Language model to use for the chain.

View File

@ -34,7 +34,8 @@ def create_tagging_chain(
prompt: Optional[ChatPromptTemplate] = None,
**kwargs: Any
) -> Chain:
"""Creates a chain that extracts information from a passage.
"""Creates a chain that extracts information from a passage
based on a schema.
Args:
schema: The schema of the entities to extract.
@ -63,7 +64,8 @@ def create_tagging_chain_pydantic(
prompt: Optional[ChatPromptTemplate] = None,
**kwargs: Any
) -> Chain:
"""Creates a chain that extracts information from a passage.
"""Creates a chain that extracts information from a passage
based on a pydantic schema.
Args:
pydantic_schema: The pydantic schema of the entities to extract.

View File

@ -10,6 +10,8 @@ from langchain.schema.language_model import BaseLanguageModel
class BasePromptSelector(BaseModel, ABC):
"""Base class for prompt selectors."""
@abstractmethod
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
"""Get default prompt for a language model."""
@ -19,11 +21,21 @@ class ConditionalPromptSelector(BasePromptSelector):
"""Prompt collection that goes through conditionals."""
default_prompt: BasePromptTemplate
"""Default prompt to use if no conditionals match."""
conditionals: List[
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
] = Field(default_factory=list)
"""List of conditionals and prompts to use if the conditionals match."""
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
"""Get default prompt for a language model.
Args:
llm: Language model to get prompt for.
Returns:
Prompt to use for the language model.
"""
for condition, prompt in self.conditionals:
if condition(llm):
return prompt

View File

@ -15,13 +15,20 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
class QAGenerationChain(Chain):
"""Base class for question-answer generation chains."""
llm_chain: LLMChain
"""LLM Chain that generates responses from user input and context."""
text_splitter: TextSplitter = Field(
default=RecursiveCharacterTextSplitter(chunk_overlap=500)
)
"""Text splitter that splits the input into chunks."""
input_key: str = "text"
"""Key of the input to the chain."""
output_key: str = "questions"
"""Key of the output of the chain."""
k: Optional[int] = None
"""Number of questions to generate."""
@classmethod
def from_llm(
@ -30,6 +37,17 @@ class QAGenerationChain(Chain):
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> QAGenerationChain:
"""
Create a QAGenerationChain from a language model.
Args:
llm: a language model
prompt: a prompt template
**kwargs: additional arguments
Returns:
a QAGenerationChain class
"""
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
chain = LLMChain(llm=llm, prompt=_prompt)
return cls(llm_chain=chain, **kwargs)

View File

@ -31,7 +31,7 @@ from langchain.schema.language_model import BaseLanguageModel
class BaseQAWithSourcesChain(Chain, ABC):
"""Question answering with sources over documents."""
"""Question answering chain with sources over documents."""
combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine documents."""

View File

@ -155,7 +155,7 @@ def load_qa_with_sources_chain(
verbose: Optional[bool] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load question answering with sources chain.
"""Load a question answering with sources chain.
Args:
llm: Language Model to use in the chain.

View File

@ -27,6 +27,8 @@ from langchain.schema.language_model import BaseLanguageModel
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
"""Output parser that parses a structured query."""
ast_parse: Callable
"""Callable that parses dict into internal representation of query language."""
@ -57,6 +59,16 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
) -> StructuredQueryOutputParser:
"""
Create a structured query output parser from components.
Args:
allowed_comparators: allowed comparators
allowed_operators: allowed operators
Returns:
a structured query output parser
"""
ast_parser = get_parser(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
)

View File

@ -53,7 +53,17 @@ def _to_snake_case(name: str) -> str:
class Expr(BaseModel):
"""Base class for all expressions."""
def accept(self, visitor: Visitor) -> Any:
"""Accept a visitor.
Args:
visitor: visitor to accept
Returns:
result of visiting
"""
return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")(
self
)
@ -99,6 +109,11 @@ class Operation(FilterDirective):
class StructuredQuery(Expr):
"""A structured query."""
query: str
"""Query string."""
filter: Optional[FilterDirective]
"""Filtering expression."""
limit: Optional[int]
"""Limit on the number of results."""

View File

@ -54,7 +54,8 @@ GRAMMAR = """
@v_args(inline=True)
class QueryTransformer(Transformer):
"""Transforms a query string into an IR representation
(intermediate representation)."""
(intermediate representation).
"""
def __init__(
self,

View File

@ -25,12 +25,14 @@ from langchain.vectorstores.base import VectorStore
class BaseRetrievalQA(Chain):
"""Base class for question-answering chains."""
combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine the documents."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""Return the source documents."""
"""Return the source documents or not."""
class Config:
"""Configuration for this pydantic object."""
@ -41,7 +43,7 @@ class BaseRetrievalQA(Chain):
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
"""Input keys.
:meta private:
"""
@ -49,7 +51,7 @@ class BaseRetrievalQA(Chain):
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
"""Output keys.
:meta private:
"""

View File

@ -27,6 +27,16 @@ class RouterChain(Chain, ABC):
return ["destination", "next_inputs"]
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
"""
Route inputs to a destination chain.
Args:
inputs: inputs to the chain
callbacks: callbacks to use for the chain
Returns:
a Route object
"""
result = self(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])

View File

@ -12,7 +12,7 @@ from langchain.vectorstores.base import VectorStore
class EmbeddingRouterChain(RouterChain):
"""Class that uses embeddings to route between options."""
"""Chain that uses embeddings to route between options."""
vectorstore: VectorStore
routing_keys: List[str] = ["query"]