Add batch_size param to Weaviate vector store (#9890)

cc @mcantillon21 @hsm207 @cs0lar
pull/10805/head
Bagatur 12 months ago committed by GitHub
parent 720f6dbaac
commit 76dd7480e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,17 +1,29 @@
from __future__ import annotations from __future__ import annotations
import datetime import datetime
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type import os
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
)
from uuid import uuid4 from uuid import uuid4
import numpy as np import numpy as np
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.schema.embeddings import Embeddings from langchain.schema.embeddings import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance from langchain.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
import weaviate
def _default_schema(index_name: str) -> Dict: def _default_schema(index_name: str) -> Dict:
return { return {
@ -25,21 +37,11 @@ def _default_schema(index_name: str) -> Dict:
} }
def _create_weaviate_client(**kwargs: Any) -> Any: def _create_weaviate_client(
client = kwargs.get("client") url: Optional[str] = None,
if client is not None: api_key: Optional[str] = None,
return client **kwargs: Any,
) -> weaviate.Client:
weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
try:
# the weaviate api key param should not be mandatory
weaviate_api_key = get_from_dict_or_env(
kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
)
except ValueError:
weaviate_api_key = None
try: try:
import weaviate import weaviate
except ImportError: except ImportError:
@ -47,15 +49,10 @@ def _create_weaviate_client(**kwargs: Any) -> Any:
"Could not import weaviate python package. " "Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`" "Please install it with `pip install weaviate-client`"
) )
url = url or os.environ.get("WEAVIATE_URL")
auth = ( api_key = api_key or os.environ.get("WEAVIATE_API_KEY")
weaviate.auth.AuthApiKey(api_key=weaviate_api_key) auth = weaviate.auth.AuthApiKey(api_key=api_key) if api_key else None
if weaviate_api_key is not None return weaviate.Client(url=url, auth_client_secret=auth, **kwargs)
else None
)
client = weaviate.Client(weaviate_url, auth_client_secret=auth)
return client
def _default_score_normalizer(val: float) -> float: def _default_score_normalizer(val: float) -> float:
@ -78,6 +75,7 @@ class Weaviate(VectorStore):
import weaviate import weaviate
from langchain.vectorstores import Weaviate from langchain.vectorstores import Weaviate
client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...) client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
weaviate = Weaviate(client, index_name, text_key) weaviate = Weaviate(client, index_name, text_key)
@ -375,10 +373,21 @@ class Weaviate(VectorStore):
@classmethod @classmethod
def from_texts( def from_texts(
cls: Type[Weaviate], cls,
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
*,
client: Optional[weaviate.Client] = None,
weaviate_url: Optional[str] = None,
weaviate_api_key: Optional[str] = None,
batch_size: Optional[int] = None,
index_name: Optional[str] = None,
text_key: str = "text",
by_text: bool = False,
relevance_score_fn: Optional[
Callable[[float], float]
] = _default_score_normalizer,
**kwargs: Any, **kwargs: Any,
) -> Weaviate: ) -> Weaviate:
"""Construct Weaviate wrapper from raw documents. """Construct Weaviate wrapper from raw documents.
@ -390,11 +399,34 @@ class Weaviate(VectorStore):
This is intended to be a quick way to get started. This is intended to be a quick way to get started.
Args:
texts: Texts to add to vector store.
embedding: Text embedding model to use.
metadatas: Metadata associated with each text.
client: weaviate.Client to use.
weaviate_url: The Weaviate URL. If using Weaviate Cloud Services get it
from the ``Details`` tab. Can be passed in as a named param or by
setting the environment variable ``WEAVIATE_URL``. Should not be
specified if client is provided.
weaviate_api_key: The Weaviate API key. If enabled and using Weaviate Cloud
Services, get it from ``Details`` tab. Can be passed in as a named param
or by setting the environment variable ``WEAVIATE_API_KEY``. Should
not be specified if client is provided.
batch_size: Size of batch operations.
index_name: Index name.
text_key: Key to use for uploading/retrieving text to/from vectorstore.
by_text: Whether to search by text or by embedding.
relevance_score_fn: Function for converting whatever distance function the
vector store uses to a relevance score, which is a normalized similarity
score (0 means dissimilar, 1 means similar).
**kwargs: Additional named parameters to pass to ``Weaviate.__init__()``.
Example: Example:
.. code-block:: python .. code-block:: python
from langchain.vectorstores.weaviate import Weaviate
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Weaviate
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
weaviate = Weaviate.from_texts( weaviate = Weaviate.from_texts(
texts, texts,
@ -403,20 +435,30 @@ class Weaviate(VectorStore):
) )
""" """
client = _create_weaviate_client(**kwargs) try:
from weaviate.util import get_valid_uuid
except ImportError as e:
raise ImportError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`"
) from e
from weaviate.util import get_valid_uuid client = client or _create_weaviate_client(
url=weaviate_url,
api_key=weaviate_api_key,
)
if batch_size:
client.batch.configure(batch_size=batch_size)
index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}") index_name = index_name or f"LangChain_{uuid4().hex}"
embeddings = embedding.embed_documents(texts) if embedding else None
text_key = "text"
schema = _default_schema(index_name) schema = _default_schema(index_name)
attributes = list(metadatas[0].keys()) if metadatas else None
# check whether the index already exists # check whether the index already exists
if not client.schema.contains(schema): if not client.schema.contains(schema):
client.schema.create_class(schema) client.schema.create_class(schema)
embeddings = embedding.embed_documents(texts) if embedding else None
attributes = list(metadatas[0].keys()) if metadatas else None
with client.batch as batch: with client.batch as batch:
for i, text in enumerate(texts): for i, text in enumerate(texts):
data_properties = { data_properties = {
@ -449,9 +491,6 @@ class Weaviate(VectorStore):
batch.flush() batch.flush()
relevance_score_fn = kwargs.get("relevance_score_fn")
by_text: bool = kwargs.get("by_text", False)
return cls( return cls(
client, client,
index_name, index_name,
@ -460,6 +499,7 @@ class Weaviate(VectorStore):
attributes=attributes, attributes=attributes,
relevance_score_fn=relevance_score_fn, relevance_score_fn=relevance_score_fn,
by_text=by_text, by_text=by_text,
**kwargs,
) )
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:

Loading…
Cancel
Save