Add batch_size param to Weaviate vector store (#9890)

cc @mcantillon21 @hsm207 @cs0lar
pull/10805/head
Bagatur 11 months ago committed by GitHub
parent 720f6dbaac
commit 76dd7480e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,17 +1,29 @@
from __future__ import annotations
import datetime
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
import os
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
)
from uuid import uuid4
import numpy as np
from langchain.docstore.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
import weaviate
def _default_schema(index_name: str) -> Dict:
return {
@ -25,21 +37,11 @@ def _default_schema(index_name: str) -> Dict:
}
def _create_weaviate_client(**kwargs: Any) -> Any:
client = kwargs.get("client")
if client is not None:
return client
weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
try:
# the weaviate api key param should not be mandatory
weaviate_api_key = get_from_dict_or_env(
kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
)
except ValueError:
weaviate_api_key = None
def _create_weaviate_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs: Any,
) -> weaviate.Client:
try:
import weaviate
except ImportError:
@ -47,15 +49,10 @@ def _create_weaviate_client(**kwargs: Any) -> Any:
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`"
)
auth = (
weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
if weaviate_api_key is not None
else None
)
client = weaviate.Client(weaviate_url, auth_client_secret=auth)
return client
url = url or os.environ.get("WEAVIATE_URL")
api_key = api_key or os.environ.get("WEAVIATE_API_KEY")
auth = weaviate.auth.AuthApiKey(api_key=api_key) if api_key else None
return weaviate.Client(url=url, auth_client_secret=auth, **kwargs)
def _default_score_normalizer(val: float) -> float:
@ -78,6 +75,7 @@ class Weaviate(VectorStore):
import weaviate
from langchain.vectorstores import Weaviate
client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
weaviate = Weaviate(client, index_name, text_key)
@ -375,10 +373,21 @@ class Weaviate(VectorStore):
@classmethod
def from_texts(
cls: Type[Weaviate],
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
*,
client: Optional[weaviate.Client] = None,
weaviate_url: Optional[str] = None,
weaviate_api_key: Optional[str] = None,
batch_size: Optional[int] = None,
index_name: Optional[str] = None,
text_key: str = "text",
by_text: bool = False,
relevance_score_fn: Optional[
Callable[[float], float]
] = _default_score_normalizer,
**kwargs: Any,
) -> Weaviate:
"""Construct Weaviate wrapper from raw documents.
@ -390,11 +399,34 @@ class Weaviate(VectorStore):
This is intended to be a quick way to get started.
Args:
texts: Texts to add to vector store.
embedding: Text embedding model to use.
metadatas: Metadata associated with each text.
client: weaviate.Client to use.
weaviate_url: The Weaviate URL. If using Weaviate Cloud Services get it
from the ``Details`` tab. Can be passed in as a named param or by
setting the environment variable ``WEAVIATE_URL``. Should not be
specified if client is provided.
weaviate_api_key: The Weaviate API key. If enabled and using Weaviate Cloud
Services, get it from ``Details`` tab. Can be passed in as a named param
or by setting the environment variable ``WEAVIATE_API_KEY``. Should
not be specified if client is provided.
batch_size: Size of batch operations.
index_name: Index name.
text_key: Key to use for uploading/retrieving text to/from vectorstore.
by_text: Whether to search by text or by embedding.
relevance_score_fn: Function for converting whatever distance function the
vector store uses to a relevance score, which is a normalized similarity
score (0 means dissimilar, 1 means similar).
**kwargs: Additional named parameters to pass to ``Weaviate.__init__()``.
Example:
.. code-block:: python
from langchain.vectorstores.weaviate import Weaviate
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Weaviate
embeddings = OpenAIEmbeddings()
weaviate = Weaviate.from_texts(
texts,
@ -403,20 +435,30 @@ class Weaviate(VectorStore):
)
"""
client = _create_weaviate_client(**kwargs)
try:
from weaviate.util import get_valid_uuid
except ImportError as e:
raise ImportError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`"
) from e
from weaviate.util import get_valid_uuid
client = client or _create_weaviate_client(
url=weaviate_url,
api_key=weaviate_api_key,
)
if batch_size:
client.batch.configure(batch_size=batch_size)
index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
embeddings = embedding.embed_documents(texts) if embedding else None
text_key = "text"
index_name = index_name or f"LangChain_{uuid4().hex}"
schema = _default_schema(index_name)
attributes = list(metadatas[0].keys()) if metadatas else None
# check whether the index already exists
if not client.schema.contains(schema):
client.schema.create_class(schema)
embeddings = embedding.embed_documents(texts) if embedding else None
attributes = list(metadatas[0].keys()) if metadatas else None
with client.batch as batch:
for i, text in enumerate(texts):
data_properties = {
@ -449,9 +491,6 @@ class Weaviate(VectorStore):
batch.flush()
relevance_score_fn = kwargs.get("relevance_score_fn")
by_text: bool = kwargs.get("by_text", False)
return cls(
client,
index_name,
@ -460,6 +499,7 @@ class Weaviate(VectorStore):
attributes=attributes,
relevance_score_fn=relevance_score_fn,
by_text=by_text,
**kwargs,
)
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:

Loading…
Cancel
Save