rm mutable defaults (#9974)

This commit is contained in:
Bagatur 2023-08-29 20:36:27 -07:00 committed by GitHub
parent 6a51672164
commit d762a6b51f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 26 additions and 21 deletions

View File

@ -242,7 +242,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self,
name: Optional[str] = "langchainrun-%",
experiment: Optional[str] = "langchain",
tags: Optional[Dict] = {},
tags: Optional[Dict] = None,
tracking_uri: Optional[str] = None,
) -> None:
"""Initialize callback handler."""
@ -254,7 +254,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.name = name
self.experiment = experiment
self.tags = tags
self.tags = tags or {}
self.tracking_uri = tracking_uri
self.temp_dir = tempfile.TemporaryDirectory()

View File

@ -40,12 +40,12 @@ class PromptLayerCallbackHandler(BaseCallbackHandler):
def __init__(
self,
pl_id_callback: Optional[Callable[..., Any]] = None,
pl_tags: Optional[List[str]] = [],
pl_tags: Optional[List[str]] = None,
) -> None:
"""Initialize the PromptLayerCallbackHandler."""
_lazy_import_promptlayer()
self.pl_id_callback = pl_id_callback
self.pl_tags = pl_tags
self.pl_tags = pl_tags or []
self.runs: Dict[UUID, Dict[str, Any]] = {}
def on_chat_model_start(

View File

@ -33,7 +33,7 @@ class AsyncHtmlLoader(BaseLoader):
verify_ssl: Optional[bool] = True,
proxies: Optional[dict] = None,
requests_per_second: int = 2,
requests_kwargs: Dict[str, Any] = {},
requests_kwargs: Optional[Dict[str, Any]] = None,
raise_for_status: bool = False,
):
"""Initialize with a webpage path."""
@ -67,7 +67,7 @@ class AsyncHtmlLoader(BaseLoader):
self.session.proxies.update(proxies)
self.requests_per_second = requests_per_second
self.requests_kwargs = requests_kwargs
self.requests_kwargs = requests_kwargs or {}
self.raise_for_status = raise_for_status
async def _fetch(

View File

@ -1,6 +1,6 @@
import logging
from string import Template
from typing import Any, Dict
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
@ -106,11 +106,12 @@ class NebulaGraph:
"""Returns the schema of the NebulaGraph database"""
return self.schema
def execute(self, query: str, params: dict = {}, retry: int = 0) -> Any:
def execute(self, query: str, params: Optional[dict] = None, retry: int = 0) -> Any:
"""Query NebulaGraph database."""
from nebula3.Exception import IOErrorException, NoValidSessionException
from nebula3.fbthrift.transport.TTransport import TTransportException
params = params or {}
try:
result = self.session_pool.execute_parameter(query, params)
if not result.is_succeeded():

View File

@ -183,9 +183,10 @@ def make_request(
instruction: str,
conversation: str,
url: str = f"{DEFAULT_NEBULA_SERVICE_URL}{DEFAULT_NEBULA_SERVICE_PATH}",
params: Dict = {},
params: Optional[Dict] = None,
) -> Any:
"""Generate text from the model."""
params = params or {}
headers = {
"Content-Type": "application/json",
"ApiKey": f"{self.nebula_api_key}",

View File

@ -372,10 +372,10 @@ class Marqo(VectorStore):
index_name: str = "",
url: str = "http://localhost:8882",
api_key: str = "",
add_documents_settings: Optional[Dict[str, Any]] = {},
add_documents_settings: Optional[Dict[str, Any]] = None,
searchable_attributes: Optional[List[str]] = None,
page_content_builder: Optional[Callable[[Dict[str, str]], str]] = None,
index_settings: Optional[Dict[str, Any]] = {},
index_settings: Optional[Dict[str, Any]] = None,
verbose: bool = True,
**kwargs: Any,
) -> Marqo:
@ -435,7 +435,7 @@ class Marqo(VectorStore):
client = marqo.Client(url=url, api_key=api_key)
try:
client.create_index(index_name, settings_dict=index_settings)
client.create_index(index_name, settings_dict=index_settings or {})
if verbose:
print(f"Created {index_name} successfully.")
except Exception:
@ -446,7 +446,7 @@ class Marqo(VectorStore):
client,
index_name,
searchable_attributes=searchable_attributes,
add_documents_settings=add_documents_settings,
add_documents_settings=add_documents_settings or {},
page_content_builder=page_content_builder,
)
instance.add_texts(texts, metadatas)

View File

@ -991,7 +991,7 @@ class Redis(VectorStore):
self,
k: int,
filter: Optional[RedisFilterExpression] = None,
return_fields: List[str] = [],
return_fields: Optional[List[str]] = None,
) -> "Query":
try:
from redis.commands.search.query import Query
@ -1000,6 +1000,7 @@ class Redis(VectorStore):
"Could not import redis python package. "
"Please install it with `pip install redis`."
) from e
return_fields = return_fields or []
vector_key = self._schema.content_vector_key
base_query = f"@{vector_key}:[VECTOR_RANGE $distance_threshold $vector]"
@ -1020,7 +1021,7 @@ class Redis(VectorStore):
self,
k: int,
filter: Optional[RedisFilterExpression] = None,
return_fields: List[str] = [],
return_fields: Optional[List[str]] = None,
) -> "Query":
"""Prepare query for vector search.
@ -1038,6 +1039,7 @@ class Redis(VectorStore):
"Could not import redis python package. "
"Please install it with `pip install redis`."
) from e
return_fields = return_fields or []
query_prefix = "*"
if filter:
query_prefix = f"{str(filter)}"

View File

@ -345,8 +345,9 @@ class SingleStoreDB(VectorStore):
def build_where_clause(
where_clause_values: List[Any],
sub_filter: dict,
prefix_args: List[str] = [],
prefix_args: Optional[List[str]] = None,
) -> None:
prefix_args = prefix_args or []
for key in sub_filter.keys():
if isinstance(sub_filter[key], dict):
build_where_clause(

View File

@ -463,7 +463,7 @@ class VectaraRetriever(VectorStoreRetriever):
self,
texts: List[str],
metadatas: Optional[List[dict]] = None,
doc_metadata: Optional[dict] = {},
doc_metadata: Optional[dict] = None,
) -> None:
"""Add text to the Vectara vectorstore.
@ -471,4 +471,4 @@ class VectaraRetriever(VectorStoreRetriever):
texts (List[str]): The text
metadatas (List[dict]): Metadata dicts, must line up with existing store
"""
self.vectorstore.add_texts(texts, metadatas, doc_metadata)
self.vectorstore.add_texts(texts, metadatas, doc_metadata or {})

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.milvus import Milvus
@ -140,7 +140,7 @@ class Zilliz(Milvus):
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
collection_name: str = "LangChainCollection",
connection_args: dict[str, Any] = {},
connection_args: Optional[Dict[str, Any]] = None,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
@ -173,7 +173,7 @@ class Zilliz(Milvus):
vector_db = cls(
embedding_function=embedding,
collection_name=collection_name,
connection_args=connection_args,
connection_args=connection_args or {},
consistency_level=consistency_level,
index_params=index_params,
search_params=search_params,