chore: Refactor embeddings instantiation to use a singleton pattern

This commit is contained in:
Alex 2024-06-14 12:58:35 +01:00
parent 558ecd84a6
commit 3454309cbc

View File

@ -8,6 +8,30 @@ from langchain_community.embeddings import (
from langchain_openai import OpenAIEmbeddings
from application.core.settings import settings
class EmbeddingsSingleton:
_instances = {}
@staticmethod
def get_instance(embeddings_name, *args, **kwargs):
if embeddings_name not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
return EmbeddingsSingleton._instances[embeddings_name]
@staticmethod
def _create_instance(embeddings_name, *args, **kwargs):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_sentence-transformers-all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings,
"cohere_medium": CohereEmbeddings
}
if embeddings_name not in embeddings_factory:
raise ValueError(f"Invalid embeddings_name: {embeddings_name}")
return embeddings_factory[embeddings_name](*args, **kwargs)
class BaseVectorStore(ABC):
def __init__(self):
pass
@ -20,42 +44,36 @@ class BaseVectorStore(ABC):
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
def _get_embeddings(self, embeddings_name, embeddings_key=None):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings,
"cohere_medium": CohereEmbeddings
}
if embeddings_name not in embeddings_factory:
raise ValueError(f"Invalid embeddings_name: {embeddings_name}")
if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"
embedding_instance = embeddings_factory[embeddings_name](
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
)
else:
embedding_instance = embeddings_factory[embeddings_name](
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
openai_api_key=embeddings_key
)
elif embeddings_name == "cohere_medium":
embedding_instance = embeddings_factory[embeddings_name](
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
cohere_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./model/all-mpnet-base-v2"):
embedding_instance = embeddings_factory[embeddings_name](
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model_name="./model/all-mpnet-base-v2",
model_kwargs={"device": "cpu"},
model_kwargs={"device": "cpu"}
)
else:
embedding_instance = embeddings_factory[embeddings_name](
model_kwargs={"device": "cpu"},
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model_kwargs={"device": "cpu"}
)
else:
embedding_instance = embeddings_factory[embeddings_name]()
return embedding_instance
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
return embedding_instance