diff --git a/langchain/vectorstores/weaviate.py b/langchain/vectorstores/weaviate.py index dd5e79ca..0ad33b1a 100644 --- a/langchain/vectorstores/weaviate.py +++ b/langchain/vectorstores/weaviate.py @@ -25,6 +25,35 @@ def _default_schema(index_name: str) -> Dict: } +def _create_weaviate_client(**kwargs: Any) -> Any: + client = kwargs.get("client") + + if client is not None: + return client + + weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL") + weaviate_api_key = get_from_dict_or_env( + kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None + ) + + try: + import weaviate + except ImportError: + raise ValueError( + "Could not import weaviate python package. " + "Please install it with `pip instal weaviate-client`" + ) + + auth = ( + weaviate.auth.AuthApiKey(api_key=weaviate_api_key) + if weaviate_api_key is not None + else None + ) + client = weaviate.Client(weaviate_url, auth_client_secret=auth) + + return client + + class Weaviate(VectorStore): """Wrapper around Weaviate vector database. @@ -248,18 +277,11 @@ class Weaviate(VectorStore): weaviate_url="http://localhost:8080" ) """ - weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL") - try: - from weaviate import Client - from weaviate.util import get_valid_uuid - except ImportError: - raise ValueError( - "Could not import weaviate python package. " - "Please install it with `pip instal weaviate-client`" - ) + client = _create_weaviate_client(**kwargs) + + from weaviate.util import get_valid_uuid - client = Client(weaviate_url) index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}") embeddings = embedding.embed_documents(texts) if embedding else None text_key = "text"