community: fix for surrealdb client 0.3.2 update + store and retrieve metadata (#14997)

Surrealdb client changes from 0.3.1 to 0.3.2 broke the surrealdb vectore
integration.
This PR updates the code to work with the updated client. The change is
backwards compatible with previous versions of surrealdb client.
Also expanded the vector store implementation to store and retrieve
metadata that's included with the document object.
This commit is contained in:
Karim Lalani 2023-12-21 11:04:57 -06:00 committed by GitHub
parent c7be59c122
commit 228ddabc3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -62,7 +62,7 @@ class SurrealDBStore(VectorStore):
self.db = kwargs.pop("db", "database") self.db = kwargs.pop("db", "database")
self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc") self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc")
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.sdb = Surreal() self.sdb = Surreal(self.dburl)
self.kwargs = kwargs self.kwargs = kwargs
async def initialize(self) -> None: async def initialize(self) -> None:
@ -103,8 +103,12 @@ class SurrealDBStore(VectorStore):
embeddings = self.embedding_function.embed_documents(list(texts)) embeddings = self.embedding_function.embed_documents(list(texts))
ids = [] ids = []
for idx, text in enumerate(texts): for idx, text in enumerate(texts):
data = {"text": text, "embedding": embeddings[idx]}
if metadatas is not None and idx < len(metadatas):
data["metadata"] = metadatas[idx]
record = await self.sdb.create( record = await self.sdb.create(
self.collection, {"text": text, "embedding": embeddings[idx]} self.collection,
data,
) )
ids.append(record[0]["id"]) ids.append(record[0]["id"])
return ids return ids
@ -123,7 +127,16 @@ class SurrealDBStore(VectorStore):
Returns: Returns:
List of ids for the newly inserted documents List of ids for the newly inserted documents
""" """
return asyncio.run(self.aadd_texts(texts, metadatas, **kwargs))
async def _add_texts(
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
await self.initialize()
return await self.aadd_texts(texts, metadatas, **kwargs)
return asyncio.run(_add_texts(texts, metadatas, **kwargs))
async def adelete( async def adelete(
self, self,
@ -195,7 +208,7 @@ class SurrealDBStore(VectorStore):
"k": k, "k": k,
"score_threshold": kwargs.get("score_threshold", 0), "score_threshold": kwargs.get("score_threshold", 0),
} }
query = """select id, text, query = """select id, text, metadata,
vector::similarity::cosine(embedding,{embedding}) as similarity vector::similarity::cosine(embedding,{embedding}) as similarity
from {collection} from {collection}
where vector::similarity::cosine(embedding,{embedding}) >= {score_threshold} where vector::similarity::cosine(embedding,{embedding}) >= {score_threshold}
@ -208,7 +221,10 @@ class SurrealDBStore(VectorStore):
return [ return [
( (
Document(page_content=result["text"], metadata={"id": result["id"]}), Document(
page_content=result["text"],
metadata={"id": result["id"], **result["metadata"]},
),
result["similarity"], result["similarity"],
) )
for result in results[0]["result"] for result in results[0]["result"]
@ -401,7 +417,7 @@ class SurrealDBStore(VectorStore):
sdb = cls(embedding, **kwargs) sdb = cls(embedding, **kwargs)
await sdb.initialize() await sdb.initialize()
await sdb.aadd_texts(texts) await sdb.aadd_texts(texts, metadatas, **kwargs)
return sdb return sdb
@classmethod @classmethod