|
|
@ -10,6 +10,7 @@ from typing import (
|
|
|
|
List,
|
|
|
|
List,
|
|
|
|
Optional,
|
|
|
|
Optional,
|
|
|
|
Tuple,
|
|
|
|
Tuple,
|
|
|
|
|
|
|
|
TypeVar,
|
|
|
|
Union,
|
|
|
|
Union,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
@ -18,7 +19,9 @@ from langchain.embeddings.base import Embeddings
|
|
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from pymongo import MongoClient
|
|
|
|
from pymongo.collection import Collection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
@ -41,15 +44,14 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|
|
|
from pymongo import MongoClient
|
|
|
|
from pymongo import MongoClient
|
|
|
|
|
|
|
|
|
|
|
|
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
|
|
|
|
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
|
|
|
|
namespace = "<db_name>.<collection_name>"
|
|
|
|
collection = mongo_client["<db_name>"]["<collection_name>"]
|
|
|
|
embeddings = OpenAIEmbeddings()
|
|
|
|
embeddings = OpenAIEmbeddings()
|
|
|
|
vectorstore = MongoDBAtlasVectorSearch(mongo_client, namespace, embeddings)
|
|
|
|
vectorstore = MongoDBAtlasVectorSearch(collection, embeddings)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
client: MongoClient,
|
|
|
|
collection: Collection[MongoDBDocumentType],
|
|
|
|
namespace: str,
|
|
|
|
|
|
|
|
embedding: Embeddings,
|
|
|
|
embedding: Embeddings,
|
|
|
|
*,
|
|
|
|
*,
|
|
|
|
index_name: str = "default",
|
|
|
|
index_name: str = "default",
|
|
|
@ -58,17 +60,14 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|
|
|
):
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
client: MongoDB client.
|
|
|
|
collection: MongoDB collection to add the texts to.
|
|
|
|
namespace: MongoDB namespace to add the texts to.
|
|
|
|
|
|
|
|
embedding: Text embedding model to use.
|
|
|
|
embedding: Text embedding model to use.
|
|
|
|
text_key: MongoDB field that will contain the text for each
|
|
|
|
text_key: MongoDB field that will contain the text for each
|
|
|
|
document.
|
|
|
|
document.
|
|
|
|
embedding_key: MongoDB field that will contain the embedding for
|
|
|
|
embedding_key: MongoDB field that will contain the embedding for
|
|
|
|
each document.
|
|
|
|
each document.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
self._client = client
|
|
|
|
self._collection = collection
|
|
|
|
db_name, collection_name = namespace.split(".")
|
|
|
|
|
|
|
|
self._collection = client[db_name][collection_name]
|
|
|
|
|
|
|
|
self._embedding = embedding
|
|
|
|
self._embedding = embedding
|
|
|
|
self._index_name = index_name
|
|
|
|
self._index_name = index_name
|
|
|
|
self._text_key = text_key
|
|
|
|
self._text_key = text_key
|
|
|
@ -90,7 +89,9 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|
|
|
"`pip install pymongo`."
|
|
|
|
"`pip install pymongo`."
|
|
|
|
)
|
|
|
|
)
|
|
|
|
client: MongoClient = MongoClient(connection_string)
|
|
|
|
client: MongoClient = MongoClient(connection_string)
|
|
|
|
return cls(client, namespace, embedding, **kwargs)
|
|
|
|
db_name, collection_name = namespace.split(".")
|
|
|
|
|
|
|
|
collection = client[db_name][collection_name]
|
|
|
|
|
|
|
|
return cls(collection, embedding, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def add_texts(
|
|
|
|
def add_texts(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
@ -232,8 +233,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|
|
|
texts: List[str],
|
|
|
|
texts: List[str],
|
|
|
|
embedding: Embeddings,
|
|
|
|
embedding: Embeddings,
|
|
|
|
metadatas: Optional[List[dict]] = None,
|
|
|
|
metadatas: Optional[List[dict]] = None,
|
|
|
|
client: Optional[MongoClient] = None,
|
|
|
|
collection: Optional[Collection[MongoDBDocumentType]] = None,
|
|
|
|
namespace: Optional[str] = None,
|
|
|
|
|
|
|
|
**kwargs: Any,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> MongoDBAtlasVectorSearch:
|
|
|
|
) -> MongoDBAtlasVectorSearch:
|
|
|
|
"""Construct MongoDBAtlasVectorSearch wrapper from raw documents.
|
|
|
|
"""Construct MongoDBAtlasVectorSearch wrapper from raw documents.
|
|
|
@ -253,18 +253,17 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
|
|
|
|
|
|
|
|
client = MongoClient("<YOUR-CONNECTION-STRING>")
|
|
|
|
client = MongoClient("<YOUR-CONNECTION-STRING>")
|
|
|
|
namespace = "<db_name>.<collection_name>"
|
|
|
|
collection = mongo_client["<db_name>"]["<collection_name>"]
|
|
|
|
embeddings = OpenAIEmbeddings()
|
|
|
|
embeddings = OpenAIEmbeddings()
|
|
|
|
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
|
|
|
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
|
|
|
texts,
|
|
|
|
texts,
|
|
|
|
embeddings,
|
|
|
|
embeddings,
|
|
|
|
metadatas=metadatas,
|
|
|
|
metadatas=metadatas,
|
|
|
|
client=client,
|
|
|
|
collection=collection
|
|
|
|
namespace=namespace
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if not client or not namespace:
|
|
|
|
if not collection:
|
|
|
|
raise ValueError("Must provide 'client' and 'namespace' named parameters.")
|
|
|
|
raise ValueError("Must provide 'collection' named parameter.")
|
|
|
|
vecstore = cls(client, namespace, embedding, **kwargs)
|
|
|
|
vecstore = cls(collection, embedding, **kwargs)
|
|
|
|
vecstore.add_texts(texts, metadatas=metadatas)
|
|
|
|
vecstore.add_texts(texts, metadatas=metadatas)
|
|
|
|
return vecstore
|
|
|
|
return vecstore
|
|
|
|