From 76dd7480e6e1769f86afa913aa288068ee94f038 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 19 Sep 2023 16:20:23 -0700 Subject: [PATCH] Add batch_size param to Weaviate vector store (#9890) cc @mcantillon21 @hsm207 @cs0lar --- .../langchain/vectorstores/weaviate.py | 116 ++++++++++++------ 1 file changed, 78 insertions(+), 38 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/weaviate.py b/libs/langchain/langchain/vectorstores/weaviate.py index d2e4de0118..49a1d8f58c 100644 --- a/libs/langchain/langchain/vectorstores/weaviate.py +++ b/libs/langchain/langchain/vectorstores/weaviate.py @@ -1,17 +1,29 @@ from __future__ import annotations import datetime -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type +import os +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, +) from uuid import uuid4 import numpy as np from langchain.docstore.document import Document from langchain.schema.embeddings import Embeddings -from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance +if TYPE_CHECKING: + import weaviate + def _default_schema(index_name: str) -> Dict: return { @@ -25,21 +37,11 @@ def _default_schema(index_name: str) -> Dict: } -def _create_weaviate_client(**kwargs: Any) -> Any: - client = kwargs.get("client") - if client is not None: - return client - - weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL") - - try: - # the weaviate api key param should not be mandatory - weaviate_api_key = get_from_dict_or_env( - kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None - ) - except ValueError: - weaviate_api_key = None - +def _create_weaviate_client( + url: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs: Any, +) -> weaviate.Client: try: import weaviate except ImportError: @@ -47,15 +49,10 @@ def _create_weaviate_client(**kwargs: Any) -> Any: "Could not import weaviate python package. " "Please install it with `pip install weaviate-client`" ) - - auth = ( - weaviate.auth.AuthApiKey(api_key=weaviate_api_key) - if weaviate_api_key is not None - else None - ) - client = weaviate.Client(weaviate_url, auth_client_secret=auth) - - return client + url = url or os.environ.get("WEAVIATE_URL") + api_key = api_key or os.environ.get("WEAVIATE_API_KEY") + auth = weaviate.auth.AuthApiKey(api_key=api_key) if api_key else None + return weaviate.Client(url=url, auth_client_secret=auth, **kwargs) def _default_score_normalizer(val: float) -> float: @@ -78,6 +75,7 @@ class Weaviate(VectorStore): import weaviate from langchain.vectorstores import Weaviate + client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...) weaviate = Weaviate(client, index_name, text_key) @@ -375,10 +373,21 @@ class Weaviate(VectorStore): @classmethod def from_texts( - cls: Type[Weaviate], + cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, + *, + client: Optional[weaviate.Client] = None, + weaviate_url: Optional[str] = None, + weaviate_api_key: Optional[str] = None, + batch_size: Optional[int] = None, + index_name: Optional[str] = None, + text_key: str = "text", + by_text: bool = False, + relevance_score_fn: Optional[ + Callable[[float], float] + ] = _default_score_normalizer, **kwargs: Any, ) -> Weaviate: """Construct Weaviate wrapper from raw documents. @@ -390,11 +399,34 @@ class Weaviate(VectorStore): This is intended to be a quick way to get started. + Args: + texts: Texts to add to vector store. + embedding: Text embedding model to use. + metadatas: Metadata associated with each text. + client: weaviate.Client to use. + weaviate_url: The Weaviate URL. If using Weaviate Cloud Services get it + from the ``Details`` tab. Can be passed in as a named param or by + setting the environment variable ``WEAVIATE_URL``. Should not be + specified if client is provided. + weaviate_api_key: The Weaviate API key. If enabled and using Weaviate Cloud + Services, get it from ``Details`` tab. Can be passed in as a named param + or by setting the environment variable ``WEAVIATE_API_KEY``. Should + not be specified if client is provided. + batch_size: Size of batch operations. + index_name: Index name. + text_key: Key to use for uploading/retrieving text to/from vectorstore. + by_text: Whether to search by text or by embedding. + relevance_score_fn: Function for converting whatever distance function the + vector store uses to a relevance score, which is a normalized similarity + score (0 means dissimilar, 1 means similar). + **kwargs: Additional named parameters to pass to ``Weaviate.__init__()``. + Example: .. code-block:: python - from langchain.vectorstores.weaviate import Weaviate from langchain.embeddings import OpenAIEmbeddings + from langchain.vectorstores import Weaviate + embeddings = OpenAIEmbeddings() weaviate = Weaviate.from_texts( texts, @@ -403,20 +435,30 @@ class Weaviate(VectorStore): ) """ - client = _create_weaviate_client(**kwargs) + try: + from weaviate.util import get_valid_uuid + except ImportError as e: + raise ImportError( + "Could not import weaviate python package. " + "Please install it with `pip install weaviate-client`" + ) from e - from weaviate.util import get_valid_uuid + client = client or _create_weaviate_client( + url=weaviate_url, + api_key=weaviate_api_key, + ) + if batch_size: + client.batch.configure(batch_size=batch_size) - index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}") - embeddings = embedding.embed_documents(texts) if embedding else None - text_key = "text" + index_name = index_name or f"LangChain_{uuid4().hex}" schema = _default_schema(index_name) - attributes = list(metadatas[0].keys()) if metadatas else None - # check whether the index already exists if not client.schema.contains(schema): client.schema.create_class(schema) + embeddings = embedding.embed_documents(texts) if embedding else None + attributes = list(metadatas[0].keys()) if metadatas else None + with client.batch as batch: for i, text in enumerate(texts): data_properties = { @@ -449,9 +491,6 @@ class Weaviate(VectorStore): batch.flush() - relevance_score_fn = kwargs.get("relevance_score_fn") - by_text: bool = kwargs.get("by_text", False) - return cls( client, index_name, @@ -460,6 +499,7 @@ class Weaviate(VectorStore): attributes=attributes, relevance_score_fn=relevance_score_fn, by_text=by_text, + **kwargs, ) def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: