diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index d8f7470..ec10519 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -8,6 +8,30 @@ from langchain_community.embeddings import ( from langchain_openai import OpenAIEmbeddings from application.core.settings import settings +class EmbeddingsSingleton: + _instances = {} + + @staticmethod + def get_instance(embeddings_name, *args, **kwargs): + if embeddings_name not in EmbeddingsSingleton._instances: + EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs) + return EmbeddingsSingleton._instances[embeddings_name] + + @staticmethod + def _create_instance(embeddings_name, *args, **kwargs): + embeddings_factory = { + "openai_text-embedding-ada-002": OpenAIEmbeddings, + "huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings, + "huggingface_sentence-transformers-all-mpnet-base-v2": HuggingFaceEmbeddings, + "huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings, + "cohere_medium": CohereEmbeddings + } + + if embeddings_name not in embeddings_factory: + raise ValueError(f"Invalid embeddings_name: {embeddings_name}") + + return embeddings_factory[embeddings_name](*args, **kwargs) + class BaseVectorStore(ABC): def __init__(self): pass @@ -20,42 +44,36 @@ class BaseVectorStore(ABC): return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME def _get_embeddings(self, embeddings_name, embeddings_key=None): - embeddings_factory = { - "openai_text-embedding-ada-002": OpenAIEmbeddings, - "huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings, - "huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings, - "cohere_medium": CohereEmbeddings - } - - if embeddings_name not in embeddings_factory: - raise ValueError(f"Invalid embeddings_name: {embeddings_name}") - if embeddings_name == "openai_text-embedding-ada-002": if self.is_azure_configured(): os.environ["OPENAI_API_TYPE"] = "azure" - embedding_instance = embeddings_factory[embeddings_name]( + embedding_instance = EmbeddingsSingleton.get_instance( + embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME ) else: - embedding_instance = embeddings_factory[embeddings_name]( + embedding_instance = EmbeddingsSingleton.get_instance( + embeddings_name, openai_api_key=embeddings_key ) elif embeddings_name == "cohere_medium": - embedding_instance = embeddings_factory[embeddings_name]( + embedding_instance = EmbeddingsSingleton.get_instance( + embeddings_name, cohere_api_key=embeddings_key ) elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": if os.path.exists("./model/all-mpnet-base-v2"): - embedding_instance = embeddings_factory[embeddings_name]( + embedding_instance = EmbeddingsSingleton.get_instance( + embeddings_name, model_name="./model/all-mpnet-base-v2", - model_kwargs={"device": "cpu"}, + model_kwargs={"device": "cpu"} ) else: - embedding_instance = embeddings_factory[embeddings_name]( - model_kwargs={"device": "cpu"}, + embedding_instance = EmbeddingsSingleton.get_instance( + embeddings_name, + model_kwargs={"device": "cpu"} ) else: - embedding_instance = embeddings_factory[embeddings_name]() - - return embedding_instance + embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) + return embedding_instance \ No newline at end of file