diff --git a/langchain/vectorstores/weaviate.py b/langchain/vectorstores/weaviate.py index 9d2a168b..a7398ae7 100644 --- a/langchain/vectorstores/weaviate.py +++ b/langchain/vectorstores/weaviate.py @@ -63,6 +63,12 @@ def _default_score_normalizer(val: float) -> float: return 1 - 1 / (1 + np.exp(val)) +def _json_serializable(value: Any) -> Any: + if isinstance(value, datetime.datetime): + return value.isoformat() + return value + + class Weaviate(VectorStore): """Wrapper around Weaviate vector database. @@ -121,42 +127,30 @@ class Weaviate(VectorStore): """Upload texts with metadata (properties) to Weaviate.""" from weaviate.util import get_valid_uuid - def json_serializable(value: Any) -> Any: - if isinstance(value, datetime.datetime): - return value.isoformat() - return value - + ids = [] with self._client.batch as batch: - ids = [] - for i, doc in enumerate(texts): - data_properties = { - self._text_key: doc, - } + for i, text in enumerate(texts): + data_properties = {self._text_key: text} if metadatas is not None: - for key in metadatas[i].keys(): - data_properties[key] = json_serializable(metadatas[i][key]) + for key, val in metadatas[i].items(): + data_properties[key] = _json_serializable(val) # If the UUID of one of the objects already exists - # then the existing objectwill be replaced by the new object. - if "uuids" in kwargs: - _id = kwargs["uuids"][i] - else: - _id = get_valid_uuid(uuid4()) + # then the existing object will be replaced by the new object. + _id = ( + kwargs["uuids"][i] if "uuids" in kwargs else get_valid_uuid(uuid4()) + ) if self._embedding is not None: - embeddings = self._embedding.embed_documents(list(doc)) - batch.add_data_object( - data_object=data_properties, - class_name=self._index_name, - uuid=_id, - vector=embeddings[0], - ) + vector = self._embedding.embed_documents([text])[0] else: - batch.add_data_object( - data_object=data_properties, - class_name=self._index_name, - uuid=_id, - ) + vector = None + batch.add_data_object( + data_object=data_properties, + class_name=self._index_name, + uuid=_id, + vector=vector, + ) ids.append(_id) return ids