You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DocsGPT/application/vectorstore/base.py

79 lines
3.1 KiB
Python

1 year ago
from abc import ABC, abstractmethod
import os
from langchain_community.embeddings import (
1 year ago
HuggingFaceEmbeddings,
1 year ago
CohereEmbeddings,
HuggingFaceInstructEmbeddings,
)
from langchain_openai import OpenAIEmbeddings
1 year ago
from application.core.settings import settings
class EmbeddingsSingleton:
_instances = {}
@staticmethod
def get_instance(embeddings_name, *args, **kwargs):
if embeddings_name not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
return EmbeddingsSingleton._instances[embeddings_name]
@staticmethod
def _create_instance(embeddings_name, *args, **kwargs):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_sentence-transformers-all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings,
"cohere_medium": CohereEmbeddings
}
if embeddings_name not in embeddings_factory:
raise ValueError(f"Invalid embeddings_name: {embeddings_name}")
return embeddings_factory[embeddings_name](*args, **kwargs)
1 year ago
class BaseVectorStore(ABC):
def __init__(self):
pass
@abstractmethod
def search(self, *args, **kwargs):
pass
def is_azure_configured(self):
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
1 year ago
def _get_embeddings(self, embeddings_name, embeddings_key=None):
1 year ago
if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
1 year ago
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
)
1 year ago
else:
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
1 year ago
openai_api_key=embeddings_key
)
1 year ago
elif embeddings_name == "cohere_medium":
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
1 year ago
cohere_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./model/all-mpnet-base-v2"):
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model_name="./model/all-mpnet-base-v2",
model_kwargs={"device": "cpu"}
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model_kwargs={"device": "cpu"}
)
1 year ago
else:
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
1 year ago
return embedding_instance