From d762a6b51feb3d1f8f60a9a9ccdb7d432aa07b53 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 29 Aug 2023 20:36:27 -0700 Subject: [PATCH] rm mutable defaults (#9974) --- libs/langchain/langchain/callbacks/mlflow_callback.py | 4 ++-- .../langchain/langchain/callbacks/promptlayer_callback.py | 4 ++-- libs/langchain/langchain/document_loaders/async_html.py | 4 ++-- libs/langchain/langchain/graphs/nebula_graph.py | 5 +++-- libs/langchain/langchain/llms/symblai_nebula.py | 3 ++- libs/langchain/langchain/vectorstores/marqo.py | 8 ++++---- libs/langchain/langchain/vectorstores/redis/base.py | 6 ++++-- libs/langchain/langchain/vectorstores/singlestoredb.py | 3 ++- libs/langchain/langchain/vectorstores/vectara.py | 4 ++-- libs/langchain/langchain/vectorstores/zilliz.py | 6 +++--- 10 files changed, 26 insertions(+), 21 deletions(-) diff --git a/libs/langchain/langchain/callbacks/mlflow_callback.py b/libs/langchain/langchain/callbacks/mlflow_callback.py index baa297af54..c51db69bf0 100644 --- a/libs/langchain/langchain/callbacks/mlflow_callback.py +++ b/libs/langchain/langchain/callbacks/mlflow_callback.py @@ -242,7 +242,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self, name: Optional[str] = "langchainrun-%", experiment: Optional[str] = "langchain", - tags: Optional[Dict] = {}, + tags: Optional[Dict] = None, tracking_uri: Optional[str] = None, ) -> None: """Initialize callback handler.""" @@ -254,7 +254,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.name = name self.experiment = experiment - self.tags = tags + self.tags = tags or {} self.tracking_uri = tracking_uri self.temp_dir = tempfile.TemporaryDirectory() diff --git a/libs/langchain/langchain/callbacks/promptlayer_callback.py b/libs/langchain/langchain/callbacks/promptlayer_callback.py index bd93d70879..749bb6b006 100644 --- a/libs/langchain/langchain/callbacks/promptlayer_callback.py +++ b/libs/langchain/langchain/callbacks/promptlayer_callback.py @@ -40,12 +40,12 @@ class PromptLayerCallbackHandler(BaseCallbackHandler): def __init__( self, pl_id_callback: Optional[Callable[..., Any]] = None, - pl_tags: Optional[List[str]] = [], + pl_tags: Optional[List[str]] = None, ) -> None: """Initialize the PromptLayerCallbackHandler.""" _lazy_import_promptlayer() self.pl_id_callback = pl_id_callback - self.pl_tags = pl_tags + self.pl_tags = pl_tags or [] self.runs: Dict[UUID, Dict[str, Any]] = {} def on_chat_model_start( diff --git a/libs/langchain/langchain/document_loaders/async_html.py b/libs/langchain/langchain/document_loaders/async_html.py index ce54573ff9..286319a5ee 100644 --- a/libs/langchain/langchain/document_loaders/async_html.py +++ b/libs/langchain/langchain/document_loaders/async_html.py @@ -33,7 +33,7 @@ class AsyncHtmlLoader(BaseLoader): verify_ssl: Optional[bool] = True, proxies: Optional[dict] = None, requests_per_second: int = 2, - requests_kwargs: Dict[str, Any] = {}, + requests_kwargs: Optional[Dict[str, Any]] = None, raise_for_status: bool = False, ): """Initialize with a webpage path.""" @@ -67,7 +67,7 @@ class AsyncHtmlLoader(BaseLoader): self.session.proxies.update(proxies) self.requests_per_second = requests_per_second - self.requests_kwargs = requests_kwargs + self.requests_kwargs = requests_kwargs or {} self.raise_for_status = raise_for_status async def _fetch( diff --git a/libs/langchain/langchain/graphs/nebula_graph.py b/libs/langchain/langchain/graphs/nebula_graph.py index c74b9fb854..8a031e372b 100644 --- a/libs/langchain/langchain/graphs/nebula_graph.py +++ b/libs/langchain/langchain/graphs/nebula_graph.py @@ -1,6 +1,6 @@ import logging from string import Template -from typing import Any, Dict +from typing import Any, Dict, Optional logger = logging.getLogger(__name__) @@ -106,11 +106,12 @@ class NebulaGraph: """Returns the schema of the NebulaGraph database""" return self.schema - def execute(self, query: str, params: dict = {}, retry: int = 0) -> Any: + def execute(self, query: str, params: Optional[dict] = None, retry: int = 0) -> Any: """Query NebulaGraph database.""" from nebula3.Exception import IOErrorException, NoValidSessionException from nebula3.fbthrift.transport.TTransport import TTransportException + params = params or {} try: result = self.session_pool.execute_parameter(query, params) if not result.is_succeeded(): diff --git a/libs/langchain/langchain/llms/symblai_nebula.py b/libs/langchain/langchain/llms/symblai_nebula.py index d368c3a1a0..8d33e1a42c 100644 --- a/libs/langchain/langchain/llms/symblai_nebula.py +++ b/libs/langchain/langchain/llms/symblai_nebula.py @@ -183,9 +183,10 @@ def make_request( instruction: str, conversation: str, url: str = f"{DEFAULT_NEBULA_SERVICE_URL}{DEFAULT_NEBULA_SERVICE_PATH}", - params: Dict = {}, + params: Optional[Dict] = None, ) -> Any: """Generate text from the model.""" + params = params or {} headers = { "Content-Type": "application/json", "ApiKey": f"{self.nebula_api_key}", diff --git a/libs/langchain/langchain/vectorstores/marqo.py b/libs/langchain/langchain/vectorstores/marqo.py index b18731e08c..43f7172071 100644 --- a/libs/langchain/langchain/vectorstores/marqo.py +++ b/libs/langchain/langchain/vectorstores/marqo.py @@ -372,10 +372,10 @@ class Marqo(VectorStore): index_name: str = "", url: str = "http://localhost:8882", api_key: str = "", - add_documents_settings: Optional[Dict[str, Any]] = {}, + add_documents_settings: Optional[Dict[str, Any]] = None, searchable_attributes: Optional[List[str]] = None, page_content_builder: Optional[Callable[[Dict[str, str]], str]] = None, - index_settings: Optional[Dict[str, Any]] = {}, + index_settings: Optional[Dict[str, Any]] = None, verbose: bool = True, **kwargs: Any, ) -> Marqo: @@ -435,7 +435,7 @@ class Marqo(VectorStore): client = marqo.Client(url=url, api_key=api_key) try: - client.create_index(index_name, settings_dict=index_settings) + client.create_index(index_name, settings_dict=index_settings or {}) if verbose: print(f"Created {index_name} successfully.") except Exception: @@ -446,7 +446,7 @@ class Marqo(VectorStore): client, index_name, searchable_attributes=searchable_attributes, - add_documents_settings=add_documents_settings, + add_documents_settings=add_documents_settings or {}, page_content_builder=page_content_builder, ) instance.add_texts(texts, metadatas) diff --git a/libs/langchain/langchain/vectorstores/redis/base.py b/libs/langchain/langchain/vectorstores/redis/base.py index f3e966d3c1..a09ba44cba 100644 --- a/libs/langchain/langchain/vectorstores/redis/base.py +++ b/libs/langchain/langchain/vectorstores/redis/base.py @@ -991,7 +991,7 @@ class Redis(VectorStore): self, k: int, filter: Optional[RedisFilterExpression] = None, - return_fields: List[str] = [], + return_fields: Optional[List[str]] = None, ) -> "Query": try: from redis.commands.search.query import Query @@ -1000,6 +1000,7 @@ class Redis(VectorStore): "Could not import redis python package. " "Please install it with `pip install redis`." ) from e + return_fields = return_fields or [] vector_key = self._schema.content_vector_key base_query = f"@{vector_key}:[VECTOR_RANGE $distance_threshold $vector]" @@ -1020,7 +1021,7 @@ class Redis(VectorStore): self, k: int, filter: Optional[RedisFilterExpression] = None, - return_fields: List[str] = [], + return_fields: Optional[List[str]] = None, ) -> "Query": """Prepare query for vector search. @@ -1038,6 +1039,7 @@ class Redis(VectorStore): "Could not import redis python package. " "Please install it with `pip install redis`." ) from e + return_fields = return_fields or [] query_prefix = "*" if filter: query_prefix = f"{str(filter)}" diff --git a/libs/langchain/langchain/vectorstores/singlestoredb.py b/libs/langchain/langchain/vectorstores/singlestoredb.py index 983f3f7f03..35a955807f 100644 --- a/libs/langchain/langchain/vectorstores/singlestoredb.py +++ b/libs/langchain/langchain/vectorstores/singlestoredb.py @@ -345,8 +345,9 @@ class SingleStoreDB(VectorStore): def build_where_clause( where_clause_values: List[Any], sub_filter: dict, - prefix_args: List[str] = [], + prefix_args: Optional[List[str]] = None, ) -> None: + prefix_args = prefix_args or [] for key in sub_filter.keys(): if isinstance(sub_filter[key], dict): build_where_clause( diff --git a/libs/langchain/langchain/vectorstores/vectara.py b/libs/langchain/langchain/vectorstores/vectara.py index 9af1124648..457511b104 100644 --- a/libs/langchain/langchain/vectorstores/vectara.py +++ b/libs/langchain/langchain/vectorstores/vectara.py @@ -463,7 +463,7 @@ class VectaraRetriever(VectorStoreRetriever): self, texts: List[str], metadatas: Optional[List[dict]] = None, - doc_metadata: Optional[dict] = {}, + doc_metadata: Optional[dict] = None, ) -> None: """Add text to the Vectara vectorstore. @@ -471,4 +471,4 @@ class VectaraRetriever(VectorStoreRetriever): texts (List[str]): The text metadatas (List[dict]): Metadata dicts, must line up with existing store """ - self.vectorstore.add_texts(texts, metadatas, doc_metadata) + self.vectorstore.add_texts(texts, metadatas, doc_metadata or {}) diff --git a/libs/langchain/langchain/vectorstores/zilliz.py b/libs/langchain/langchain/vectorstores/zilliz.py index 8a571aca3b..e62bfb1aa8 100644 --- a/libs/langchain/langchain/vectorstores/zilliz.py +++ b/libs/langchain/langchain/vectorstores/zilliz.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from langchain.embeddings.base import Embeddings from langchain.vectorstores.milvus import Milvus @@ -140,7 +140,7 @@ class Zilliz(Milvus): embedding: Embeddings, metadatas: Optional[List[dict]] = None, collection_name: str = "LangChainCollection", - connection_args: dict[str, Any] = {}, + connection_args: Optional[Dict[str, Any]] = None, consistency_level: str = "Session", index_params: Optional[dict] = None, search_params: Optional[dict] = None, @@ -173,7 +173,7 @@ class Zilliz(Milvus): vector_db = cls( embedding_function=embedding, collection_name=collection_name, - connection_args=connection_args, + connection_args=connection_args or {}, consistency_level=consistency_level, index_params=index_params, search_params=search_params,