docstrings update (#12093)

Added missed docstrings. Added missed Args:, Returns: Raises:
pull/12215/head
Leonid Ganeline 12 months ago committed by GitHub
parent ba20c14e28
commit 11f13aed53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,7 +31,7 @@ from langchain.schema.messages import (
async def aenumerate(
iterable: AsyncIterator[Any], start: int = 0
) -> AsyncIterator[tuple[int, Any]]:
"""Async version of enumerate."""
"""Async version of enumerate function."""
i = start
async for x in iterable:
yield i, x
@ -39,6 +39,14 @@ async def aenumerate(
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Args:
_dict: The dictionary.
Returns:
The LangChain message.
"""
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
@ -60,6 +68,14 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
@ -122,6 +138,8 @@ def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str
class ChatCompletion:
"""Chat completion."""
@overload
@staticmethod
def create(
@ -217,7 +235,14 @@ def _has_assistant_message(session: ChatSession) -> bool:
def convert_messages_for_finetuning(
sessions: Iterable[ChatSession],
) -> List[List[dict]]:
"""Convert messages to a list of lists of dictionaries for fine-tuning."""
"""Convert messages to a list of lists of dictionaries for fine-tuning.
Args:
sessions: The chat sessions.
Returns:
The list of lists of dictionaries.
"""
return [
[convert_message_to_dict(s) for s in session["messages"]]
for session in sessions

@ -6,6 +6,14 @@ from langchain.schema.agent import AgentAction
def format_xml(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> str:
"""Format the intermediate steps as XML.
Args:
intermediate_steps: The intermediate steps.
Returns:
The intermediate steps as XML.
"""
log = ""
for action, observation in intermediate_steps:
log += (

@ -108,6 +108,17 @@ def fix_filter_directive(
allowed_operators: Optional[Sequence[Operator]] = None,
allowed_attributes: Optional[Sequence[str]] = None,
) -> Optional[FilterDirective]:
"""Fix invalid filter directive.
Args:
filter: Filter directive to fix.
allowed_comparators: allowed comparators. Defaults to all comparators.
allowed_operators: allowed operators. Defaults to all operators.
allowed_attributes: allowed attributes. Defaults to all attributes.
Returns:
Fixed filter directive.
"""
if (
not (allowed_comparators or allowed_operators or allowed_attributes)
) or not filter:
@ -154,6 +165,14 @@ def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]:
"""Construct examples from input-output pairs.
Args:
input_output_pairs: Sequence of input-output pairs.
Returns:
List of examples.
"""
examples = []
for i, (_input, output) in enumerate(input_output_pairs):
structured_request = (
@ -192,6 +211,9 @@ def get_query_constructor_prompt(
schema_prompt: Prompt for describing query schema. Should have string input
variables allowed_comparators and allowed_operators.
**kwargs: Additional named params to pass to FewShotPromptTemplate init.
Returns:
A prompt template that can be used to construct queries.
"""
default_schema_prompt = (
SCHEMA_WITH_LIMIT_PROMPT if enable_limit else DEFAULT_SCHEMA_PROMPT

@ -22,6 +22,17 @@ from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatRes
def get_role(message: BaseMessage) -> str:
"""Get the role of the message.
Args:
message: The message.
Returns:
The role of the message.
Raises:
ValueError: If the message is of an unknown type.
"""
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
return "User"
elif isinstance(message, AIMessage):
@ -38,6 +49,16 @@ def get_cohere_chat_request(
connectors: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Get the request for the Cohere chat API.
Args:
messages: The messages.
connectors: The connectors.
**kwargs: The keyword arguments.
Returns:
The request for the Cohere chat API.
"""
documents = (
None
if "source_documents" not in kwargs

@ -49,6 +49,17 @@ _PDF_FILTER_WITHOUT_LOSS = [
def extract_from_images_with_rapidocr(
images: Sequence[Union[Iterable[np.ndarray], bytes]]
) -> str:
"""Extract text from images with RapidOCR.
Args:
images: Images to extract text from.
Returns:
Text extracted from images.
Raises:
ImportError: If `rapidocr-onnxruntime` package is not installed.
"""
try:
from rapidocr_onnxruntime import RapidOCR
except ImportError:

@ -61,7 +61,8 @@ def create_llm_result(
class Anyscale(BaseOpenAI):
"""Wrapper around Anyscale Endpoint.
"""Anyscale large language models.
To use, you should have the environment variable ``ANYSCALE_API_BASE`` and
``ANYSCALE_API_KEY``set with your Anyscale Endpoint, or pass it as a named
parameter to the constructor.

@ -8,7 +8,8 @@ from langchain.llms.base import LLM
class NIBittensorLLM(LLM):
"""
"""NIBittensor LLMs
NIBittensorLLM is created by Neural Internet (https://neuralinternet.ai/),
powered by Bittensor, a decentralized network full of different AI models.

@ -18,6 +18,8 @@ from langchain.utils import get_from_dict_or_env
class TrainResult(TypedDict):
"""Train result."""
loss: float

@ -21,8 +21,7 @@ class Params(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
class JavelinAIGateway(LLM):
"""
Wrapper around completions LLMs in the Javelin AI Gateway.
"""Javelin AI Gateway LLMs.
To use, you should have the ``javelin_sdk`` python package installed.
For more information, see https://docs.getjavelin.io

@ -83,7 +83,16 @@ DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
def check_valid_template(
template: str, template_format: str, input_variables: List[str]
) -> None:
"""Check that template string is valid."""
"""Check that template string is valid.
Args:
template: The template string.
template_format: The template format. Should be one of "f-string" or "jinja2".
input_variables: The input variables.
Raises:
ValueError: If the template format is not supported.
"""
if template_format not in DEFAULT_FORMATTER_MAPPING:
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
raise ValueError(
@ -101,6 +110,18 @@ def check_valid_template(
def get_template_variables(template: str, template_format: str) -> List[str]:
"""Get the variables from the template.
Args:
template: The template string.
template_format: The template format. Should be one of "f-string" or "jinja2".
Returns:
The variables from the template.
Raises:
ValueError: If the template format is not supported.
"""
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)

@ -8,6 +8,8 @@ from langchain.schema.retriever import BaseRetriever
class SearchDepth(Enum):
"""Search depth as enumerator."""
BASIC = "basic"
ADVANCED = "advanced"
@ -31,7 +33,7 @@ class TavilySearchAPIRetriever(BaseRetriever):
try:
from tavily import Client
except ImportError:
raise ValueError(
raise ImportError(
"Tavily python package not found. "
"Please install it with `pip install tavily-python`."
)

@ -1,4 +1,4 @@
"""LangChain Runnables and the LangChain Expression Language (LCEL).
"""LangChain **Runnable** and the **LangChain Expression Language (LCEL)**.
The LangChain Expression Language (LCEL) offers a declarative method to build
production-grade programs that harness the power of LLMs.
@ -6,10 +6,10 @@ production-grade programs that harness the power of LLMs.
Programs created using LCEL and LangChain Runnables inherently support
synchronous, asynchronous, batch, and streaming operations.
Support for async allows servers hosting LCEL based programs to scale better
Support for **async** allows servers hosting LCEL based programs to scale better
for higher concurrent loads.
Streaming of intermediate outputs as they're being generated allows for
**Streaming** of intermediate outputs as they're being generated allows for
creating more responsive UX.
This module contains schema and implementation of LangChain Runnables primitives.

@ -2627,6 +2627,14 @@ RunnableLike = Union[
def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
"""Coerce a runnable-like object into a Runnable.
Args:
thing: A runnable-like object.
Returns:
A Runnable.
"""
if isinstance(thing, Runnable):
return thing
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):

@ -35,6 +35,8 @@ if TYPE_CHECKING:
class EmptyDict(TypedDict, total=False):
"""Empty dict type."""
pass
@ -85,6 +87,15 @@ class RunnableConfig(TypedDict, total=False):
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
"""Ensure that a config is a dict with all keys present.
Args:
config (Optional[RunnableConfig], optional): The config to ensure.
Defaults to None.
Returns:
RunnableConfig: The ensured config.
"""
empty = RunnableConfig(
tags=[],
metadata={},
@ -101,9 +112,21 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
def get_config_list(
config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
"""
Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch().
"""Get a list of configs from a single config or a list of configs.
It is useful for subclasses overriding batch() or abatch().
Args:
config (Optional[Union[RunnableConfig, List[RunnableConfig]]]):
The config or list of configs.
length (int): The length of the list.
Returns:
List[RunnableConfig]: The list of configs.
Raises:
ValueError: If the length of the list is not equal to the length of the inputs.
"""
if length < 0:
raise ValueError(f"length must be >= 0, but got {length}")
@ -129,9 +152,27 @@ def patch_config(
run_name: Optional[str] = None,
configurable: Optional[Dict[str, Any]] = None,
) -> RunnableConfig:
"""Patch a config with new values.
Args:
config (Optional[RunnableConfig]): The config to patch.
copy_locals (bool, optional): Whether to copy locals. Defaults to False.
callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
Defaults to None.
recursion_limit (Optional[int], optional): The recursion limit to set.
Defaults to None.
max_concurrency (Optional[int], optional): The max concurrency to set.
Defaults to None.
run_name (Optional[str], optional): The run name to set. Defaults to None.
configurable (Optional[Dict[str, Any]], optional): The configurable to set.
Defaults to None.
Returns:
RunnableConfig: The patched config.
"""
config = ensure_config(config)
if callbacks is not None:
# If we're replacing callbacks we need to unset run_name
# If we're replacing callbacks, we need to unset run_name
# As that should apply only to the same run as the original callbacks
config["callbacks"] = callbacks
if "run_name" in config:
@ -148,9 +189,17 @@ def patch_config(
def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
"""Merge multiple configs into one.
Args:
*configs (Optional[RunnableConfig]): The configs to merge.
Returns:
RunnableConfig: The merged config.
"""
base: RunnableConfig = {}
# Even though the keys aren't literals this is correct
# because both dicts are same type
# Even though the keys aren't literals, this is correct
# because both dicts are the same type
for config in (c for c in configs if c is not None):
for key in config:
if key == "metadata":
@ -184,7 +233,22 @@ def call_func_with_variable_args(
run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
"""Call function that may optionally accept a run_manager and/or config.
Args:
func (Union[Callable[[Input], Output],
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]):
The function to call.
input (Input): The input to the function.
run_manager (CallbackManagerForChainRun): The run manager to
pass to the function.
config (RunnableConfig): The config to pass to the function.
**kwargs (Any): The keyword arguments to pass to the function.
Returns:
Output: The output of the function.
"""
if accepts_config(func):
if run_manager is not None:
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
@ -210,7 +274,22 @@ async def acall_func_with_variable_args(
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
"""Call function that may optionally accept a run_manager and/or config.
Args:
func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input,
AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[[Input,
AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]):
The function to call.
input (Input): The input to the function.
run_manager (AsyncCallbackManagerForChainRun): The run manager
to pass to the function.
config (RunnableConfig): The config to pass to the function.
**kwargs (Any): The keyword arguments to pass to the function.
Returns:
Output: The output of the function.
"""
if accepts_config(func):
if run_manager is not None:
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
@ -222,6 +301,14 @@ async def acall_func_with_variable_args(
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
"""Get a callback manager for a config.
Args:
config (RunnableConfig): The config.
Returns:
CallbackManager: The callback manager.
"""
from langchain.callbacks.manager import CallbackManager
return CallbackManager.configure(
@ -234,6 +321,14 @@ def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
def get_async_callback_manager_for_config(
config: RunnableConfig,
) -> AsyncCallbackManager:
"""Get an async callback manager for a config.
Args:
config (RunnableConfig): The config.
Returns:
AsyncCallbackManager: The async callback manager.
"""
from langchain.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager.configure(
@ -245,5 +340,13 @@ def get_async_callback_manager_for_config(
@contextmanager
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
"""Get an executor for a config.
Args:
config (RunnableConfig): The config.
Yields:
Generator[Executor, None, None]: The executor.
"""
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
yield executor

@ -7,6 +7,7 @@ def sanitize(
"""
Sanitize input string or dict of strings by replacing sensitive data with
placeholders.
It returns the sanitized input string or dict of strings and the secure
context as a dict following the format:
{
@ -29,6 +30,10 @@ def sanitize(
}
The `secure_context` needs to be passed to the `desanitize` function.
Raises:
ValueError: If the input is not a string or dict of strings.
ImportError: If the `opaqueprompts` Python package is not installed.
"""
try:
import opaqueprompts as op

Loading…
Cancel
Save