docstrings cleanup (#11640)

Added missed docstrings. Some reformatting.
pull/11646/head
Leonid Ganeline 1 year ago committed by GitHub
parent 78b4c7d5a0
commit db67ccb0bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,6 +21,7 @@ INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def trim_query(query: str) -> str:
"""Trim the query to only include Cypher keywords."""
keywords = (
"CALL",
"CREATE",

@ -42,6 +42,7 @@ def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a message to a dictionary that can be passed to the API."""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
@ -105,11 +106,12 @@ class QianfanChatEndpoint(BaseChatModel):
"""
model: str = "ERNIE-Bot-turbo"
"""Model name.
"""Model name.
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
preset models are mapping to an endpoint.
`model` will be ignored if `endpoint` is set
`model` will be ignored if `endpoint` is set.
Default is ERNIE-Bot-turbo.
"""
endpoint: Optional[str] = None

@ -6,7 +6,17 @@ from langchain.schema import Document
class MsWordParser(BaseBlobParser):
"""Parse the Microsoft Word documents from a blob."""
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Parse a Microsoft Word document into the Document iterator.
Args:
blob: The blob to parse.
Returns: An iterator of Documents.
"""
try:
from unstructured.partition.doc import partition_doc
from unstructured.partition.docx import partition_docx

@ -6,7 +6,8 @@ from langchain.schema.embeddings import Embeddings
class XinferenceEmbeddings(Embeddings):
"""Wrapper around xinference embedding models.
"""Xinference embedding models.
To use, you should have the xinference library installed:
.. code-block:: bash

@ -1,4 +1,4 @@
"""Push and pull to the LangChain Hub."""
"""Interface with the LangChain Hub."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional

@ -46,6 +46,15 @@ def _custom_parser(multiline_string: str) -> str:
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
# MIT License
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
"""Parse a JSON string that may be missing closing braces.
Args:
s: The JSON string to parse.
strict: Whether to use strict parsing. Defaults to False.
Returns:
The parsed JSON object as a Python dictionary.
"""
# Attempt to parse the string as-is.
try:
return json.loads(s, strict=strict)

@ -22,7 +22,17 @@ UNARY_OPERATORS = [Operator.NOT]
def process_value(value: Union[int, float, str]) -> str:
# required for comparators involving strings
"""Convert a value to a string and add double quotes if it is a string.
It required for comparators involving strings.
Args:
value: The value to convert.
Returns:
The converted value as a string.
"""
#
if isinstance(value, str):
# If the value is already a string, add double quotes
return f'"{value}"'

@ -11,6 +11,7 @@ from langchain.chains.query_constructor.ir import (
def process_value(value: Union[int, float, str]) -> str:
"""Convert a value to a string and add single quotes if it is a string."""
if isinstance(value, str):
return f"'{value}'"
else:

@ -30,11 +30,20 @@ Output = TypeVar("Output")
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
"""Run a coroutine with a semaphore.
Args:
semaphore: The semaphore to use.
coro: The coroutine to run.
Returns:
The result of the coroutine.
"""
async with semaphore:
return await coro
async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
"""Gather coroutines with a limit on the number of concurrent coroutines."""
if n is None:
return await asyncio.gather(*coros)
@ -44,6 +53,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
"""Check if a callable accepts a run_manager argument."""
try:
return signature(callable).parameters.get("run_manager") is not None
except ValueError:
@ -51,6 +61,7 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
def accepts_config(callable: Callable[..., Any]) -> bool:
"""Check if a callable accepts a config argument."""
try:
return signature(callable).parameters.get("config") is not None
except ValueError:
@ -58,6 +69,8 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
class IsLocalDict(ast.NodeVisitor):
"""Check if a name is a local dict."""
def __init__(self, name: str, keys: Set[str]) -> None:
self.name = name
self.keys = keys
@ -88,6 +101,8 @@ class IsLocalDict(ast.NodeVisitor):
class IsFunctionArgDict(ast.NodeVisitor):
"""Check if the first argument of a function is a dict."""
def __init__(self) -> None:
self.keys: Set[str] = set()
@ -105,17 +120,22 @@ class IsFunctionArgDict(ast.NodeVisitor):
class GetLambdaSource(ast.NodeVisitor):
"""Get the source code of a lambda function."""
def __init__(self) -> None:
"""Initialize the visitor."""
self.source: Optional[str] = None
self.count = 0
def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function."""
self.count += 1
if hasattr(ast, "unparse"):
self.source = ast.unparse(node)
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
"""Get the keys of the first argument of a function if it is a dict."""
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
@ -190,6 +210,8 @@ _T_contra = TypeVar("_T_contra", contravariant=True)
class SupportsAdd(Protocol[_T_contra, _T_co]):
"""Protocol for objects that support addition."""
def __add__(self, __x: _T_contra) -> _T_co:
...
@ -198,6 +220,7 @@ Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
def add(addables: Iterable[Addable]) -> Optional[Addable]:
"""Add a sequence of addable objects together."""
final = None
for chunk in addables:
if final is None:
@ -208,6 +231,7 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]:
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
"""Asynchronously add a sequence of addable objects together."""
final = None
async for chunk in addables:
if final is None:
@ -218,6 +242,8 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
class ConfigurableField(NamedTuple):
"""A field that can be configured by the user."""
id: str
name: Optional[str] = None
@ -226,6 +252,8 @@ class ConfigurableField(NamedTuple):
class ConfigurableFieldSingleOption(NamedTuple):
"""A field that can be configured by the user with a default value."""
id: str
options: Mapping[str, Any]
default: str
@ -235,6 +263,8 @@ class ConfigurableFieldSingleOption(NamedTuple):
class ConfigurableFieldMultiOption(NamedTuple):
"""A field that can be configured by the user with multiple default values."""
id: str
options: Mapping[str, Any]
default: Sequence[str]
@ -249,6 +279,8 @@ AnyConfigurableField = Union[
class ConfigurableFieldSpec(NamedTuple):
"""A field that can be configured by the user. It is a specification of a field."""
id: str
name: Optional[str]
description: Optional[str]
@ -260,6 +292,7 @@ class ConfigurableFieldSpec(NamedTuple):
def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec],
) -> Sequence[ConfigurableFieldSpec]:
"""Get the unique config specs from a sequence of config specs."""
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
unique: List[ConfigurableFieldSpec] = []
for id, dupes in grouped:

