diff --git a/libs/community/langchain_community/vectorstores/surrealdb.py b/libs/community/langchain_community/vectorstores/surrealdb.py index a96c2410fe..773a00cc57 100644 --- a/libs/community/langchain_community/vectorstores/surrealdb.py +++ b/libs/community/langchain_community/vectorstores/surrealdb.py @@ -62,7 +62,7 @@ class SurrealDBStore(VectorStore): self.db = kwargs.pop("db", "database") self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc") self.embedding_function = embedding_function - self.sdb = Surreal() + self.sdb = Surreal(self.dburl) self.kwargs = kwargs async def initialize(self) -> None: @@ -103,8 +103,12 @@ class SurrealDBStore(VectorStore): embeddings = self.embedding_function.embed_documents(list(texts)) ids = [] for idx, text in enumerate(texts): + data = {"text": text, "embedding": embeddings[idx]} + if metadatas is not None and idx < len(metadatas): + data["metadata"] = metadatas[idx] record = await self.sdb.create( - self.collection, {"text": text, "embedding": embeddings[idx]} + self.collection, + data, ) ids.append(record[0]["id"]) return ids @@ -123,7 +127,16 @@ class SurrealDBStore(VectorStore): Returns: List of ids for the newly inserted documents """ - return asyncio.run(self.aadd_texts(texts, metadatas, **kwargs)) + + async def _add_texts( + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> List[str]: + await self.initialize() + return await self.aadd_texts(texts, metadatas, **kwargs) + + return asyncio.run(_add_texts(texts, metadatas, **kwargs)) async def adelete( self, @@ -195,7 +208,7 @@ class SurrealDBStore(VectorStore): "k": k, "score_threshold": kwargs.get("score_threshold", 0), } - query = """select id, text, + query = """select id, text, metadata, vector::similarity::cosine(embedding,{embedding}) as similarity from {collection} where vector::similarity::cosine(embedding,{embedding}) >= {score_threshold} @@ -208,7 +221,10 @@ class SurrealDBStore(VectorStore): return [ ( - Document(page_content=result["text"], metadata={"id": result["id"]}), + Document( + page_content=result["text"], + metadata={"id": result["id"], **result["metadata"]}, + ), result["similarity"], ) for result in results[0]["result"] @@ -401,7 +417,7 @@ class SurrealDBStore(VectorStore): sdb = cls(embedding, **kwargs) await sdb.initialize() - await sdb.aadd_texts(texts) + await sdb.aadd_texts(texts, metadatas, **kwargs) return sdb @classmethod