@ -9,6 +9,7 @@ import numpy as np
from langchain . docstore . document import Document
from langchain . docstore . document import Document
from langchain . embeddings . base import Embeddings
from langchain . embeddings . base import Embeddings
from langchain . utils import xor_args
from langchain . vectorstores . base import VectorStore
from langchain . vectorstores . base import VectorStore
from langchain . vectorstores . utils import maximal_marginal_relevance
from langchain . vectorstores . utils import maximal_marginal_relevance
@ -96,6 +97,32 @@ class Chroma(VectorStore):
metadata = collection_metadata ,
metadata = collection_metadata ,
)
)
@xor_args ( ( " query_texts " , " query_embeddings " ) )
def __query_collection (
self ,
query_texts : Optional [ List [ str ] ] = None ,
query_embeddings : Optional [ List [ List [ float ] ] ] = None ,
n_results : int = 4 ,
where : Optional [ Dict [ str , str ] ] = None ,
) - > List [ Document ] :
""" Query the chroma collection. """
for i in range ( n_results , 0 , - 1 ) :
try :
return self . _collection . query (
query_texts = query_texts ,
query_embeddings = query_embeddings ,
n_results = n_results ,
where = where ,
)
except chromadb . errors . NotEnoughElementsException :
logger . error (
f " Chroma collection { self . _collection . name } "
f " contains fewer than { i } elements. "
)
raise chromadb . errors . NotEnoughElementsException (
f " No documents found for Chroma collection { self . _collection . name } "
)
def add_texts (
def add_texts (
self ,
self ,
texts : Iterable [ str ] ,
texts : Iterable [ str ] ,
@ -158,7 +185,7 @@ class Chroma(VectorStore):
Returns :
Returns :
List of Documents most similar to the query vector .
List of Documents most similar to the query vector .
"""
"""
results = self . _ collection. query (
results = self . _ _query_ collection(
query_embeddings = embedding , n_results = k , where = filter
query_embeddings = embedding , n_results = k , where = filter
)
)
return _results_to_docs ( results )
return _results_to_docs ( results )
@ -182,12 +209,12 @@ class Chroma(VectorStore):
text with distance in float .
text with distance in float .
"""
"""
if self . _embedding_function is None :
if self . _embedding_function is None :
results = self . _ collection. query (
results = self . _ _query_ collection(
query_texts = [ query ] , n_results = k , where = filter
query_texts = [ query ] , n_results = k , where = filter
)
)
else :
else :
query_embedding = self . _embedding_function . embed_query ( query )
query_embedding = self . _embedding_function . embed_query ( query )
results = self . _ collection. query (
results = self . _ _query_ collection(
query_embeddings = [ query_embedding ] , n_results = k , where = filter
query_embeddings = [ query_embedding ] , n_results = k , where = filter
)
)
@ -218,7 +245,7 @@ class Chroma(VectorStore):
List of Documents selected by maximal marginal relevance .
List of Documents selected by maximal marginal relevance .
"""
"""
results = self . _ collection. query (
results = self . _ _query_ collection(
query_embeddings = embedding ,
query_embeddings = embedding ,
n_results = fetch_k ,
n_results = fetch_k ,
where = filter ,
where = filter ,