removing client+namespace in favor of collection (#5610)

removing client+namespace in favor of collection for an easier
instantiation and to be similar to the typescript library

@dev2049
searx_updates
Paul-Emile Brotons 1 year ago committed by GitHub
parent ad09367a92
commit 92f218207b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -118,15 +118,14 @@
"\n", "\n",
"db_name = \"lanchain_db\"\n", "db_name = \"lanchain_db\"\n",
"collection_name = \"langchain_col\"\n", "collection_name = \"langchain_col\"\n",
"namespace = f\"{db_name}.{collection_name}\"\n", "collection = client[db_name][collection_name]\n",
"index_name = \"langchain_demo\"\n", "index_name = \"langchain_demo\"\n",
"\n", "\n",
"# insert the documents in MongoDB Atlas with their embedding\n", "# insert the documents in MongoDB Atlas with their embedding\n",
"docsearch = MongoDBAtlasVectorSearch.from_documents(\n", "docsearch = MongoDBAtlasVectorSearch.from_documents(\n",
" docs,\n", " docs,\n",
" embeddings,\n", " embeddings,\n",
" client=client,\n", " collection=collection,\n",
" namespace=namespace,\n",
" index_name=index_name\n", " index_name=index_name\n",
")\n", ")\n",
"\n", "\n",

@ -10,6 +10,7 @@ from typing import (
List, List,
Optional, Optional,
Tuple, Tuple,
TypeVar,
Union, Union,
) )
@ -18,7 +19,9 @@ from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
if TYPE_CHECKING: if TYPE_CHECKING:
from pymongo import MongoClient from pymongo.collection import Collection
MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,15 +44,14 @@ class MongoDBAtlasVectorSearch(VectorStore):
from pymongo import MongoClient from pymongo import MongoClient
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>") mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
namespace = "<db_name>.<collection_name>" collection = mongo_client["<db_name>"]["<collection_name>"]
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
vectorstore = MongoDBAtlasVectorSearch(mongo_client, namespace, embeddings) vectorstore = MongoDBAtlasVectorSearch(collection, embeddings)
""" """
def __init__( def __init__(
self, self,
client: MongoClient, collection: Collection[MongoDBDocumentType],
namespace: str,
embedding: Embeddings, embedding: Embeddings,
*, *,
index_name: str = "default", index_name: str = "default",
@ -58,17 +60,14 @@ class MongoDBAtlasVectorSearch(VectorStore):
): ):
""" """
Args: Args:
client: MongoDB client. collection: MongoDB collection to add the texts to.
namespace: MongoDB namespace to add the texts to.
embedding: Text embedding model to use. embedding: Text embedding model to use.
text_key: MongoDB field that will contain the text for each text_key: MongoDB field that will contain the text for each
document. document.
embedding_key: MongoDB field that will contain the embedding for embedding_key: MongoDB field that will contain the embedding for
each document. each document.
""" """
self._client = client self._collection = collection
db_name, collection_name = namespace.split(".")
self._collection = client[db_name][collection_name]
self._embedding = embedding self._embedding = embedding
self._index_name = index_name self._index_name = index_name
self._text_key = text_key self._text_key = text_key
@ -90,7 +89,9 @@ class MongoDBAtlasVectorSearch(VectorStore):
"`pip install pymongo`." "`pip install pymongo`."
) )
client: MongoClient = MongoClient(connection_string) client: MongoClient = MongoClient(connection_string)
return cls(client, namespace, embedding, **kwargs) db_name, collection_name = namespace.split(".")
collection = client[db_name][collection_name]
return cls(collection, embedding, **kwargs)
def add_texts( def add_texts(
self, self,
@ -232,8 +233,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
client: Optional[MongoClient] = None, collection: Optional[Collection[MongoDBDocumentType]] = None,
namespace: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> MongoDBAtlasVectorSearch: ) -> MongoDBAtlasVectorSearch:
"""Construct MongoDBAtlasVectorSearch wrapper from raw documents. """Construct MongoDBAtlasVectorSearch wrapper from raw documents.
@ -253,18 +253,17 @@ class MongoDBAtlasVectorSearch(VectorStore):
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
client = MongoClient("<YOUR-CONNECTION-STRING>") client = MongoClient("<YOUR-CONNECTION-STRING>")
namespace = "<db_name>.<collection_name>" collection = mongo_client["<db_name>"]["<collection_name>"]
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
vectorstore = MongoDBAtlasVectorSearch.from_texts( vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts, texts,
embeddings, embeddings,
metadatas=metadatas, metadatas=metadatas,
client=client, collection=collection
namespace=namespace
) )
""" """
if not client or not namespace: if not collection:
raise ValueError("Must provide 'client' and 'namespace' named parameters.") raise ValueError("Must provide 'collection' named parameter.")
vecstore = cls(client, namespace, embedding, **kwargs) vecstore = cls(collection, embedding, **kwargs)
vecstore.add_texts(texts, metadatas=metadatas) vecstore.add_texts(texts, metadatas=metadatas)
return vecstore return vecstore

