mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
98 lines
2.9 KiB
Python
98 lines
2.9 KiB
Python
|
import asyncio
|
||
|
import json
|
||
|
import logging
|
||
|
from typing import Any, Dict, List, Optional
|
||
|
|
||
|
from langchain_core.documents import Document
|
||
|
|
||
|
from langchain_community.document_loaders.base import BaseLoader
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class SurrealDBLoader(BaseLoader):
|
||
|
"""Load SurrealDB documents."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
filter_criteria: Optional[Dict] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> None:
|
||
|
try:
|
||
|
from surrealdb import Surreal
|
||
|
except ImportError as e:
|
||
|
raise ImportError(
|
||
|
"""Cannot import from surrealdb.
|
||
|
please install with `pip install surrealdb`."""
|
||
|
) from e
|
||
|
|
||
|
self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc")
|
||
|
|
||
|
if self.dburl[0:2] == "ws":
|
||
|
self.sdb = Surreal(self.dburl)
|
||
|
else:
|
||
|
raise ValueError("Only websocket connections are supported at this time.")
|
||
|
|
||
|
self.filter_criteria = filter_criteria or {}
|
||
|
|
||
|
if "table" in self.filter_criteria:
|
||
|
raise ValueError(
|
||
|
"key `table` is not a valid criteria for `filter_criteria` argument."
|
||
|
)
|
||
|
|
||
|
self.ns = kwargs.pop("ns", "langchain")
|
||
|
self.db = kwargs.pop("db", "database")
|
||
|
self.table = kwargs.pop("table", "documents")
|
||
|
self.sdb = Surreal(self.dburl)
|
||
|
self.kwargs = kwargs
|
||
|
|
||
|
asyncio.run(self.initialize())
|
||
|
|
||
|
async def initialize(self) -> None:
|
||
|
"""
|
||
|
Initialize connection to surrealdb database
|
||
|
and authenticate if credentials are provided
|
||
|
"""
|
||
|
await self.sdb.connect()
|
||
|
if "db_user" in self.kwargs and "db_pass" in self.kwargs:
|
||
|
user = self.kwargs.get("db_user")
|
||
|
password = self.kwargs.get("db_pass")
|
||
|
await self.sdb.signin({"user": user, "pass": password})
|
||
|
|
||
|
await self.sdb.use(self.ns, self.db)
|
||
|
|
||
|
def load(self) -> List[Document]:
|
||
|
async def _load() -> List[Document]:
|
||
|
await self.initialize()
|
||
|
return await self.aload()
|
||
|
|
||
|
return asyncio.run(_load())
|
||
|
|
||
|
async def aload(self) -> List[Document]:
|
||
|
"""Load data into Document objects."""
|
||
|
|
||
|
query = "SELECT * FROM type::table($table)"
|
||
|
if self.filter_criteria is not None and len(self.filter_criteria) > 0:
|
||
|
query += " WHERE "
|
||
|
for idx, key in enumerate(self.filter_criteria):
|
||
|
query += f""" {"AND" if idx > 0 else ""} {key} = ${key}"""
|
||
|
|
||
|
metadata = {
|
||
|
"ns": self.ns,
|
||
|
"db": self.db,
|
||
|
"table": self.table,
|
||
|
}
|
||
|
results = await self.sdb.query(
|
||
|
query, {"table": self.table, **self.filter_criteria}
|
||
|
)
|
||
|
|
||
|
return [
|
||
|
(
|
||
|
Document(
|
||
|
page_content=json.dumps(result),
|
||
|
metadata={"id": result["id"], **result["metadata"], **metadata},
|
||
|
)
|
||
|
)
|
||
|
for result in results[0]["result"]
|
||
|
]
|