diff --git a/libs/community/langchain_community/chat_loaders/imessage.py b/libs/community/langchain_community/chat_loaders/imessage.py index b8d4e610d6..e69f0a0425 100644 --- a/libs/community/langchain_community/chat_loaders/imessage.py +++ b/libs/community/langchain_community/chat_loaders/imessage.py @@ -13,6 +13,14 @@ if TYPE_CHECKING: def nanoseconds_from_2001_to_datetime(nanoseconds: int) -> datetime: + """Convert nanoseconds since 2001 to a datetime object. + + Args: + nanoseconds (int): Nanoseconds since January 1, 2001. + + Returns: + datetime: Datetime object. + """ # Convert nanoseconds to seconds (1 second = 1e9 nanoseconds) timestamp_in_seconds = nanoseconds / 1e9 diff --git a/libs/community/langchain_community/chat_message_histories/zep_cloud.py b/libs/community/langchain_community/chat_message_histories/zep_cloud.py index a97cfcfbcc..0fc36b737d 100644 --- a/libs/community/langchain_community/chat_message_histories/zep_cloud.py +++ b/libs/community/langchain_community/chat_message_histories/zep_cloud.py @@ -26,6 +26,14 @@ logger = logging.getLogger(__name__) def condense_zep_memory_into_human_message(zep_memory: Memory) -> BaseMessage: + """Condense Zep memory into a human message. + + Args: + zep_memory: The Zep memory object. + + Returns: + BaseMessage: The human message. + """ prompt = "" if zep_memory.facts: prompt = "\n".join(zep_memory.facts) @@ -37,6 +45,16 @@ def condense_zep_memory_into_human_message(zep_memory: Memory) -> BaseMessage: def get_zep_message_role_type(role: str) -> RoleType: + """Get the Zep role type from the role string. + + Args: + role: The role string. One of "human", "ai", "system", + "function", "tool". + + Returns: + RoleType: The Zep role type. One of "user", "assistant", + "system", "function", "tool". + """ if role == "human": return "user" elif role == "ai": diff --git a/libs/community/langchain_community/chat_models/minimax.py b/libs/community/langchain_community/chat_models/minimax.py index d79e3499a6..8761b04d39 100644 --- a/libs/community/langchain_community/chat_models/minimax.py +++ b/libs/community/langchain_community/chat_models/minimax.py @@ -1,4 +1,5 @@ """Wrapper around Minimax chat models.""" + import json import logging from contextlib import asynccontextmanager, contextmanager @@ -32,6 +33,17 @@ logger = logging.getLogger(__name__) @contextmanager def connect_httpx_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator: + """Context manager for connecting to an SSE stream. + + Args: + client: The httpx client. + method: The HTTP method. + url: The URL to connect to. + kwargs: Additional keyword arguments to pass to the client. + + Yields: + An EventSource object. + """ from httpx_sse import EventSource with client.stream(method, url, **kwargs) as response: @@ -42,6 +54,17 @@ def connect_httpx_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iter async def aconnect_httpx_sse( client: Any, method: str, url: str, **kwargs: Any ) -> AsyncIterator: + """Async context manager for connecting to an SSE stream. + + Args: + client: The httpx client. + method: The HTTP method. + url: The URL to connect to. + kwargs: Additional keyword arguments to pass to the client. + + Yields: + An EventSource object. + """ from httpx_sse import EventSource async with client.stream(method, url, **kwargs) as response: diff --git a/libs/community/langchain_community/chat_models/zhipuai.py b/libs/community/langchain_community/chat_models/zhipuai.py index e13851f0e7..062878ed2d 100644 --- a/libs/community/langchain_community/chat_models/zhipuai.py +++ b/libs/community/langchain_community/chat_models/zhipuai.py @@ -42,6 +42,17 @@ ZHIPUAI_API_BASE = "https://open.bigmodel.cn/api/paas/v4/chat/completions" @contextmanager def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator: + """Context manager for connecting to an SSE stream. + + Args: + client: The HTTP client. + method: The HTTP method. + url: The URL. + **kwargs: Additional keyword arguments. + + Yields: + The event source. + """ from httpx_sse import EventSource with client.stream(method, url, **kwargs) as response: @@ -52,6 +63,17 @@ def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator: async def aconnect_sse( client: Any, method: str, url: str, **kwargs: Any ) -> AsyncIterator: + """Async context manager for connecting to an SSE stream. + + Args: + client: The HTTP client. + method: The HTTP method. + url: The URL. + **kwargs: Additional keyword arguments. + + Yields: + The event source. + """ from httpx_sse import EventSource async with client.stream(method, url, **kwargs) as response: @@ -59,7 +81,9 @@ async def aconnect_sse( def _get_jwt_token(api_key: str) -> str: - """Gets JWT token for ZhipuAI API, see 'https://open.bigmodel.cn/dev/api#nosdk'. + """Gets JWT token for ZhipuAI API. + + See 'https://open.bigmodel.cn/dev/api#nosdk'. Args: api_key: The API key for ZhipuAI API. diff --git a/libs/community/langchain_community/embeddings/jina.py b/libs/community/langchain_community/embeddings/jina.py index d62a36924f..73f75b24b8 100644 --- a/libs/community/langchain_community/embeddings/jina.py +++ b/libs/community/langchain_community/embeddings/jina.py @@ -12,6 +12,14 @@ JINA_API_URL: str = "https://api.jina.ai/v1/embeddings" def is_local(url: str) -> bool: + """Check if a URL is a local file. + + Args: + url (str): The URL to check. + + Returns: + bool: True if the URL is a local file, False otherwise. + """ url_parsed = urlparse(url) if url_parsed.scheme in ("file", ""): # Possibly a local file return exists(url_parsed.path) @@ -19,6 +27,14 @@ def is_local(url: str) -> bool: def get_bytes_str(file_path: str) -> str: + """Get the bytes string of a file. + + Args: + file_path (str): The path to the file. + + Returns: + str: The bytes string of the file. + """ with open(file_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") diff --git a/libs/community/langchain_community/embeddings/sparkllm.py b/libs/community/langchain_community/embeddings/sparkllm.py index fe82f9b312..d8578810f8 100644 --- a/libs/community/langchain_community/embeddings/sparkllm.py +++ b/libs/community/langchain_community/embeddings/sparkllm.py @@ -30,6 +30,8 @@ logger = logging.getLogger(__name__) class Url: + """URL class for parsing the URL.""" + def __init__(self, host: str, path: str, schema: str) -> None: self.host = host self.path = path diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py index a8b7bf6b0b..cd2791d646 100644 --- a/libs/community/langchain_community/graphs/neo4j_graph.py +++ b/libs/community/langchain_community/graphs/neo4j_graph.py @@ -52,6 +52,16 @@ include_docs_query = ( def clean_string_values(text: str) -> str: + """Clean string values for schema. + + Cleans the input text by replacing newline and carriage return characters. + + Args: + text (str): The input text to clean. + + Returns: + str: The cleaned text. + """ return text.replace("\n", " ").replace("\r", " ") @@ -63,6 +73,12 @@ def value_sanitize(d: Any) -> Any: generating answers in a LLM context. These properties, if left in results, can occupy significant context space and detract from the LLM's performance by introducing unnecessary noise and cost. + + Args: + d (Any): The input dictionary or list to sanitize. + + Returns: + Any: The sanitized dictionary or list. """ if isinstance(d, dict): new_dict = {} @@ -382,7 +398,15 @@ class Neo4jGraph(GraphStore): return self.structured_schema def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: - """Query Neo4j database.""" + """Query Neo4j database. + + Args: + query (str): The Cypher query to execute. + params (dict): The parameters to pass to the query. + + Returns: + List[Dict[str, Any]]: The list of dictionaries containing the query results. + """ from neo4j import Query from neo4j.exceptions import CypherSyntaxError diff --git a/libs/community/langchain_community/storage/cassandra.py b/libs/community/langchain_community/storage/cassandra.py index 280ce5b3a5..d2d97a3557 100644 --- a/libs/community/langchain_community/storage/cassandra.py +++ b/libs/community/langchain_community/storage/cassandra.py @@ -35,6 +35,17 @@ DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?; class CassandraByteStore(ByteStore): + """A ByteStore implementation using Cassandra as the backend. + + Parameters: + table: The name of the table to use. + session: A Cassandra session object. If not provided, it will be resolved + from the cassio config. + keyspace: The keyspace to use. If not provided, it will be resolved + from the cassio config. + setup_mode: The setup mode to use. Default is SYNC (SetupMode.SYNC). + """ + def __init__( self, table: str, @@ -75,6 +86,7 @@ class CassandraByteStore(ByteStore): self.session.execute(create_cql) def ensure_db_setup(self) -> None: + """Ensure that the DB setup is finished. If not, raise a ValueError.""" if self.db_setup_task: try: self.db_setup_task.result() @@ -86,10 +98,17 @@ class CassandraByteStore(ByteStore): ) async def aensure_db_setup(self) -> None: + """Ensure that the DB setup is finished. If not, wait for it.""" if self.db_setup_task: await self.db_setup_task def get_select_statement(self) -> PreparedStatement: + """Get the prepared select statement for the table. + If not available, prepare it. + + Returns: + PreparedStatement: The prepared statement. + """ if not self.select_statement: self.select_statement = self.session.prepare( SELECT_TABLE_CQL_TEMPLATE.format( @@ -99,6 +118,12 @@ class CassandraByteStore(ByteStore): return self.select_statement def get_insert_statement(self) -> PreparedStatement: + """Get the prepared insert statement for the table. + If not available, prepare it. + + Returns: + PreparedStatement: The prepared statement. + """ if not self.insert_statement: self.insert_statement = self.session.prepare( INSERT_TABLE_CQL_TEMPLATE.format( @@ -108,6 +133,13 @@ class CassandraByteStore(ByteStore): return self.insert_statement def get_delete_statement(self) -> PreparedStatement: + """Get the prepared delete statement for the table. + If not available, prepare it. + + Returns: + PreparedStatement: The prepared statement. + """ + if not self.delete_statement: self.delete_statement = self.session.prepare( DELETE_TABLE_CQL_TEMPLATE.format( diff --git a/libs/community/langchain_community/utilities/cassandra.py b/libs/community/langchain_community/utilities/cassandra.py index 52b0963c89..7aa2d66ba3 100644 --- a/libs/community/langchain_community/utilities/cassandra.py +++ b/libs/community/langchain_community/utilities/cassandra.py @@ -36,6 +36,16 @@ async def wrapped_response_future( async def aexecute_cql(session: Session, query: str, **kwargs: Any) -> Any: + """Execute a CQL query asynchronously. + + Args: + session: The Cassandra session to use. + query: The CQL query to execute. + **kwargs: Additional keyword arguments to pass to the session execute method. + + Returns: + The result of the query. + """ return await wrapped_response_future(session.execute_async, query, **kwargs) diff --git a/libs/community/langchain_community/vectorstores/oraclevs.py b/libs/community/langchain_community/vectorstores/oraclevs.py index 0ebf027e82..ab7da5c8ac 100644 --- a/libs/community/langchain_community/vectorstores/oraclevs.py +++ b/libs/community/langchain_community/vectorstores/oraclevs.py @@ -173,6 +173,16 @@ def create_index( vector_store: OracleVS, params: Optional[dict[str, Any]] = None, ) -> None: + """Create an index on the vector store. + + Args: + client: The OracleDB connection object. + vector_store: The vector store object. + params: Optional parameters for the index creation. + + Raises: + ValueError: If an invalid parameter is provided. + """ if params: if params["idx_type"] == "HNSW": _create_hnsw_index( @@ -351,6 +361,15 @@ def _create_ivf_index( @_handle_exceptions def drop_table_purge(client: Connection, table_name: str) -> None: + """Drop a table and purge it from the database. + + Args: + client: The OracleDB connection object. + table_name: The name of the table to drop. + + Raises: + RuntimeError: If an error occurs while dropping the table. + """ if _table_exists(client, table_name): cursor = client.cursor() with cursor: @@ -364,6 +383,15 @@ def drop_table_purge(client: Connection, table_name: str) -> None: @_handle_exceptions def drop_index_if_exists(client: Connection, index_name: str) -> None: + """Drop an index if it exists. + + Args: + client: The OracleDB connection object. + index_name: The name of the index to drop. + + Raises: + RuntimeError: If an error occurs while dropping the index. + """ if _index_exists(client, index_name): drop_query = f"DROP INDEX {index_name}" with client.cursor() as cursor: diff --git a/libs/community/langchain_community/vectorstores/vectara.py b/libs/community/langchain_community/vectorstores/vectara.py index b78cc8cc85..49dd94e181 100644 --- a/libs/community/langchain_community/vectorstores/vectara.py +++ b/libs/community/langchain_community/vectorstores/vectara.py @@ -748,6 +748,14 @@ class VectaraRetriever(VectorStoreRetriever): class VectaraRAG(Runnable): + """Vectara RAG runnable. + + Parameters: + vectara: Vectara object + config: VectaraQueryConfig object + chat: bool, default False + """ + def __init__( self, vectara: Vectara, config: VectaraQueryConfig, chat: bool = False ): @@ -762,10 +770,12 @@ class VectaraRAG(Runnable): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Iterator[dict]: - """get streaming output from Vectara RAG + """Get streaming output from Vectara RAG. Args: - query: The input query + input: The input query + config: RunnableConfig object + kwargs: Any additional arguments Returns: The output dictionary with question, answer and context