DocsGPT/application/vectorstore/base.py

57 lines
2.1 KiB
Python
Raw Normal View History

2023-09-27 15:25:57 +00:00
from abc import ABC, abstractmethod
import os
from langchain_community.embeddings import (
2023-10-01 16:20:47 +00:00
HuggingFaceEmbeddings,
2023-09-27 15:25:57 +00:00
CohereEmbeddings,
HuggingFaceInstructEmbeddings,
)
from langchain_openai import OpenAIEmbeddings
2023-09-27 15:25:57 +00:00
from application.core.settings import settings
class BaseVectorStore(ABC):
def __init__(self):
pass
@abstractmethod
def search(self, *args, **kwargs):
pass
def is_azure_configured(self):
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
2023-09-28 23:32:19 +00:00
def _get_embeddings(self, embeddings_name, embeddings_key=None):
2023-09-27 15:25:57 +00:00
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
2023-10-01 16:20:47 +00:00
"huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings,
2023-09-27 15:25:57 +00:00
"huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings,
"cohere_medium": CohereEmbeddings
}
if embeddings_name not in embeddings_factory:
raise ValueError(f"Invalid embeddings_name: {embeddings_name}")
if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"
2023-09-27 17:01:40 +00:00
embedding_instance = embeddings_factory[embeddings_name](
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
)
2023-09-27 15:25:57 +00:00
else:
2023-09-27 17:01:40 +00:00
embedding_instance = embeddings_factory[embeddings_name](
openai_api_key=embeddings_key
)
2023-09-27 15:25:57 +00:00
elif embeddings_name == "cohere_medium":
2023-09-27 17:01:40 +00:00
embedding_instance = embeddings_factory[embeddings_name](
cohere_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
embedding_instance = embeddings_factory[embeddings_name](
#model_name="./model/all-mpnet-base-v2",
model_kwargs={"device": "cpu"},
)
2023-09-27 15:25:57 +00:00
else:
embedding_instance = embeddings_factory[embeddings_name]()
return embedding_instance