mirror of
https://github.com/arc53/DocsGPT
synced 2024-11-17 21:26:26 +00:00
Merge pull request #588 from asoderlind/fix/as/embedding-size-mismatch
raise more legible error if the word embedding dimensions don't match
This commit is contained in:
commit
d899b6a7e1
@ -104,3 +104,4 @@ urllib3==1.26.17
|
||||
vine==5.0.0
|
||||
wcwidth==0.2.6
|
||||
yarl==1.8.2
|
||||
sentence-transformers==2.2.2
|
@ -1,5 +1,5 @@
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
from langchain.vectorstores import FAISS
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
from application.core.settings import settings
|
||||
|
||||
class FaissStore(BaseVectorStore):
|
||||
@ -7,14 +7,16 @@ class FaissStore(BaseVectorStore):
|
||||
def __init__(self, path, embeddings_key, docs_init=None):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||
if docs_init:
|
||||
self.docsearch = FAISS.from_documents(
|
||||
docs_init, self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||
docs_init, embeddings
|
||||
)
|
||||
else:
|
||||
self.docsearch = FAISS.load_local(
|
||||
self.path, self._get_embeddings(settings.EMBEDDINGS_NAME, settings.EMBEDDINGS_KEY)
|
||||
self.path, embeddings
|
||||
)
|
||||
self.assert_embedding_dimensions(embeddings)
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
return self.docsearch.similarity_search(*args, **kwargs)
|
||||
@ -24,3 +26,19 @@ class FaissStore(BaseVectorStore):
|
||||
|
||||
def save_local(self, *args, **kwargs):
|
||||
return self.docsearch.save_local(*args, **kwargs)
|
||||
|
||||
def assert_embedding_dimensions(self, embeddings):
|
||||
"""
|
||||
Check that the word embedding dimension of the docsearch index matches
|
||||
the dimension of the word embeddings used
|
||||
"""
|
||||
if settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
try:
|
||||
word_embedding_dimension = embeddings.client[1].word_embedding_dimension
|
||||
except AttributeError as e:
|
||||
raise AttributeError("word_embedding_dimension not found in embeddings.client[1]") from e
|
||||
docsearch_index_dimension = self.docsearch.index.d
|
||||
if word_embedding_dimension != docsearch_index_dimension:
|
||||
raise ValueError(f"word_embedding_dimension ({word_embedding_dimension}) " +
|
||||
f"!= docsearch_index_word_embedding_dimension ({docsearch_index_dimension})")
|
||||
|
||||
|
19
tests/test_vector_store.py
Normal file
19
tests/test_vector_store.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""
|
||||
Tests regarding the vector store class, including checking
|
||||
compatibility between different transformers and local vector
|
||||
stores (index.faiss)
|
||||
"""
|
||||
import pytest
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
from application.core.settings import settings
|
||||
|
||||
def test_init_local_faiss_store_huggingface():
|
||||
"""
|
||||
Test that asserts that trying to initialize a FaissStore with
|
||||
the huggingface sentence transformer below together with the
|
||||
index.faiss file in the application/ folder results in a
|
||||
dimension mismatch error.
|
||||
"""
|
||||
settings.EMBEDDINGS_NAME = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
with pytest.raises(ValueError):
|
||||
FaissStore("application/", "", None)
|
Loading…
Reference in New Issue
Block a user