@ -642,10 +642,16 @@ class HTMLHeaderTextSplitter:
# @dataclass(frozen=True, kw_only=True, slots=True)
@dataclass(frozen=True)
class Tokenizer:
"""Tokenizer data class."""
chunk_overlap: int
"""Overlap in tokens between chunks"""
tokens_per_chunk: int
"""Maximum number of tokens per chunk"""
decode: Callable[[list[int]], str]
""" Function to decode a list of token ids to a string"""
encode: Callable[[str], List[int]]
""" Function to encode a string to a list of token ids"""
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:

@ -9,11 +9,15 @@ from langchain.tools.ainetwork.base import AINBaseTool
class AppOperationType(str, Enum):
"""Type of app operation as enumerator."""
SET_ADMIN = "SET_ADMIN"
GET_CONFIG = "GET_CONFIG"
class AppSchema(BaseModel):
"""Schema for app operations."""
type: AppOperationType = Field(...)
appName: str = Field(..., description="Name of the application on the blockchain")
address: Optional[Union[str, List[str]]] = Field(
@ -26,6 +30,8 @@ class AppSchema(BaseModel):
class AINAppOps(AINBaseTool):
"""Tool for app operations."""
name: str = "AINappOps"
description: str = """
Create an app in the AINetwork Blockchain database by creating the /apps/<appName> path.

@ -1,4 +1,3 @@
"""Base class for AINetwork tools."""
from __future__ import annotations
import asyncio
@ -16,6 +15,8 @@ if TYPE_CHECKING:
class OperationType(str, Enum):
"""Type of operation as enumerator."""
SET = "SET"
GET = "GET"

@ -8,6 +8,8 @@ from langchain.tools.ainetwork.base import AINBaseTool, OperationType
class RuleSchema(BaseModel):
"""Schema for owner operations."""
type: OperationType = Field(...)
path: str = Field(..., description="Blockchain reference path")
address: Optional[Union[str, List[str]]] = Field(
@ -28,6 +30,8 @@ class RuleSchema(BaseModel):
class AINOwnerOps(AINBaseTool):
"""Tool for owner operations."""
name: str = "AINownerOps"
description: str = """
Rules for `owner` in AINetwork Blockchain database.

@ -8,12 +8,16 @@ from langchain.tools.ainetwork.base import AINBaseTool, OperationType
class RuleSchema(BaseModel):
"""Schema for owner operations."""
type: OperationType = Field(...)
path: str = Field(..., description="Path on the blockchain where the rule applies")
eval: Optional[str] = Field(None, description="eval string to determine permission")
class AINRuleOps(AINBaseTool):
"""Tool for owner operations."""
name: str = "AINruleOps"
description: str = """
Covers the write `rule` for the AINetwork Blockchain database. The SET type specifies write permissions using the `eval` variable as a JavaScript eval string.

@ -7,11 +7,15 @@ from langchain.tools.ainetwork.base import AINBaseTool
class TransferSchema(BaseModel):
"""Schema for transfer operations."""
address: str = Field(..., description="Address to transfer AIN to")
amount: int = Field(..., description="Amount of AIN to transfer")
class AINTransfer(AINBaseTool):
"""Tool for transfer operations."""
name: str = "AINtransfer"
description: str = "Transfers AIN to a specified address"
args_schema: Type[TransferSchema] = TransferSchema

@ -8,6 +8,8 @@ from langchain.tools.ainetwork.base import AINBaseTool, OperationType
class ValueSchema(BaseModel):
"""Schema for value operations."""
type: OperationType = Field(...)
path: str = Field(..., description="Blockchain reference path")
value: Optional[Union[int, str, float, dict]] = Field(
@ -16,6 +18,8 @@ class ValueSchema(BaseModel):
class AINValueOps(AINBaseTool):
"""Tool for value operations."""
name: str = "AINvalueOps"
description: str = """
Covers the read and write value for the AINetwork Blockchain database.

@ -31,6 +31,15 @@ DEFAULT_LINK_REGEX = (
def find_all_links(
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
) -> List[str]:
"""Extract all links from a raw html string.
Args:
raw_html: original html.
pattern: Regex to use for extracting links from raw html.
Returns:
List[str]: all links
"""
pattern = pattern or DEFAULT_LINK_REGEX
return list(set(re.findall(pattern, raw_html)))

@ -21,6 +21,7 @@ def convert_pydantic_to_openai_function(
name: Optional[str] = None,
description: Optional[str] = None
) -> FunctionDescription:
"""Converts a Pydantic model to a function description for the OpenAI API."""
schema = dereference_refs(model.schema())
schema.pop("definitions", None)
return {

@ -16,7 +16,10 @@ from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
class LLMRails(VectorStore):
"""Implementation of Vector Store using LLMRails (https://llmrails.com/).
"""Implementation of Vector Store using LLMRails.
See https://llmrails.com/
Example:
.. code-block:: python
@ -224,6 +227,8 @@ class LLMRails(VectorStore):
class LLMRailsRetriever(VectorStoreRetriever):
"""Retriever for LLMRails."""
vectorstore: LLMRails
search_kwargs: dict = Field(default_factory=lambda: {"k": 5})
"""Search params.

@ -61,6 +61,7 @@ def _get_search_index_query(search_type: SearchType) -> str:
def check_if_not_null(props: List[str], values: List[Any]) -> None:
"""Check if the values are not None or empty string"""
for prop, value in zip(props, values):
if not value:
raise ValueError(f"Parameter `{prop}` must not be None or empty string")

@ -80,7 +80,7 @@ def check_index_exists(client: RedisType, index_name: str) -> bool:
class Redis(VectorStore):
"""Wrapper around Redis vector database.
"""Redis vector database.
To use, you should have the ``redis`` python package installed
and have a running Redis Enterprise or Redis-Stack server

@ -10,6 +10,8 @@ from langchain.utilities.redis import TokenEscaper
class RedisFilterOperator(Enum):
"""RedisFilterOperator enumerator is used to create RedisFilterExpressions."""
EQ = 1
NE = 2
LT = 3
@ -23,6 +25,8 @@ class RedisFilterOperator(Enum):
class RedisFilter:
"""Collection of RedisFilterFields."""
@staticmethod
def text(field: str) -> "RedisText":
return RedisText(field)
@ -37,6 +41,8 @@ class RedisFilter:
class RedisFilterField:
"""Base class for RedisFilterFields."""
escaper: "TokenEscaper" = TokenEscaper()
OPERATORS: Dict[RedisFilterOperator, str] = {}
@ -72,6 +78,8 @@ class RedisFilterField:
def check_operator_misuse(func: Callable) -> Callable:
"""Decorator to check for misuse of equality operators."""
@wraps(func)
def wrapper(instance: Any, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
# Extracting 'other' from positional arguments or keyword arguments
@ -93,7 +101,7 @@ def check_operator_misuse(func: Callable) -> Callable:
class RedisTag(RedisFilterField):
"""A RedisTag is a RedisFilterField representing a tag in a Redis index."""
"""A RedisFilterField representing a tag in a Redis index."""
OPERATORS: Dict[RedisFilterOperator, str] = {
RedisFilterOperator.EQ: "==",
@ -293,7 +301,7 @@ class RedisNum(RedisFilterField):
class RedisText(RedisFilterField):
"""A RedisText is a RedisFilterField representing a text field in a Redis index."""
"""A RedisFilterField representing a text field in a Redis index."""
OPERATORS = {
RedisFilterOperator.EQ: "==",
@ -361,7 +369,7 @@ class RedisText(RedisFilterField):
class RedisFilterExpression:
"""A RedisFilterExpression is a logical expression of RedisFilterFields.
"""A logical expression of RedisFilterFields.
RedisFilterExpressions can be combined using the & and | operators to create
complex logical expressions that evaluate to the Redis Query language.

@ -22,16 +22,22 @@ if TYPE_CHECKING:
class RedisDistanceMetric(str, Enum):
"""Distance metrics for Redis vector fields."""
l2 = "L2"
cosine = "COSINE"
ip = "IP"
class RedisField(BaseModel):
"""Base class for Redis fields."""
name: str = Field(...)
class TextFieldSchema(RedisField):
"""Schema for text fields in Redis."""
weight: float = 1
no_stem: bool = False
phonetic_matcher: Optional[str] = None
@ -53,6 +59,8 @@ class TextFieldSchema(RedisField):
class TagFieldSchema(RedisField):
"""Schema for tag fields in Redis."""
separator: str = ","
case_sensitive: bool = False
no_index: bool = False
@ -71,6 +79,8 @@ class TagFieldSchema(RedisField):
class NumericFieldSchema(RedisField):
"""Schema for numeric fields in Redis."""
no_index: bool = False
sortable: Optional[bool] = False
@ -81,6 +91,8 @@ class NumericFieldSchema(RedisField):
class RedisVectorField(RedisField):
"""Base class for Redis vector fields."""
dims: int = Field(...)
algorithm: object = Field(...)
datatype: str = Field(default="FLOAT32")
@ -101,6 +113,8 @@ class RedisVectorField(RedisField):
class FlatVectorField(RedisVectorField):
"""Schema for flat vector fields in Redis."""
algorithm: Literal["FLAT"] = "FLAT"
block_size: int = Field(default=1000)
@ -121,6 +135,8 @@ class FlatVectorField(RedisVectorField):
class HNSWVectorField(RedisVectorField):
"""Schema for HNSW vector fields in Redis."""
algorithm: Literal["HNSW"] = "HNSW"
m: int = Field(default=16)
ef_construction: int = Field(default=200)
@ -147,6 +163,8 @@ class HNSWVectorField(RedisVectorField):
class RedisModel(BaseModel):
"""Schema for Redis index."""
# always have a content field for text
text: List[TextFieldSchema] = [TextFieldSchema(name="content")]
tag: Optional[List[TagFieldSchema]] = None
@ -268,8 +286,11 @@ class RedisModel(BaseModel):
def read_schema(
index_schema: Optional[Union[Dict[str, str], str, os.PathLike]]
) -> Dict[str, Any]:
# check if its a dict and return RedisModel otherwise, check if it's a path and
# read in the file assuming it's a yaml file and return a RedisModel
"""Reads in the index schema from a dict or yaml file.
Check if it is a dict and return RedisModel otherwise, check if it's a path and
read in the file assuming it's a yaml file and return a RedisModel
"""
if isinstance(index_schema, dict):
return index_schema
elif isinstance(index_schema, Path):

Loading…
Cancel
Save