diff --git a/docs/modules/indexes/vectorstores/examples/mongodb_atlas_vector_search.ipynb b/docs/modules/indexes/vectorstores/examples/mongodb_atlas_vector_search.ipynb index 4af84211..1cba5676 100644 --- a/docs/modules/indexes/vectorstores/examples/mongodb_atlas_vector_search.ipynb +++ b/docs/modules/indexes/vectorstores/examples/mongodb_atlas_vector_search.ipynb @@ -118,15 +118,14 @@ "\n", "db_name = \"lanchain_db\"\n", "collection_name = \"langchain_col\"\n", - "namespace = f\"{db_name}.{collection_name}\"\n", + "collection = client[db_name][collection_name]\n", "index_name = \"langchain_demo\"\n", "\n", "# insert the documents in MongoDB Atlas with their embedding\n", "docsearch = MongoDBAtlasVectorSearch.from_documents(\n", " docs,\n", " embeddings,\n", - " client=client,\n", - " namespace=namespace,\n", + " collection=collection,\n", " index_name=index_name\n", ")\n", "\n", diff --git a/langchain/vectorstores/mongodb_atlas.py b/langchain/vectorstores/mongodb_atlas.py index 3bd9abc3..78df5317 100644 --- a/langchain/vectorstores/mongodb_atlas.py +++ b/langchain/vectorstores/mongodb_atlas.py @@ -10,6 +10,7 @@ from typing import ( List, Optional, Tuple, + TypeVar, Union, ) @@ -18,7 +19,9 @@ from langchain.embeddings.base import Embeddings from langchain.vectorstores.base import VectorStore if TYPE_CHECKING: - from pymongo import MongoClient + from pymongo.collection import Collection + +MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any]) logger = logging.getLogger(__name__) @@ -41,15 +44,14 @@ class MongoDBAtlasVectorSearch(VectorStore): from pymongo import MongoClient mongo_client = MongoClient("") - namespace = "." + collection = mongo_client[""][""] embeddings = OpenAIEmbeddings() - vectorstore = MongoDBAtlasVectorSearch(mongo_client, namespace, embeddings) + vectorstore = MongoDBAtlasVectorSearch(collection, embeddings) """ def __init__( self, - client: MongoClient, - namespace: str, + collection: Collection[MongoDBDocumentType], embedding: Embeddings, *, index_name: str = "default", @@ -58,17 +60,14 @@ class MongoDBAtlasVectorSearch(VectorStore): ): """ Args: - client: MongoDB client. - namespace: MongoDB namespace to add the texts to. + collection: MongoDB collection to add the texts to. embedding: Text embedding model to use. text_key: MongoDB field that will contain the text for each document. embedding_key: MongoDB field that will contain the embedding for each document. """ - self._client = client - db_name, collection_name = namespace.split(".") - self._collection = client[db_name][collection_name] + self._collection = collection self._embedding = embedding self._index_name = index_name self._text_key = text_key @@ -90,7 +89,9 @@ class MongoDBAtlasVectorSearch(VectorStore): "`pip install pymongo`." ) client: MongoClient = MongoClient(connection_string) - return cls(client, namespace, embedding, **kwargs) + db_name, collection_name = namespace.split(".") + collection = client[db_name][collection_name] + return cls(collection, embedding, **kwargs) def add_texts( self, @@ -232,8 +233,7 @@ class MongoDBAtlasVectorSearch(VectorStore): texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, - client: Optional[MongoClient] = None, - namespace: Optional[str] = None, + collection: Optional[Collection[MongoDBDocumentType]] = None, **kwargs: Any, ) -> MongoDBAtlasVectorSearch: """Construct MongoDBAtlasVectorSearch wrapper from raw documents. @@ -253,18 +253,17 @@ class MongoDBAtlasVectorSearch(VectorStore): from langchain.embeddings import OpenAIEmbeddings client = MongoClient("") - namespace = "." + collection = mongo_client[""][""] embeddings = OpenAIEmbeddings() vectorstore = MongoDBAtlasVectorSearch.from_texts( texts, embeddings, metadatas=metadatas, - client=client, - namespace=namespace + collection=collection ) """ - if not client or not namespace: - raise ValueError("Must provide 'client' and 'namespace' named parameters.") - vecstore = cls(client, namespace, embedding, **kwargs) + if not collection: + raise ValueError("Must provide 'collection' named parameter.") + vecstore = cls(collection, embedding, **kwargs) vecstore.add_texts(texts, metadatas=metadatas) return vecstore diff --git a/tests/integration_tests/vectorstores/test_mongodb_atlas.py b/tests/integration_tests/vectorstores/test_mongodb_atlas.py index 715b5a30..d36bb0e0 100644 --- a/tests/integration_tests/vectorstores/test_mongodb_atlas.py +++ b/tests/integration_tests/vectorstores/test_mongodb_atlas.py @@ -3,7 +3,7 @@ from __future__ import annotations import os from time import sleep -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import pytest @@ -19,37 +19,27 @@ NAMESPACE = "langchain_test_db.langchain_test_collection" CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") DB_NAME, COLLECTION_NAME = NAMESPACE.split(".") - -def get_test_client() -> Optional[MongoClient]: - try: - from pymongo import MongoClient - - client: MongoClient = MongoClient(CONNECTION_STRING) - return client - except: # noqa: E722 - return None - - # Instantiate as constant instead of pytest fixture to prevent needing to make multiple # connections. -TEST_CLIENT = get_test_client() +TEST_CLIENT = MongoClient(CONNECTION_STRING) +collection = TEST_CLIENT[DB_NAME][COLLECTION_NAME] class TestMongoDBAtlasVectorSearch: @classmethod def setup_class(cls) -> None: # insure the test collection is empty - assert TEST_CLIENT[DB_NAME][COLLECTION_NAME].count_documents({}) == 0 # type: ignore[index] # noqa: E501 + assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501 @classmethod def teardown_class(cls) -> None: # delete all the documents in the collection - TEST_CLIENT[DB_NAME][COLLECTION_NAME].delete_many({}) # type: ignore[index] + collection.delete_many({}) # type: ignore[index] @pytest.fixture(autouse=True) def setup(self) -> None: # delete all the documents in the collection - TEST_CLIENT[DB_NAME][COLLECTION_NAME].delete_many({}) # type: ignore[index] + collection.delete_many({}) # type: ignore[index] def test_from_documents(self, embedding_openai: Embeddings) -> None: """Test end to end construction and search.""" @@ -62,8 +52,7 @@ class TestMongoDBAtlasVectorSearch: vectorstore = MongoDBAtlasVectorSearch.from_documents( documents, embedding_openai, - client=TEST_CLIENT, - namespace=NAMESPACE, + collection=collection, index_name=INDEX_NAME, ) sleep(1) # waits for mongot to update Lucene's index @@ -81,8 +70,7 @@ class TestMongoDBAtlasVectorSearch: vectorstore = MongoDBAtlasVectorSearch.from_texts( texts, embedding_openai, - client=TEST_CLIENT, - namespace=NAMESPACE, + collection=collection, index_name=INDEX_NAME, ) sleep(1) # waits for mongot to update Lucene's index @@ -101,8 +89,7 @@ class TestMongoDBAtlasVectorSearch: texts, embedding_openai, metadatas=metadatas, - client=TEST_CLIENT, - namespace=NAMESPACE, + collection=collection, index_name=INDEX_NAME, ) sleep(1) # waits for mongot to update Lucene's index @@ -124,8 +111,7 @@ class TestMongoDBAtlasVectorSearch: texts, embedding_openai, metadatas=metadatas, - client=TEST_CLIENT, - namespace=NAMESPACE, + collection=collection, index_name=INDEX_NAME, ) sleep(1) # waits for mongot to update Lucene's index