mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community: Enhance MongoDBLoader with flexible metadata and optimized field extraction (#23376)
### Description: This pull request significantly enhances the MongodbLoader class in the LangChain community package by adding robust metadata customization and improved field extraction capabilities. The updated class now allows users to specify additional metadata fields through the metadata_names parameter, enabling the extraction of both top-level and deeply nested document attributes as metadata. This flexibility is crucial for users who need to include detailed contextual information without altering the database schema. Moreover, the include_db_collection_in_metadata flag offers optional inclusion of database and collection names in the metadata, allowing for even greater customization depending on the user's needs. The loader's field extraction logic has been refined to handle missing or nested fields more gracefully. It now employs a safe access mechanism that avoids the KeyError previously encountered when a specified nested field was absent in a document. This update ensures that the loader can handle diverse and complex data structures without failure, making it more resilient and user-friendly. ### Issue: This pull request addresses a critical issue where the MongodbLoader class in the LangChain community package could throw a KeyError when attempting to access nested fields that may not exist in some documents. The previous implementation did not handle the absence of specified nested fields gracefully, leading to runtime errors and interruptions in data processing workflows. This enhancement ensures robust error handling by safely accessing nested document fields, using default values for missing data, thus preventing KeyError and ensuring smoother operation across various data structures in MongoDB. This improvement is crucial for users working with diverse and complex data sets, ensuring the loader can adapt to documents with varying structures without failing. ### Dependencies: Requires motor for asynchronous MongoDB interaction. ### Twitter handle: N/A ### Add tests and docs Tests: Unit tests have been added to verify that the metadata inclusion toggle works as expected and that the field extraction correctly handles nested fields. Docs: An example notebook demonstrating the use of the enhanced MongodbLoader is included in the docs/docs/integrations directory. This notebook includes setup instructions, example usage, and outputs. (Here is the notebook link : [colab link](https://colab.research.google.com/drive/1tp7nyUnzZa3dxEFF4Kc3KS7ACuNF6jzH?usp=sharing)) Lint and test Before submitting, I ran make format, make lint, and make test as per the contribution guidelines. All tests pass, and the code style adheres to the LangChain standards. ```python import unittest from unittest.mock import patch, MagicMock import asyncio from langchain_community.document_loaders.mongodb import MongodbLoader class TestMongodbLoader(unittest.TestCase): def setUp(self): """Setup the MongodbLoader test environment by mocking the motor client and database collection interactions.""" # Mocking the AsyncIOMotorClient self.mock_client = MagicMock() self.mock_db = MagicMock() self.mock_collection = MagicMock() self.mock_client.get_database.return_value = self.mock_db self.mock_db.get_collection.return_value = self.mock_collection # Initialize the MongodbLoader with test data self.loader = MongodbLoader( connection_string="mongodb://localhost:27017", db_name="testdb", collection_name="testcol" ) @patch('langchain_community.document_loaders.mongodb.AsyncIOMotorClient', return_value=MagicMock()) def test_constructor(self, mock_motor_client): """Test if the constructor properly initializes with the correct database and collection names.""" loader = MongodbLoader( connection_string="mongodb://localhost:27017", db_name="testdb", collection_name="testcol" ) self.assertEqual(loader.db_name, "testdb") self.assertEqual(loader.collection_name, "testcol") def test_aload(self): """Test the aload method to ensure it correctly queries and processes documents.""" # Setup mock data and responses for the database operations self.mock_collection.count_documents.return_value = asyncio.Future() self.mock_collection.count_documents.return_value.set_result(1) self.mock_collection.find.return_value = [ {"_id": "1", "content": "Test document content"} ] # Run the aload method and check responses loop = asyncio.get_event_loop() results = loop.run_until_complete(self.loader.aload()) self.assertEqual(len(results), 1) self.assertEqual(results[0].page_content, "Test document content") def test_construct_projection(self): """Verify that the projection dictionary is constructed correctly based on field names.""" self.loader.field_names = ['content', 'author'] self.loader.metadata_names = ['timestamp'] expected_projection = {'content': 1, 'author': 1, 'timestamp': 1} projection = self.loader._construct_projection() self.assertEqual(projection, expected_projection) if __name__ == '__main__': unittest.main() ``` ### Additional Example for Documentation Sample Data: ```json [ { "_id": "1", "title": "Artificial Intelligence in Medicine", "content": "AI is transforming the medical industry by providing personalized medicine solutions.", "author": { "name": "John Doe", "email": "john.doe@example.com" }, "tags": ["AI", "Healthcare", "Innovation"] }, { "_id": "2", "title": "Data Science in Sports", "content": "Data science provides insights into player performance and strategic planning in sports.", "author": { "name": "Jane Smith", "email": "jane.smith@example.com" }, "tags": ["Data Science", "Sports", "Analytics"] } ] ``` Example Code: ```python loader = MongodbLoader( connection_string="mongodb://localhost:27017", db_name="example_db", collection_name="articles", filter_criteria={"tags": "AI"}, field_names=["title", "content"], metadata_names=["author.name", "author.email"], include_db_collection_in_metadata=True ) documents = loader.load() for doc in documents: print("Page Content:", doc.page_content) print("Metadata:", doc.metadata) ``` Expected Output: ``` Page Content: Artificial Intelligence in Medicine AI is transforming the medical industry by providing personalized medicine solutions. Metadata: {'author_name': 'John Doe', 'author_email': 'john.doe@example.com', 'database': 'example_db', 'collection': 'articles'} ``` Thank you. --- Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
6758894af1
commit
0a177ec2cc
@ -20,13 +20,37 @@ class MongodbLoader(BaseLoader):
|
||||
*,
|
||||
filter_criteria: Optional[Dict] = None,
|
||||
field_names: Optional[Sequence[str]] = None,
|
||||
metadata_names: Optional[Sequence[str]] = None,
|
||||
include_db_collection_in_metadata: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the MongoDB loader with necessary database connection
|
||||
details and configurations.
|
||||
|
||||
Args:
|
||||
connection_string (str): MongoDB connection URI.
|
||||
db_name (str):Name of the database to connect to.
|
||||
collection_name (str): Name of the collection to fetch documents from.
|
||||
filter_criteria (Optional[Dict]): MongoDB filter criteria for querying
|
||||
documents.
|
||||
field_names (Optional[Sequence[str]]): List of field names to retrieve
|
||||
from documents.
|
||||
metadata_names (Optional[Sequence[str]]): Additional metadata fields to
|
||||
extract from documents.
|
||||
include_db_collection_in_metadata (bool): Flag to include database and
|
||||
collection names in metadata.
|
||||
|
||||
Raises:
|
||||
ImportError: If the motor library is not installed.
|
||||
ValueError: If any necessary argument is missing.
|
||||
"""
|
||||
try:
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cannot import from motor, please install with `pip install motor`."
|
||||
) from e
|
||||
|
||||
if not connection_string:
|
||||
raise ValueError("connection_string must be provided.")
|
||||
|
||||
@ -39,8 +63,10 @@ class MongodbLoader(BaseLoader):
|
||||
self.client = AsyncIOMotorClient(connection_string)
|
||||
self.db_name = db_name
|
||||
self.collection_name = collection_name
|
||||
self.field_names = field_names
|
||||
self.field_names = field_names or []
|
||||
self.filter_criteria = filter_criteria or {}
|
||||
self.metadata_names = metadata_names or []
|
||||
self.include_db_collection_in_metadata = include_db_collection_in_metadata
|
||||
|
||||
self.db = self.client.get_database(db_name)
|
||||
self.collection = self.db.get_collection(collection_name)
|
||||
@ -60,36 +86,24 @@ class MongodbLoader(BaseLoader):
|
||||
return asyncio.run(self.aload())
|
||||
|
||||
async def aload(self) -> List[Document]:
|
||||
"""Load data into Document objects."""
|
||||
"""Asynchronously loads data into Document objects."""
|
||||
result = []
|
||||
total_docs = await self.collection.count_documents(self.filter_criteria)
|
||||
|
||||
# Construct the projection dictionary if field_names are specified
|
||||
projection = (
|
||||
{field: 1 for field in self.field_names} if self.field_names else None
|
||||
)
|
||||
projection = self._construct_projection()
|
||||
|
||||
async for doc in self.collection.find(self.filter_criteria, projection):
|
||||
metadata = {
|
||||
"database": self.db_name,
|
||||
"collection": self.collection_name,
|
||||
}
|
||||
metadata = self._extract_fields(doc, self.metadata_names, default="")
|
||||
|
||||
# Optionally add database and collection names to metadata
|
||||
if self.include_db_collection_in_metadata:
|
||||
metadata.update(
|
||||
{"database": self.db_name, "collection": self.collection_name}
|
||||
)
|
||||
|
||||
# Extract text content from filtered fields or use the entire document
|
||||
if self.field_names is not None:
|
||||
fields = {}
|
||||
for name in self.field_names:
|
||||
# Split the field names to handle nested fields
|
||||
keys = name.split(".")
|
||||
value = doc
|
||||
for key in keys:
|
||||
if key in value:
|
||||
value = value[key]
|
||||
else:
|
||||
value = ""
|
||||
break
|
||||
fields[name] = value
|
||||
|
||||
fields = self._extract_fields(doc, self.field_names, default="")
|
||||
texts = [str(value) for value in fields.values()]
|
||||
text = " ".join(texts)
|
||||
else:
|
||||
@ -104,3 +118,29 @@ class MongodbLoader(BaseLoader):
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _construct_projection(self) -> Optional[Dict]:
|
||||
"""Constructs the projection dictionary for MongoDB query based
|
||||
on the specified field names and metadata names."""
|
||||
field_names = list(self.field_names) or []
|
||||
metadata_names = list(self.metadata_names) or []
|
||||
all_fields = field_names + metadata_names
|
||||
return {field: 1 for field in all_fields} if all_fields else None
|
||||
|
||||
def _extract_fields(
|
||||
self,
|
||||
document: Dict,
|
||||
fields: Sequence[str],
|
||||
default: str = "",
|
||||
) -> Dict:
|
||||
"""Extracts and returns values for specified fields from a document."""
|
||||
extracted = {}
|
||||
for field in fields or []:
|
||||
value = document
|
||||
for key in field.split("."):
|
||||
value = value.get(key, default)
|
||||
if value == default:
|
||||
break
|
||||
new_field_name = field.replace(".", "_")
|
||||
extracted[new_field_name] = value
|
||||
return extracted
|
||||
|
@ -12,6 +12,7 @@ def raw_docs() -> List[Dict]:
|
||||
return [
|
||||
{"_id": "1", "address": {"building": "1", "room": "1"}},
|
||||
{"_id": "2", "address": {"building": "2", "room": "2"}},
|
||||
{"_id": "3", "address": {"building": "3", "room": "2"}},
|
||||
]
|
||||
|
||||
|
||||
@ -19,18 +20,23 @@ def raw_docs() -> List[Dict]:
|
||||
def expected_documents() -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content="{'_id': '1', 'address': {'building': '1', 'room': '1'}}",
|
||||
page_content="{'_id': '2', 'address': {'building': '2', 'room': '2'}}",
|
||||
metadata={"database": "sample_restaurants", "collection": "restaurants"},
|
||||
),
|
||||
Document(
|
||||
page_content="{'_id': '2', 'address': {'building': '2', 'room': '2'}}",
|
||||
page_content="{'_id': '3', 'address': {'building': '3', 'room': '2'}}",
|
||||
metadata={"database": "sample_restaurants", "collection": "restaurants"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("motor")
|
||||
async def test_load_mocked(expected_documents: List[Document]) -> None:
|
||||
async def test_load_mocked_with_filters(expected_documents: List[Document]) -> None:
|
||||
filter_criteria = {"address.room": {"$eq": "2"}}
|
||||
field_names = ["address.building", "address.room"]
|
||||
metadata_names = ["_id"]
|
||||
include_db_collection_in_metadata = True
|
||||
|
||||
mock_async_load = AsyncMock()
|
||||
mock_async_load.return_value = expected_documents
|
||||
|
||||
@ -51,7 +57,13 @@ async def test_load_mocked(expected_documents: List[Document]) -> None:
|
||||
new=mock_async_load,
|
||||
):
|
||||
loader = MongodbLoader(
|
||||
"mongodb://localhost:27017", "test_db", "test_collection"
|
||||
"mongodb://localhost:27017",
|
||||
"test_db",
|
||||
"test_collection",
|
||||
filter_criteria=filter_criteria,
|
||||
field_names=field_names,
|
||||
metadata_names=metadata_names,
|
||||
include_db_collection_in_metadata=include_db_collection_in_metadata,
|
||||
)
|
||||
loader.collection = mock_collection
|
||||
documents = await loader.aload()
|
||||
|
Loading…
Reference in New Issue
Block a user