chore: Refactor embeddings instantiation to use a singleton pattern

pull/996/head
Alex 1 month ago
parent 558ecd84a6
commit 3454309cbc

@ -8,21 +8,21 @@ from langchain_community.embeddings import (
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
from application.core.settings import settings from application.core.settings import settings
class BaseVectorStore(ABC): class EmbeddingsSingleton:
def __init__(self): _instances = {}
pass
@abstractmethod
def search(self, *args, **kwargs):
pass
def is_azure_configured(self): @staticmethod
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME def get_instance(embeddings_name, *args, **kwargs):
if embeddings_name not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
return EmbeddingsSingleton._instances[embeddings_name]
def _get_embeddings(self, embeddings_name, embeddings_key=None): @staticmethod
def _create_instance(embeddings_name, *args, **kwargs):
embeddings_factory = { embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings, "openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings, "huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_sentence-transformers-all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings, "huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings,
"cohere_medium": CohereEmbeddings "cohere_medium": CohereEmbeddings
} }
@ -30,32 +30,50 @@ class BaseVectorStore(ABC):
if embeddings_name not in embeddings_factory: if embeddings_name not in embeddings_factory:
raise ValueError(f"Invalid embeddings_name: {embeddings_name}") raise ValueError(f"Invalid embeddings_name: {embeddings_name}")
return embeddings_factory[embeddings_name](*args, **kwargs)
class BaseVectorStore(ABC):
def __init__(self):
pass
@abstractmethod
def search(self, *args, **kwargs):
pass
def is_azure_configured(self):
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
def _get_embeddings(self, embeddings_name, embeddings_key=None):
if embeddings_name == "openai_text-embedding-ada-002": if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured(): if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure" os.environ["OPENAI_API_TYPE"] = "azure"
embedding_instance = embeddings_factory[embeddings_name]( embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
) )
else: else:
embedding_instance = embeddings_factory[embeddings_name]( embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
openai_api_key=embeddings_key openai_api_key=embeddings_key
) )
elif embeddings_name == "cohere_medium": elif embeddings_name == "cohere_medium":
embedding_instance = embeddings_factory[embeddings_name]( embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
cohere_api_key=embeddings_key cohere_api_key=embeddings_key
) )
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./model/all-mpnet-base-v2"): if os.path.exists("./model/all-mpnet-base-v2"):
embedding_instance = embeddings_factory[embeddings_name]( embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model_name="./model/all-mpnet-base-v2", model_name="./model/all-mpnet-base-v2",
model_kwargs={"device": "cpu"}, model_kwargs={"device": "cpu"}
) )
else: else:
embedding_instance = embeddings_factory[embeddings_name]( embedding_instance = EmbeddingsSingleton.get_instance(
model_kwargs={"device": "cpu"}, embeddings_name,
model_kwargs={"device": "cpu"}
) )
else: else:
embedding_instance = embeddings_factory[embeddings_name]() embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
return embedding_instance return embedding_instance

Loading…
Cancel
Save