diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index d8f74705..ec10519f 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -8,6 +8,30 @@ from langchain_community.embeddings import ( from langchain_openai import OpenAIEmbeddings from application.core.settings import settings +class EmbeddingsSingleton: + _instances = {} + + @staticmethod + def get_instance(embeddings_name, *args, **kwargs): + if embeddings_name not in EmbeddingsSingleton._instances: + EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs) + return EmbeddingsSingleton._instances[embeddings_name] + + @staticmethod + def _create_instance(embeddings_name, *args, **kwargs): + embeddings_factory = { + "openai_text-embedding-ada-002": OpenAIEmbeddings, + "huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings, + "huggingface_sentence-transformers-all-mpnet-base-v2": HuggingFaceEmbeddings, + "huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings, + "cohere_medium": CohereEmbeddings + } + + if embeddings_name not in embeddings_factory: + raise ValueError(f"Invalid embeddings_name: {embeddings_name}") + + return embeddings_factory[embeddings_name](*args, **kwargs) + class BaseVectorStore(ABC): def __init__(self): pass @@ -20,42 +44,36 @@ class BaseVectorStore(ABC): return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME def _get_embeddings(self, embeddings_name, embeddings_key=None): - embeddings_factory = { - "openai_text-embedding-ada-002": OpenAIEmbeddings, - "huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings, - "huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings, - "cohere_medium": CohereEmbeddings - } - - if embeddings_name not in embeddings_factory: - raise ValueError(f"Invalid embeddings_name: {embeddings_name}") - if embeddings_name == "openai_text-embedding-ada-002": if self.is_azure_configured(): os.environ["OPENAI_API_TYPE"] = "azure" - embedding_instance = embeddings_factory[embeddings_name]( + embedding_instance = EmbeddingsSingleton.get_instance( + embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME ) else: - embedding_instance = embeddings_factory[embeddings_name]( + embedding_instance = EmbeddingsSingleton.get_instance( + embeddings_name, openai_api_key=embeddings_key ) elif embeddings_name == "cohere_medium": - embedding_instance = embeddings_factory[embeddings_name]( + embedding_instance = EmbeddingsSingleton.get_instance( + embeddings_name, cohere_api_key=embeddings_key ) elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": if os.path.exists("./model/all-mpnet-base-v2"): - embedding_instance = embeddings_factory[embeddings_name]( + embedding_instance = EmbeddingsSingleton.get_instance( + embeddings_name, model_name="./model/all-mpnet-base-v2", - model_kwargs={"device": "cpu"}, + model_kwargs={"device": "cpu"} ) else: - embedding_instance = embeddings_factory[embeddings_name]( - model_kwargs={"device": "cpu"}, + embedding_instance = EmbeddingsSingleton.get_instance( + embeddings_name, + model_kwargs={"device": "cpu"} ) else: - embedding_instance = embeddings_factory[embeddings_name]() - - return embedding_instance + embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) + return embedding_instance \ No newline at end of file diff --git a/frontend/index.html b/frontend/index.html index 3717e3e0..5af1721b 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -1,13 +1,17 @@ - - - - DocsGPT 🦖 - - - -
- - - + + + + + + DocsGPT 🦖 + + + + +
+ + + + \ No newline at end of file diff --git a/frontend/src/Hero.tsx b/frontend/src/Hero.tsx index 69bf23ac..904dd279 100644 --- a/frontend/src/Hero.tsx +++ b/frontend/src/Hero.tsx @@ -4,7 +4,13 @@ import { useTranslation } from 'react-i18next'; export default function Hero({ handleQuestion, }: { - handleQuestion: (question: string) => void; + handleQuestion: ({ + question, + isRetry, + }: { + question: string; + isRetry?: boolean; + }) => void; }) { const { t } = useTranslation(); const demos = t('demo', { returnObjects: true }) as Array<{ @@ -23,14 +29,14 @@ export default function Hero({
-
+
{demos?.map( (demo: { header: string; query: string }, key: number) => demo.header && demo.query && ( + ); + responseView = ( + + ); } return responseView; }; @@ -137,13 +195,13 @@ export default function Conversation() {
{queries.length > 0 && !hasScrolledToLast && (
)} + {queries.length === 0 && }
-
-
+ +
+
{ if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); - if (inputRef.current?.textContent && status !== 'loading') { - handleQuestion(inputRef.current.textContent); - inputRef.current.textContent = ''; - } + handleQuestionSubmission(); } }} >
{status === 'loading' ? ( ) : (
{ - if (inputRef.current?.textContent) { - handleQuestion(inputRef.current.textContent); - inputRef.current.textContent = ''; - } - }} + onClick={handleQuestionSubmission} src={isDarkTheme ? SendDark : Send} >
)}
-

+ +

{t('tagline')}

diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index aa9c4e3d..3a3842b2 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -23,9 +23,10 @@ const ConversationBubble = forwardRef< feedback?: FEEDBACK; handleFeedback?: (feedback: FEEDBACK) => void; sources?: { title: string; text: string; source: string }[]; + retryBtn?: React.ReactElement; } >(function ConversationBubble( - { message, type, className, feedback, handleFeedback, sources }, + { message, type, className, feedback, handleFeedback, sources, retryBtn }, ref, ) { const [openSource, setOpenSource] = useState(null); @@ -69,12 +70,17 @@ const ConversationBubble = forwardRef<
{type === 'ERROR' && ( - alert + <> + alert +
+ {retryBtn} +
+ )} { label: 'Japanese', value: 'jp', }, + { + label: 'Mandarin', + value: 'zh', + }, ]; const chunks = ['0', '2', '4', '6', '8', '10']; const token_limits = new Map([