docs: langchain docstrings updates (#21032)

Added missed docstings. Formatted docstrings into a consistent format.
pull/21144/head
Leonid Ganeline 1 month ago committed by GitHub
parent 85094cbb3a
commit 08d08d7c83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -16,6 +16,15 @@ from langchain.chains.llm import LLMChain
def remove_prefix(text: str, prefix: str) -> str:
"""Remove a prefix from a text.
Args:
text: Text to remove the prefix from.
prefix: Prefix to remove from the text.
Returns:
Text with the prefix removed.
"""
if text.startswith(prefix):
return text[len(prefix) :]
return text

@ -54,6 +54,14 @@ SPARQL_GENERATION_PROMPT = PromptTemplate(
def extract_sparql(query: str) -> str:
"""Extract SPARQL code from a text.
Args:
query: Text to extract SPARQL code from.
Returns:
SPARQL code extracted from the text.
"""
query = query.strip()
querytoks = query.split("```")
if len(querytoks) == 3:

@ -35,7 +35,7 @@ def create_tagging_chain(
prompt: Optional[ChatPromptTemplate] = None,
**kwargs: Any,
) -> Chain:
"""Creates a chain that extracts information from a passage
"""Create a chain that extracts information from a passage
based on a schema.
Args:
@ -65,7 +65,7 @@ def create_tagging_chain_pydantic(
prompt: Optional[ChatPromptTemplate] = None,
**kwargs: Any,
) -> Chain:
"""Creates a chain that extracts information from a passage
"""Create a chain that extracts information from a passage
based on a pydantic schema.
Args:

@ -3,7 +3,7 @@ from typing import Any, Dict
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
"""
Resolves the $ref keys in a JSON schema object using the provided definitions.
Resolve the $ref keys in a JSON schema object using the provided definitions.
"""
if isinstance(schema, list):
for i, item in enumerate(schema):
@ -29,7 +29,7 @@ def _convert_schema(schema: dict) -> dict:
def get_llm_kwargs(function: dict) -> dict:
"""Returns the kwargs for the LLMChain constructor.
"""Return the kwargs for the LLMChain constructor.
Args:
function: The function to use.

@ -63,7 +63,7 @@ class ISO8601Date(TypedDict):
@v_args(inline=True)
class QueryTransformer(Transformer):
"""Transforms a query string into an intermediate representation."""
"""Transform a query string into an intermediate representation."""
def __init__(
self,
@ -159,8 +159,7 @@ def get_parser(
allowed_operators: Optional[Sequence[Operator]] = None,
allowed_attributes: Optional[Sequence[str]] = None,
) -> Lark:
"""
Returns a parser for the query language.
"""Return a parser for the query language.
Args:
allowed_comparators: Optional[Sequence[Comparator]]

@ -9,7 +9,7 @@ from langchain.evaluation.schema import StringEvaluator
class JsonValidityEvaluator(StringEvaluator):
"""Evaluates whether the prediction is valid JSON.
"""Evaluate whether the prediction is valid JSON.
This evaluator checks if the prediction is a valid JSON string. It does not
require any input or reference.
@ -77,7 +77,7 @@ class JsonValidityEvaluator(StringEvaluator):
class JsonEqualityEvaluator(StringEvaluator):
"""Evaluates whether the prediction is equal to the reference after
"""Evaluate whether the prediction is equal to the reference after
parsing both as JSON.
This evaluator checks if the prediction, after parsing as JSON, is equal

@ -37,7 +37,7 @@ def push(
new_repo_description: str = "",
) -> str:
"""
Pushes an object to the hub and returns the URL it can be viewed at in a browser.
Push an object to the hub and returns the URL it can be viewed at in a browser.
:param repo_full_name: The full name of the repo to push to in the format of
`owner/repo`.
@ -71,7 +71,7 @@ def pull(
api_key: Optional[str] = None,
) -> Any:
"""
Pulls an object from the hub and returns it as a LangChain object.
Pull an object from the hub and returns it as a LangChain object.
:param owner_repo_commit: The full name of the repo to pull from in the format of
`owner/repo:commit_hash`.

@ -4,7 +4,7 @@ from langchain_core.memory import BaseMemory
class ReadOnlySharedMemory(BaseMemory):
"""A memory wrapper that is read-only and cannot be changed."""
"""Memory wrapper that is read-only and cannot be changed."""
memory: BaseMemory

@ -13,7 +13,7 @@ T = TypeVar("T")
class OutputFixingParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors."""
"""Wrap a parser and try to fix parsing errors."""
@classmethod
def is_lc_serializable(cls) -> bool:

@ -34,7 +34,7 @@ T = TypeVar("T")
class RetryOutputParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors.
"""Wrap a parser and try to fix parsing errors.
Does this by passing the original prompt and the completion to another
LLM, and telling it the completion did not satisfy criteria in the prompt.
@ -138,7 +138,7 @@ class RetryOutputParser(BaseOutputParser[T]):
class RetryWithErrorOutputParser(BaseOutputParser[T]):
"""Wraps a parser and tries to fix parsing errors.
"""Wrap a parser and try to fix parsing errors.
Does this by passing the original prompt, the completion, AND the error
that was raised to another language model and telling it that the completion

@ -15,7 +15,7 @@ line_template = '\t"{name}": {type} // {description}'
class ResponseSchema(BaseModel):
"""A schema for a response from a structured output parser."""
"""Schema for a response from a structured output parser."""
name: str
"""The name of the schema."""

@ -38,6 +38,15 @@ H = TypeVar("H", bound=Hashable)
def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]:
"""Yield unique elements of an iterable based on a key function.
Args:
iterable: The iterable to filter.
key: A function that returns a hashable key for each element.
Yields:
Unique elements of the iterable based on the key function.
"""
seen = set()
for e in iterable:
if (k := key(e)) not in seen:

@ -13,6 +13,8 @@ from langchain_core.structured_query import (
class TencentVectorDBTranslator(Visitor):
"""Translate StructuredQuery to Tencent VectorDB query."""
COMPARATOR_MAP = {
Comparator.EQ: "=",
Comparator.NE: "!=",
@ -32,9 +34,22 @@ class TencentVectorDBTranslator(Visitor):
]
def __init__(self, meta_keys: Optional[Sequence[str]] = None):
"""Initialize the translator.
Args:
meta_keys: List of meta keys to be used in the query. Default: [].
"""
self.meta_keys = meta_keys or []
def visit_operation(self, operation: Operation) -> str:
"""Visit an operation node and return the translated query.
Args:
operation: Operation node to be visited.
Returns:
Translated query.
"""
if operation.operator in (Operator.AND, Operator.OR):
ret = f" {operation.operator.value} ".join(
[arg.accept(self) for arg in operation.arguments]
@ -46,6 +61,14 @@ class TencentVectorDBTranslator(Visitor):
return f"not ({operation.arguments[0].accept(self)})"
def visit_comparison(self, comparison: Comparison) -> str:
"""Visit a comparison node and return the translated query.
Args:
comparison: Comparison node to be visited.
Returns:
Translated query.
"""
if self.meta_keys and comparison.attribute not in self.meta_keys:
raise ValueError(
f"Expr Filtering found Unsupported attribute: {comparison.attribute}"
@ -78,6 +101,14 @@ class TencentVectorDBTranslator(Visitor):
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
"""Visit a structured query node and return the translated query.
Args:
structured_query: StructuredQuery node to be visited.
Returns:
Translated query and query kwargs.
"""
if structured_query.filter is None:
kwargs = {}
else:

@ -281,6 +281,12 @@ def _get_prompt(inputs: Dict[str, Any]) -> str:
class ChatModelInput(TypedDict):
"""Input for a chat model.
Parameters:
messages: List of chat messages.
"""
messages: List[BaseMessage]

Loading…
Cancel
Save