mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
cd52433ba0
- **Description:** A generic document loader adapter for SQLAlchemy on top of LangChain's `SQLDatabaseLoader`. - **Needed by:** https://github.com/crate-workbench/langchain/pull/1 - **Depends on:** GH-16655 - **Addressed to:** @baskaryan, @cbornet, @eyurtsev Hi from CrateDB again, in the same spirit like GH-16243 and GH-16244, this patch breaks out another commit from https://github.com/crate-workbench/langchain/pull/1, in order to reduce the size of this patch before submitting it, and to separate concerns. To accompany the SQLAlchemy adapter implementation, the patch includes integration tests for both SQLite and PostgreSQL. Let me know if corresponding utility resources should be added at different spots. With kind regards, Andreas. ### Software Tests ```console docker compose --file libs/community/tests/integration_tests/document_loaders/docker-compose/postgresql.yml up ``` ```console cd libs/community pip install psycopg2-binary pytest -vvv tests/integration_tests -k sqldatabase ``` ``` 14 passed ``` ![image](https://github.com/langchain-ai/langchain/assets/453543/42be233c-eb37-4c76-a830-474276e01436) --------- Co-authored-by: Andreas Motl <andreas.motl@crate.io>
140 lines
5.5 KiB
Python
140 lines
5.5 KiB
Python
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union
|
|
|
|
import sqlalchemy as sa
|
|
|
|
from langchain_community.docstore.document import Document
|
|
from langchain_community.document_loaders.base import BaseLoader
|
|
from langchain_community.utilities.sql_database import SQLDatabase
|
|
|
|
|
|
class SQLDatabaseLoader(BaseLoader):
|
|
"""
|
|
Load documents by querying database tables supported by SQLAlchemy.
|
|
|
|
For talking to the database, the document loader uses the `SQLDatabase`
|
|
utility from the LangChain integration toolkit.
|
|
|
|
Each document represents one row of the result.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
query: Union[str, sa.Select],
|
|
db: SQLDatabase,
|
|
*,
|
|
parameters: Optional[Dict[str, Any]] = None,
|
|
page_content_mapper: Optional[Callable[..., str]] = None,
|
|
metadata_mapper: Optional[Callable[..., Dict[str, Any]]] = None,
|
|
source_columns: Optional[Sequence[str]] = None,
|
|
include_rownum_into_metadata: bool = False,
|
|
include_query_into_metadata: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
query: The query to execute.
|
|
db: A LangChain `SQLDatabase`, wrapping an SQLAlchemy engine.
|
|
sqlalchemy_kwargs: More keyword arguments for SQLAlchemy's `create_engine`.
|
|
parameters: Optional. Parameters to pass to the query.
|
|
page_content_mapper: Optional. Function to convert a row into a string
|
|
to use as the `page_content` of the document. By default, the loader
|
|
serializes the whole row into a string, including all columns.
|
|
metadata_mapper: Optional. Function to convert a row into a dictionary
|
|
to use as the `metadata` of the document. By default, no columns are
|
|
selected into the metadata dictionary.
|
|
source_columns: Optional. The names of the columns to use as the `source`
|
|
within the metadata dictionary.
|
|
include_rownum_into_metadata: Optional. Whether to include the row number
|
|
into the metadata dictionary. Default: False.
|
|
include_query_into_metadata: Optional. Whether to include the query
|
|
expression into the metadata dictionary. Default: False.
|
|
"""
|
|
self.query = query
|
|
self.db: SQLDatabase = db
|
|
self.parameters = parameters or {}
|
|
self.page_content_mapper = (
|
|
page_content_mapper or self.page_content_default_mapper
|
|
)
|
|
self.metadata_mapper = metadata_mapper or self.metadata_default_mapper
|
|
self.source_columns = source_columns
|
|
self.include_rownum_into_metadata = include_rownum_into_metadata
|
|
self.include_query_into_metadata = include_query_into_metadata
|
|
|
|
def lazy_load(self) -> Iterator[Document]:
|
|
try:
|
|
import sqlalchemy as sa
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import sqlalchemy python package. "
|
|
"Please install it with `pip install sqlalchemy`."
|
|
)
|
|
|
|
# Querying in `cursor` fetch mode will return an SQLAlchemy `Result` instance.
|
|
result: sa.Result[Any]
|
|
|
|
# Invoke the database query.
|
|
if isinstance(self.query, sa.SelectBase):
|
|
result = self.db._execute( # type: ignore[assignment]
|
|
self.query, fetch="cursor", parameters=self.parameters
|
|
)
|
|
query_sql = str(self.query.compile(bind=self.db._engine))
|
|
elif isinstance(self.query, str):
|
|
result = self.db._execute( # type: ignore[assignment]
|
|
sa.text(self.query), fetch="cursor", parameters=self.parameters
|
|
)
|
|
query_sql = self.query
|
|
else:
|
|
raise TypeError(f"Unable to process query of unknown type: {self.query}")
|
|
|
|
# Iterate database result rows and generate list of documents.
|
|
for i, row in enumerate(result.mappings()):
|
|
page_content = self.page_content_mapper(row)
|
|
metadata = self.metadata_mapper(row)
|
|
|
|
if self.include_rownum_into_metadata:
|
|
metadata["row"] = i
|
|
if self.include_query_into_metadata:
|
|
metadata["query"] = query_sql
|
|
|
|
source_values = []
|
|
for column, value in row.items():
|
|
if self.source_columns and column in self.source_columns:
|
|
source_values.append(value)
|
|
if source_values:
|
|
metadata["source"] = ",".join(source_values)
|
|
|
|
yield Document(page_content=page_content, metadata=metadata)
|
|
|
|
def load(self) -> List[Document]:
|
|
return list(self.lazy_load())
|
|
|
|
@staticmethod
|
|
def page_content_default_mapper(
|
|
row: sa.RowMapping, column_names: Optional[List[str]] = None
|
|
) -> str:
|
|
"""
|
|
A reasonable default function to convert a record into a "page content" string.
|
|
"""
|
|
if column_names is None:
|
|
column_names = list(row.keys())
|
|
return "\n".join(
|
|
f"{column}: {value}"
|
|
for column, value in row.items()
|
|
if column in column_names
|
|
)
|
|
|
|
@staticmethod
|
|
def metadata_default_mapper(
|
|
row: sa.RowMapping, column_names: Optional[List[str]] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
A reasonable default function to convert a record into a "metadata" dictionary.
|
|
"""
|
|
if column_names is None:
|
|
return {}
|
|
|
|
metadata: Dict[str, Any] = {}
|
|
for column, value in row.items():
|
|
if column in column_names:
|
|
metadata[column] = value
|
|
return metadata
|