import json from typing import ( TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Sequence, Union, ) from langchain_core.documents import Document from langchain_community.document_loaders.base import BaseLoader def default_page_content_mapper(row: Any) -> str: if hasattr(row, "_asdict"): return json.dumps(row._asdict()) return json.dumps(row) _NOT_SET = object() if TYPE_CHECKING: from cassandra.cluster import Session from cassandra.pool import Host from cassandra.query import Statement class CassandraLoader(BaseLoader): def __init__( self, table: Optional[str] = None, session: Optional["Session"] = None, keyspace: Optional[str] = None, query: Optional[Union[str, "Statement"]] = None, page_content_mapper: Callable[[Any], str] = default_page_content_mapper, metadata_mapper: Callable[[Any], dict] = lambda _: {}, *, query_parameters: Union[dict, Sequence] = None, query_timeout: Optional[float] = _NOT_SET, query_trace: bool = False, query_custom_payload: dict = None, query_execution_profile: Any = _NOT_SET, query_paging_state: Any = None, query_host: "Host" = None, query_execute_as: str = None, ) -> None: """ Document Loader for Apache Cassandra. Args: table: The table to load the data from. (do not use together with the query parameter) session: The cassandra driver session. If not provided, the cassio resolved session will be used. keyspace: The keyspace of the table. If not provided, the cassio resolved keyspace will be used. query: The query used to load the data. (do not use together with the table parameter) page_content_mapper: a function to convert a row to string page content. query_parameters: The query parameters used when calling session.execute . query_timeout: The query timeout used when calling session.execute . query_custom_payload: The query custom_payload used when calling session.execute . query_execution_profile: The query execution_profile used when calling session.execute . query_host: The query host used when calling session.execute . query_execute_as: The query execute_as used when calling session.execute . """ if query and table: raise ValueError("Cannot specify both query and table.") if not query and not table: raise ValueError("Must specify query or table.") if not session or (table and not keyspace): try: from cassio.config import check_resolve_keyspace, check_resolve_session except (ImportError, ModuleNotFoundError): raise ImportError( "Could not import a recent cassio package." "Please install it with `pip install --upgrade cassio`." ) if table: _keyspace = keyspace or check_resolve_keyspace(keyspace) self.query = f"SELECT * FROM {_keyspace}.{table};" self.metadata = {"table": table, "keyspace": _keyspace} else: self.query = query self.metadata = {} self.session = session or check_resolve_session(session) self.page_content_mapper = page_content_mapper self.metadata_mapper = metadata_mapper self.query_kwargs = { "parameters": query_parameters, "trace": query_trace, "custom_payload": query_custom_payload, "paging_state": query_paging_state, "host": query_host, "execute_as": query_execute_as, } if query_timeout is not _NOT_SET: self.query_kwargs["timeout"] = query_timeout if query_execution_profile is not _NOT_SET: self.query_kwargs["execution_profile"] = query_execution_profile def load(self) -> List[Document]: return list(self.lazy_load()) def lazy_load(self) -> Iterator[Document]: for row in self.session.execute(self.query, **self.query_kwargs): metadata = self.metadata.copy() metadata.update(self.metadata_mapper(row)) yield Document( page_content=self.page_content_mapper(row), metadata=metadata )