community: fix for surrealdb client 0.3.2 update + store and retrieve metadata (#14997)

Surrealdb client changes from 0.3.1 to 0.3.2 broke the surrealdb vectore
integration.
This PR updates the code to work with the updated client. The change is
backwards compatible with previous versions of surrealdb client.
Also expanded the vector store implementation to store and retrieve
metadata that's included with the document object.
pull/15018/head
Karim Lalani 9 months ago committed by GitHub
parent c7be59c122
commit 228ddabc3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -62,7 +62,7 @@ class SurrealDBStore(VectorStore):
self.db = kwargs.pop("db", "database")
self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc")
self.embedding_function = embedding_function
self.sdb = Surreal()
self.sdb = Surreal(self.dburl)
self.kwargs = kwargs
async def initialize(self) -> None:
@ -103,8 +103,12 @@ class SurrealDBStore(VectorStore):
embeddings = self.embedding_function.embed_documents(list(texts))
ids = []
for idx, text in enumerate(texts):
data = {"text": text, "embedding": embeddings[idx]}
if metadatas is not None and idx < len(metadatas):
data["metadata"] = metadatas[idx]
record = await self.sdb.create(
self.collection, {"text": text, "embedding": embeddings[idx]}
self.collection,
data,
)
ids.append(record[0]["id"])
return ids
@ -123,7 +127,16 @@ class SurrealDBStore(VectorStore):
Returns:
List of ids for the newly inserted documents
"""
return asyncio.run(self.aadd_texts(texts, metadatas, **kwargs))
async def _add_texts(
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
await self.initialize()
return await self.aadd_texts(texts, metadatas, **kwargs)
return asyncio.run(_add_texts(texts, metadatas, **kwargs))
async def adelete(
self,
@ -195,7 +208,7 @@ class SurrealDBStore(VectorStore):
"k": k,
"score_threshold": kwargs.get("score_threshold", 0),
}
query = """select id, text,
query = """select id, text, metadata,
vector::similarity::cosine(embedding,{embedding}) as similarity
from {collection}
where vector::similarity::cosine(embedding,{embedding}) >= {score_threshold}
@ -208,7 +221,10 @@ class SurrealDBStore(VectorStore):
return [
(
Document(page_content=result["text"], metadata={"id": result["id"]}),
Document(
page_content=result["text"],
metadata={"id": result["id"], **result["metadata"]},
),
result["similarity"],
)
for result in results[0]["result"]
@ -401,7 +417,7 @@ class SurrealDBStore(VectorStore):
sdb = cls(embedding, **kwargs)
await sdb.initialize()
await sdb.aadd_texts(texts)
await sdb.aadd_texts(texts, metadatas, **kwargs)
return sdb
@classmethod

Loading…
Cancel
Save