community: docstrings (#23202)

Added missed docstrings. Format docstrings to the consistent format
(used in the API Reference)
This commit is contained in:
Leonid Ganeline 2024-06-20 08:08:13 -07:00 committed by GitHub
parent 6a1a0d977a
commit 51e75cf59d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 199 additions and 4 deletions

View File

@ -13,6 +13,14 @@ if TYPE_CHECKING:
def nanoseconds_from_2001_to_datetime(nanoseconds: int) -> datetime: def nanoseconds_from_2001_to_datetime(nanoseconds: int) -> datetime:
"""Convert nanoseconds since 2001 to a datetime object.
Args:
nanoseconds (int): Nanoseconds since January 1, 2001.
Returns:
datetime: Datetime object.
"""
# Convert nanoseconds to seconds (1 second = 1e9 nanoseconds) # Convert nanoseconds to seconds (1 second = 1e9 nanoseconds)
timestamp_in_seconds = nanoseconds / 1e9 timestamp_in_seconds = nanoseconds / 1e9

View File

@ -26,6 +26,14 @@ logger = logging.getLogger(__name__)
def condense_zep_memory_into_human_message(zep_memory: Memory) -> BaseMessage: def condense_zep_memory_into_human_message(zep_memory: Memory) -> BaseMessage:
"""Condense Zep memory into a human message.
Args:
zep_memory: The Zep memory object.
Returns:
BaseMessage: The human message.
"""
prompt = "" prompt = ""
if zep_memory.facts: if zep_memory.facts:
prompt = "\n".join(zep_memory.facts) prompt = "\n".join(zep_memory.facts)
@ -37,6 +45,16 @@ def condense_zep_memory_into_human_message(zep_memory: Memory) -> BaseMessage:
def get_zep_message_role_type(role: str) -> RoleType: def get_zep_message_role_type(role: str) -> RoleType:
"""Get the Zep role type from the role string.
Args:
role: The role string. One of "human", "ai", "system",
"function", "tool".
Returns:
RoleType: The Zep role type. One of "user", "assistant",
"system", "function", "tool".
"""
if role == "human": if role == "human":
return "user" return "user"
elif role == "ai": elif role == "ai":

View File

@ -1,4 +1,5 @@
"""Wrapper around Minimax chat models.""" """Wrapper around Minimax chat models."""
import json import json
import logging import logging
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
@ -32,6 +33,17 @@ logger = logging.getLogger(__name__)
@contextmanager @contextmanager
def connect_httpx_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator: def connect_httpx_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator:
"""Context manager for connecting to an SSE stream.
Args:
client: The httpx client.
method: The HTTP method.
url: The URL to connect to.
kwargs: Additional keyword arguments to pass to the client.
Yields:
An EventSource object.
"""
from httpx_sse import EventSource from httpx_sse import EventSource
with client.stream(method, url, **kwargs) as response: with client.stream(method, url, **kwargs) as response:
@ -42,6 +54,17 @@ def connect_httpx_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iter
async def aconnect_httpx_sse( async def aconnect_httpx_sse(
client: Any, method: str, url: str, **kwargs: Any client: Any, method: str, url: str, **kwargs: Any
) -> AsyncIterator: ) -> AsyncIterator:
"""Async context manager for connecting to an SSE stream.
Args:
client: The httpx client.
method: The HTTP method.
url: The URL to connect to.
kwargs: Additional keyword arguments to pass to the client.
Yields:
An EventSource object.
"""
from httpx_sse import EventSource from httpx_sse import EventSource
async with client.stream(method, url, **kwargs) as response: async with client.stream(method, url, **kwargs) as response:

View File

@ -42,6 +42,17 @@ ZHIPUAI_API_BASE = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
@contextmanager @contextmanager
def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator: def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator:
"""Context manager for connecting to an SSE stream.
Args:
client: The HTTP client.
method: The HTTP method.
url: The URL.
**kwargs: Additional keyword arguments.
Yields:
The event source.
"""
from httpx_sse import EventSource from httpx_sse import EventSource
with client.stream(method, url, **kwargs) as response: with client.stream(method, url, **kwargs) as response:
@ -52,6 +63,17 @@ def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator:
async def aconnect_sse( async def aconnect_sse(
client: Any, method: str, url: str, **kwargs: Any client: Any, method: str, url: str, **kwargs: Any
) -> AsyncIterator: ) -> AsyncIterator:
"""Async context manager for connecting to an SSE stream.
Args:
client: The HTTP client.
method: The HTTP method.
url: The URL.
**kwargs: Additional keyword arguments.
Yields:
The event source.
"""
from httpx_sse import EventSource from httpx_sse import EventSource
async with client.stream(method, url, **kwargs) as response: async with client.stream(method, url, **kwargs) as response:
@ -59,7 +81,9 @@ async def aconnect_sse(
def _get_jwt_token(api_key: str) -> str: def _get_jwt_token(api_key: str) -> str:
"""Gets JWT token for ZhipuAI API, see 'https://open.bigmodel.cn/dev/api#nosdk'. """Gets JWT token for ZhipuAI API.
See 'https://open.bigmodel.cn/dev/api#nosdk'.
Args: Args:
api_key: The API key for ZhipuAI API. api_key: The API key for ZhipuAI API.

View File

@ -12,6 +12,14 @@ JINA_API_URL: str = "https://api.jina.ai/v1/embeddings"
def is_local(url: str) -> bool: def is_local(url: str) -> bool:
"""Check if a URL is a local file.
Args:
url (str): The URL to check.
Returns:
bool: True if the URL is a local file, False otherwise.
"""
url_parsed = urlparse(url) url_parsed = urlparse(url)
if url_parsed.scheme in ("file", ""): # Possibly a local file if url_parsed.scheme in ("file", ""): # Possibly a local file
return exists(url_parsed.path) return exists(url_parsed.path)
@ -19,6 +27,14 @@ def is_local(url: str) -> bool:
def get_bytes_str(file_path: str) -> str: def get_bytes_str(file_path: str) -> str:
"""Get the bytes string of a file.
Args:
file_path (str): The path to the file.
Returns:
str: The bytes string of the file.
"""
with open(file_path, "rb") as image_file: with open(file_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8") return base64.b64encode(image_file.read()).decode("utf-8")

View File

@ -30,6 +30,8 @@ logger = logging.getLogger(__name__)
class Url: class Url:
"""URL class for parsing the URL."""
def __init__(self, host: str, path: str, schema: str) -> None: def __init__(self, host: str, path: str, schema: str) -> None:
self.host = host self.host = host
self.path = path self.path = path

View File

@ -52,6 +52,16 @@ include_docs_query = (
def clean_string_values(text: str) -> str: def clean_string_values(text: str) -> str:
"""Clean string values for schema.
Cleans the input text by replacing newline and carriage return characters.
Args:
text (str): The input text to clean.
Returns:
str: The cleaned text.
"""
return text.replace("\n", " ").replace("\r", " ") return text.replace("\n", " ").replace("\r", " ")
@ -63,6 +73,12 @@ def value_sanitize(d: Any) -> Any:
generating answers in a LLM context. These properties, if left in generating answers in a LLM context. These properties, if left in
results, can occupy significant context space and detract from results, can occupy significant context space and detract from
the LLM's performance by introducing unnecessary noise and cost. the LLM's performance by introducing unnecessary noise and cost.
Args:
d (Any): The input dictionary or list to sanitize.
Returns:
Any: The sanitized dictionary or list.
""" """
if isinstance(d, dict): if isinstance(d, dict):
new_dict = {} new_dict = {}
@ -382,7 +398,15 @@ class Neo4jGraph(GraphStore):
return self.structured_schema return self.structured_schema
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query Neo4j database.""" """Query Neo4j database.
Args:
query (str): The Cypher query to execute.
params (dict): The parameters to pass to the query.
Returns:
List[Dict[str, Any]]: The list of dictionaries containing the query results.
"""
from neo4j import Query from neo4j import Query
from neo4j.exceptions import CypherSyntaxError from neo4j.exceptions import CypherSyntaxError

View File

@ -35,6 +35,17 @@ DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;
class CassandraByteStore(ByteStore): class CassandraByteStore(ByteStore):
"""A ByteStore implementation using Cassandra as the backend.
Parameters:
table: The name of the table to use.
session: A Cassandra session object. If not provided, it will be resolved
from the cassio config.
keyspace: The keyspace to use. If not provided, it will be resolved
from the cassio config.
setup_mode: The setup mode to use. Default is SYNC (SetupMode.SYNC).
"""
def __init__( def __init__(
self, self,
table: str, table: str,
@ -75,6 +86,7 @@ class CassandraByteStore(ByteStore):
self.session.execute(create_cql) self.session.execute(create_cql)
def ensure_db_setup(self) -> None: def ensure_db_setup(self) -> None:
"""Ensure that the DB setup is finished. If not, raise a ValueError."""
if self.db_setup_task: if self.db_setup_task:
try: try:
self.db_setup_task.result() self.db_setup_task.result()
@ -86,10 +98,17 @@ class CassandraByteStore(ByteStore):
) )
async def aensure_db_setup(self) -> None: async def aensure_db_setup(self) -> None:
"""Ensure that the DB setup is finished. If not, wait for it."""
if self.db_setup_task: if self.db_setup_task:
await self.db_setup_task await self.db_setup_task
def get_select_statement(self) -> PreparedStatement: def get_select_statement(self) -> PreparedStatement:
"""Get the prepared select statement for the table.
If not available, prepare it.
Returns:
PreparedStatement: The prepared statement.
"""
if not self.select_statement: if not self.select_statement:
self.select_statement = self.session.prepare( self.select_statement = self.session.prepare(
SELECT_TABLE_CQL_TEMPLATE.format( SELECT_TABLE_CQL_TEMPLATE.format(
@ -99,6 +118,12 @@ class CassandraByteStore(ByteStore):
return self.select_statement return self.select_statement
def get_insert_statement(self) -> PreparedStatement: def get_insert_statement(self) -> PreparedStatement:
"""Get the prepared insert statement for the table.
If not available, prepare it.
Returns:
PreparedStatement: The prepared statement.
"""
if not self.insert_statement: if not self.insert_statement:
self.insert_statement = self.session.prepare( self.insert_statement = self.session.prepare(
INSERT_TABLE_CQL_TEMPLATE.format( INSERT_TABLE_CQL_TEMPLATE.format(
@ -108,6 +133,13 @@ class CassandraByteStore(ByteStore):
return self.insert_statement return self.insert_statement
def get_delete_statement(self) -> PreparedStatement: def get_delete_statement(self) -> PreparedStatement:
"""Get the prepared delete statement for the table.
If not available, prepare it.
Returns:
PreparedStatement: The prepared statement.
"""
if not self.delete_statement: if not self.delete_statement:
self.delete_statement = self.session.prepare( self.delete_statement = self.session.prepare(
DELETE_TABLE_CQL_TEMPLATE.format( DELETE_TABLE_CQL_TEMPLATE.format(

View File

@ -36,6 +36,16 @@ async def wrapped_response_future(
async def aexecute_cql(session: Session, query: str, **kwargs: Any) -> Any: async def aexecute_cql(session: Session, query: str, **kwargs: Any) -> Any:
"""Execute a CQL query asynchronously.
Args:
session: The Cassandra session to use.
query: The CQL query to execute.
**kwargs: Additional keyword arguments to pass to the session execute method.
Returns:
The result of the query.
"""
return await wrapped_response_future(session.execute_async, query, **kwargs) return await wrapped_response_future(session.execute_async, query, **kwargs)

View File

@ -173,6 +173,16 @@ def create_index(
vector_store: OracleVS, vector_store: OracleVS,
params: Optional[dict[str, Any]] = None, params: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
"""Create an index on the vector store.
Args:
client: The OracleDB connection object.
vector_store: The vector store object.
params: Optional parameters for the index creation.
Raises:
ValueError: If an invalid parameter is provided.
"""
if params: if params:
if params["idx_type"] == "HNSW": if params["idx_type"] == "HNSW":
_create_hnsw_index( _create_hnsw_index(
@ -351,6 +361,15 @@ def _create_ivf_index(
@_handle_exceptions @_handle_exceptions
def drop_table_purge(client: Connection, table_name: str) -> None: def drop_table_purge(client: Connection, table_name: str) -> None:
"""Drop a table and purge it from the database.
Args:
client: The OracleDB connection object.
table_name: The name of the table to drop.
Raises:
RuntimeError: If an error occurs while dropping the table.
"""
if _table_exists(client, table_name): if _table_exists(client, table_name):
cursor = client.cursor() cursor = client.cursor()
with cursor: with cursor:
@ -364,6 +383,15 @@ def drop_table_purge(client: Connection, table_name: str) -> None:
@_handle_exceptions @_handle_exceptions
def drop_index_if_exists(client: Connection, index_name: str) -> None: def drop_index_if_exists(client: Connection, index_name: str) -> None:
"""Drop an index if it exists.
Args:
client: The OracleDB connection object.
index_name: The name of the index to drop.
Raises:
RuntimeError: If an error occurs while dropping the index.
"""
if _index_exists(client, index_name): if _index_exists(client, index_name):
drop_query = f"DROP INDEX {index_name}" drop_query = f"DROP INDEX {index_name}"
with client.cursor() as cursor: with client.cursor() as cursor:

View File

@ -748,6 +748,14 @@ class VectaraRetriever(VectorStoreRetriever):
class VectaraRAG(Runnable): class VectaraRAG(Runnable):
"""Vectara RAG runnable.
Parameters:
vectara: Vectara object
config: VectaraQueryConfig object
chat: bool, default False
"""
def __init__( def __init__(
self, vectara: Vectara, config: VectaraQueryConfig, chat: bool = False self, vectara: Vectara, config: VectaraQueryConfig, chat: bool = False
): ):
@ -762,10 +770,12 @@ class VectaraRAG(Runnable):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[dict]: ) -> Iterator[dict]:
"""get streaming output from Vectara RAG """Get streaming output from Vectara RAG.
Args: Args:
query: The input query input: The input query
config: RunnableConfig object
kwargs: Any additional arguments
Returns: Returns:
The output dictionary with question, answer and context The output dictionary with question, answer and context