langchain/libs/community/langchain_community/document_loaders/surrealdb.py
2024-01-23 19:46:19 -08:00

96 lines
2.9 KiB
Python

import asyncio
import json
import logging
from typing import Any, Dict, List, Optional
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
logger = logging.getLogger(__name__)
class SurrealDBLoader(BaseLoader):
"""Load SurrealDB documents."""
def __init__(
self,
filter_criteria: Optional[Dict] = None,
**kwargs: Any,
) -> None:
try:
from surrealdb import Surreal
except ImportError as e:
raise ImportError(
"""Cannot import from surrealdb.
please install with `pip install surrealdb`."""
) from e
self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc")
if self.dburl[0:2] == "ws":
self.sdb = Surreal(self.dburl)
else:
raise ValueError("Only websocket connections are supported at this time.")
self.filter_criteria = filter_criteria or {}
if "table" in self.filter_criteria:
raise ValueError(
"key `table` is not a valid criteria for `filter_criteria` argument."
)
self.ns = kwargs.pop("ns", "langchain")
self.db = kwargs.pop("db", "database")
self.table = kwargs.pop("table", "documents")
self.sdb = Surreal(self.dburl)
self.kwargs = kwargs
async def initialize(self) -> None:
"""
Initialize connection to surrealdb database
and authenticate if credentials are provided
"""
await self.sdb.connect()
if "db_user" in self.kwargs and "db_pass" in self.kwargs:
user = self.kwargs.get("db_user")
password = self.kwargs.get("db_pass")
await self.sdb.signin({"user": user, "pass": password})
await self.sdb.use(self.ns, self.db)
def load(self) -> List[Document]:
async def _load() -> List[Document]:
await self.initialize()
return await self.aload()
return asyncio.run(_load())
async def aload(self) -> List[Document]:
"""Load data into Document objects."""
query = "SELECT * FROM type::table($table)"
if self.filter_criteria is not None and len(self.filter_criteria) > 0:
query += " WHERE "
for idx, key in enumerate(self.filter_criteria):
query += f""" {"AND" if idx > 0 else ""} {key} = ${key}"""
metadata = {
"ns": self.ns,
"db": self.db,
"table": self.table,
}
results = await self.sdb.query(
query, {"table": self.table, **self.filter_criteria}
)
return [
(
Document(
page_content=json.dumps(result),
metadata={"id": result["id"], **result["metadata"], **metadata},
)
)
for result in results[0]["result"]
]