@ -3,7 +3,7 @@ from __future__ import annotations
import os import os
from time import sleep from time import sleep
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import pytest import pytest
@ -19,37 +19,27 @@ NAMESPACE = "langchain_test_db.langchain_test_collection"
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".") DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
def get_test_client() -> Optional[MongoClient]:
try:
from pymongo import MongoClient
client: MongoClient = MongoClient(CONNECTION_STRING)
return client
except: # noqa: E722
return None
# Instantiate as constant instead of pytest fixture to prevent needing to make multiple # Instantiate as constant instead of pytest fixture to prevent needing to make multiple
# connections. # connections.
TEST_CLIENT = get_test_client() TEST_CLIENT = MongoClient(CONNECTION_STRING)
collection = TEST_CLIENT[DB_NAME][COLLECTION_NAME]
class TestMongoDBAtlasVectorSearch: class TestMongoDBAtlasVectorSearch:
@classmethod @classmethod
def setup_class(cls) -> None: def setup_class(cls) -> None:
# insure the test collection is empty # insure the test collection is empty
assert TEST_CLIENT[DB_NAME][COLLECTION_NAME].count_documents({}) == 0 # type: ignore[index] # noqa: E501 assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501
@classmethod @classmethod
def teardown_class(cls) -> None: def teardown_class(cls) -> None:
# delete all the documents in the collection # delete all the documents in the collection
TEST_CLIENT[DB_NAME][COLLECTION_NAME].delete_many({}) # type: ignore[index] collection.delete_many({}) # type: ignore[index]
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup(self) -> None: def setup(self) -> None:
# delete all the documents in the collection # delete all the documents in the collection
TEST_CLIENT[DB_NAME][COLLECTION_NAME].delete_many({}) # type: ignore[index] collection.delete_many({}) # type: ignore[index]
def test_from_documents(self, embedding_openai: Embeddings) -> None: def test_from_documents(self, embedding_openai: Embeddings) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
@ -62,8 +52,7 @@ class TestMongoDBAtlasVectorSearch:
vectorstore = MongoDBAtlasVectorSearch.from_documents( vectorstore = MongoDBAtlasVectorSearch.from_documents(
documents, documents,
embedding_openai, embedding_openai,
client=TEST_CLIENT, collection=collection,
namespace=NAMESPACE,
index_name=INDEX_NAME, index_name=INDEX_NAME,
) )
sleep(1) # waits for mongot to update Lucene's index sleep(1) # waits for mongot to update Lucene's index
@ -81,8 +70,7 @@ class TestMongoDBAtlasVectorSearch:
vectorstore = MongoDBAtlasVectorSearch.from_texts( vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts, texts,
embedding_openai, embedding_openai,
client=TEST_CLIENT, collection=collection,
namespace=NAMESPACE,
index_name=INDEX_NAME, index_name=INDEX_NAME,
) )
sleep(1) # waits for mongot to update Lucene's index sleep(1) # waits for mongot to update Lucene's index
@ -101,8 +89,7 @@ class TestMongoDBAtlasVectorSearch:
texts, texts,
embedding_openai, embedding_openai,
metadatas=metadatas, metadatas=metadatas,
client=TEST_CLIENT, collection=collection,
namespace=NAMESPACE,
index_name=INDEX_NAME, index_name=INDEX_NAME,
) )
sleep(1) # waits for mongot to update Lucene's index sleep(1) # waits for mongot to update Lucene's index
@ -124,8 +111,7 @@ class TestMongoDBAtlasVectorSearch:
texts, texts,
embedding_openai, embedding_openai,
metadatas=metadatas, metadatas=metadatas,
client=TEST_CLIENT, collection=collection,
namespace=NAMESPACE,
index_name=INDEX_NAME, index_name=INDEX_NAME,
) )
sleep(1) # waits for mongot to update Lucene's index sleep(1) # waits for mongot to update Lucene's index

Loading…
Cancel
Save