From 220d137e668f3f2d4ee8770e0b701a4ba7c2ad20 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Sun, 26 May 2024 23:13:01 +0530 Subject: [PATCH] feat: dropdown to adjust conversational history limits --- application/api/answer/routes.py | 24 ++++++++++--- application/core/settings.py | 3 +- application/retriever/brave_search.py | 16 ++++++--- application/retriever/classic_rag.py | 16 ++++++--- application/retriever/duckduck_search.py | 16 ++++++--- frontend/src/components/Dropdown.tsx | 28 ++++++++++++--- frontend/src/conversation/conversationApi.ts | 7 ++++ .../src/conversation/conversationSlice.ts | 3 ++ frontend/src/preferences/preferenceSlice.ts | 22 +++++++++++- frontend/src/settings/General.tsx | 36 +++++++++++++++++++ frontend/src/store.ts | 8 +++-- 11 files changed, 152 insertions(+), 27 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 9304f20..04acc43 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -78,7 +78,7 @@ def get_data_from_api_key(api_key): if data is None: return bad_request(401, "Invalid API key") return data - + def get_vectorstore(data): if "active_docs" in data: @@ -95,6 +95,7 @@ def get_vectorstore(data): vectorstore = os.path.join("application", vectorstore) return vectorstore + def is_azure_configured(): return ( settings.OPENAI_API_BASE @@ -221,7 +222,10 @@ def stream(): chunks = int(data["chunks"]) else: chunks = 2 - + if "token_limit" in data: + token_limit = data["token_limit"] + else: + token_limit = settings.DEFAULT_MAX_HISTORY # check if active_docs or api_key is set @@ -255,6 +259,7 @@ def stream(): chat_history=history, prompt=prompt, chunks=chunks, + token_limit=token_limit, gpt_model=gpt_model, user_api_key=user_api_key, ) @@ -291,6 +296,10 @@ def api_answer(): chunks = int(data["chunks"]) else: chunks = 2 + if "token_limit" in data: + token_limit = data["token_limit"] + else: + token_limit = settings.DEFAULT_MAX_HISTORY # use try and except to check for exception try: @@ -314,7 +323,7 @@ def api_answer(): retriever_name = source["active_docs"] prompt = get_prompt(prompt_id) - + retriever = RetrieverCreator.create_retriever( retriever_name, question=question, @@ -322,6 +331,7 @@ def api_answer(): chat_history=history, prompt=prompt, chunks=chunks, + token_limit=token_limit, gpt_model=gpt_model, user_api_key=user_api_key, ) @@ -370,7 +380,6 @@ def api_search(): else: source = {} user_api_key = None - if ( source["active_docs"].split("/")[0] == "default" @@ -379,6 +388,10 @@ def api_search(): retriever_name = "classic" else: retriever_name = source["active_docs"] + if "token_limit" in data: + token_limit = data["token_limit"] + else: + token_limit = settings.DEFAULT_MAX_HISTORY retriever = RetrieverCreator.create_retriever( retriever_name, @@ -387,8 +400,9 @@ def api_search(): chat_history=[], prompt="default", chunks=chunks, + token_limit=token_limit, gpt_model=gpt_model, user_api_key=user_api_key, ) docs = retriever.search() - return docs \ No newline at end of file + return docs diff --git a/application/core/settings.py b/application/core/settings.py index 26c27ed..6ae5475 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -15,7 +15,8 @@ class Settings(BaseSettings): CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" MONGO_URI: str = "mongodb://localhost:27017/docsgpt" MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") - TOKENS_MAX_HISTORY: int = 150 + DEFAULT_MAX_HISTORY: int = 150 + MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5} UPLOAD_FOLDER: str = "inputs" VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 47ca0e7..70dbbf2 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -15,6 +15,7 @@ class BraveRetSearch(BaseRetriever): chat_history, prompt, chunks=2, + token_limit=150, gpt_model="docsgpt", user_api_key=None, ): @@ -24,6 +25,16 @@ class BraveRetSearch(BaseRetriever): self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model + self.token_limit = ( + token_limit + if token_limit + < settings.MODEL_TOKEN_LIMITS.get( + self.gpt_model, settings.DEFAULT_MAX_HISTORY + ) + else settings.MODEL_TOKEN_LIMITS.get( + self.gpt_model, settings.DEFAULT_MAX_HISTORY + ) + ) self.user_api_key = user_api_key def _get_data(self): @@ -70,10 +81,7 @@ class BraveRetSearch(BaseRetriever): tokens_batch = count_tokens(i["prompt"]) + count_tokens( i["response"] ) - if ( - tokens_current_history + tokens_batch - < settings.TOKENS_MAX_HISTORY - ): + if tokens_current_history + tokens_batch < self.token_limit: tokens_current_history += tokens_batch messages_combine.append( {"role": "user", "content": i["prompt"]} diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 1bce6f8..3eb0f20 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -16,6 +16,7 @@ class ClassicRAG(BaseRetriever): chat_history, prompt, chunks=2, + token_limit=150, gpt_model="docsgpt", user_api_key=None, ): @@ -25,6 +26,16 @@ class ClassicRAG(BaseRetriever): self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model + self.token_limit = ( + token_limit + if token_limit + < settings.MODEL_TOKEN_LIMITS.get( + self.gpt_model, settings.DEFAULT_MAX_HISTORY + ) + else settings.MODEL_TOKEN_LIMITS.get( + self.gpt_model, settings.DEFAULT_MAX_HISTORY + ) + ) self.user_api_key = user_api_key def _get_vectorstore(self, source): @@ -85,10 +96,7 @@ class ClassicRAG(BaseRetriever): tokens_batch = count_tokens(i["prompt"]) + count_tokens( i["response"] ) - if ( - tokens_current_history + tokens_batch - < settings.TOKENS_MAX_HISTORY - ): + if tokens_current_history + tokens_batch < self.token_limit: tokens_current_history += tokens_batch messages_combine.append( {"role": "user", "content": i["prompt"]} diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index 9189298..bee74e2 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -15,6 +15,7 @@ class DuckDuckSearch(BaseRetriever): chat_history, prompt, chunks=2, + token_limit=150, gpt_model="docsgpt", user_api_key=None, ): @@ -24,6 +25,16 @@ class DuckDuckSearch(BaseRetriever): self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model + self.token_limit = ( + token_limit + if token_limit + < settings.MODEL_TOKEN_LIMITS.get( + self.gpt_model, settings.DEFAULT_MAX_HISTORY + ) + else settings.MODEL_TOKEN_LIMITS.get( + self.gpt_model, settings.DEFAULT_MAX_HISTORY + ) + ) self.user_api_key = user_api_key def _parse_lang_string(self, input_string): @@ -87,10 +98,7 @@ class DuckDuckSearch(BaseRetriever): tokens_batch = count_tokens(i["prompt"]) + count_tokens( i["response"] ) - if ( - tokens_current_history + tokens_batch - < settings.TOKENS_MAX_HISTORY - ): + if tokens_current_history + tokens_batch < self.token_limit: tokens_current_history += tokens_batch messages_combine.append( {"role": "user", "content": i["prompt"]} diff --git a/frontend/src/components/Dropdown.tsx b/frontend/src/components/Dropdown.tsx index 670a2d6..89fe3df 100644 --- a/frontend/src/components/Dropdown.tsx +++ b/frontend/src/components/Dropdown.tsx @@ -20,12 +20,18 @@ function Dropdown({ options: | string[] | { name: string; id: string; type: string }[] - | { label: string; value: string }[]; - selectedValue: string | { label: string; value: string } | null; + | { label: string; value: string }[] + | { value: number; description: string }[]; + selectedValue: + | string + | { label: string; value: string } + | { value: number; description: string } + | null; onSelect: | ((value: string) => void) | ((value: { name: string; id: string; type: string }) => void) - | ((value: { label: string; value: string }) => void); + | ((value: { label: string; value: string }) => void) + | ((value: { value: number; description: string }) => void); size?: string; rounded?: 'xl' | '3xl'; border?: 'border' | 'border-2'; @@ -64,8 +70,14 @@ function Dropdown({ !selectedValue && 'text-silver dark:text-gray-400' }`} > - {selectedValue + {selectedValue && 'label' in selectedValue ? selectedValue.label + : selectedValue && 'description' in selectedValue + ? `${ + selectedValue.value < 1e9 + ? selectedValue.value + ` (${selectedValue.description})` + : selectedValue.description + }` : placeholder ? placeholder : 'From URL'} @@ -99,7 +111,13 @@ function Dropdown({ ? option : option.name ? option.name - : option.label} + : option.label + ? option.label + : `${ + option.value < 1e9 + ? option.value + ` (${option.description})` + : option.description + }`} {showEdit && onEdit && ( void, ): Promise { const docPath = getDocPath(selectedDocs); @@ -119,6 +123,7 @@ export function fetchAnswerSteaming( conversation_id: conversationId, prompt_id: promptId, chunks: chunks, + token_limit: token_limit, }; fetch(apiHost + '/stream', { method: 'POST', @@ -181,6 +186,7 @@ export function searchEndpoint( conversation_id: string | null, history: Array = [], chunks: string, + token_limit: number, ) { const docPath = getDocPath(selectedDocs); @@ -190,6 +196,7 @@ export function searchEndpoint( conversation_id, history, chunks: chunks, + token_limit: token_limit, }; return fetch(`${apiHost}/api/search`, { method: 'POST', diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 72cf660..5aa7a0f 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -28,6 +28,7 @@ export const fetchAnswer = createAsyncThunk( state.conversation.conversationId, state.preference.prompt.id, state.preference.chunks, + state.preference.token_limit, (event) => { const data = JSON.parse(event.data); @@ -51,6 +52,7 @@ export const fetchAnswer = createAsyncThunk( state.conversation.conversationId, state.conversation.queries, state.preference.chunks, + state.preference.token_limit, ).then((sources) => { //dispatch streaming sources dispatch( @@ -86,6 +88,7 @@ export const fetchAnswer = createAsyncThunk( state.conversation.conversationId, state.preference.prompt.id, state.preference.chunks, + state.preference.token_limit, ); if (answer) { let sourcesPrepped = []; diff --git a/frontend/src/preferences/preferenceSlice.ts b/frontend/src/preferences/preferenceSlice.ts index ca68df7..370f260 100644 --- a/frontend/src/preferences/preferenceSlice.ts +++ b/frontend/src/preferences/preferenceSlice.ts @@ -11,8 +11,9 @@ import { ActiveState } from '../models/misc'; interface Preference { apiKey: string; prompt: { name: string; id: string; type: string }; - selectedDocs: Doc | null; chunks: string; + token_limit: number; + selectedDocs: Doc | null; sourceDocs: Doc[] | null; conversations: { name: string; id: string }[] | null; modalState: ActiveState; @@ -22,6 +23,7 @@ const initialState: Preference = { apiKey: 'xxx', prompt: { name: 'default', id: 'default', type: 'public' }, chunks: '2', + token_limit: 2000, selectedDocs: { name: 'default', language: 'default', @@ -60,6 +62,9 @@ export const prefSlice = createSlice({ setChunks: (state, action) => { state.chunks = action.payload; }, + setTokenLimit: (state, action) => { + state.token_limit = action.payload; + }, setModalStateDeleteConv: (state, action: PayloadAction) => { state.modalState = action.payload; }, @@ -73,6 +78,7 @@ export const { setConversations, setPrompt, setChunks, + setTokenLimit, setModalStateDeleteConv, } = prefSlice.actions; export default prefSlice.reducer; @@ -115,6 +121,18 @@ prefListenerMiddleware.startListening({ }, }); +prefListenerMiddleware.startListening({ + matcher: isAnyOf(setTokenLimit), + effect: (action, listenerApi) => { + localStorage.setItem( + 'DocsGPTTokenLimit', + JSON.stringify( + (listenerApi.getState() as RootState).preference.token_limit, + ), + ); + }, +}); + export const selectApiKey = (state: RootState) => state.preference.apiKey; export const selectApiKeyStatus = (state: RootState) => !!state.preference.apiKey; @@ -132,3 +150,5 @@ export const selectConversationId = (state: RootState) => state.conversation.conversationId; export const selectPrompt = (state: RootState) => state.preference.prompt; export const selectChunks = (state: RootState) => state.preference.chunks; +export const selectTokenLimit = (state: RootState) => + state.preference.token_limit; diff --git a/frontend/src/settings/General.tsx b/frontend/src/settings/General.tsx index 8801376..c098af1 100644 --- a/frontend/src/settings/General.tsx +++ b/frontend/src/settings/General.tsx @@ -8,6 +8,8 @@ import { setPrompt, setChunks, selectChunks, + setTokenLimit, + selectTokenLimit, setModalStateDeleteConv, } from '../preferences/preferenceSlice'; @@ -17,10 +19,19 @@ const General: React.FC = () => { const themes = ['Light', 'Dark']; const languages = ['English']; const chunks = ['0', '2', '4', '6', '8', '10']; + const token_limits = new Map([ + [0, 'None'], + [100, 'Low'], + [1000, 'Medium'], + [2000, 'Default'], + [4000, 'High'], + [1e9, 'Unlimited'], + ]); const [prompts, setPrompts] = React.useState< { name: string; id: string; type: string }[] >([]); const selectedChunks = useSelector(selectChunks); + const selectedTokenLimit = useSelector(selectTokenLimit); const [isDarkTheme, toggleTheme] = useDarkTheme(); const [selectedTheme, setSelectedTheme] = React.useState( isDarkTheme ? 'Dark' : 'Light', @@ -87,6 +98,31 @@ const General: React.FC = () => { border="border" /> +
+

+ Conversational history +

+ ({ + value: value, + description: desc, + }))} + selectedValue={{ + value: selectedTokenLimit, + description: token_limits.get(selectedTokenLimit) as string, + }} + onSelect={({ + value, + description, + }: { + value: number; + description: string; + }) => dispatch(setTokenLimit(value))} + size="w-56" + rounded="3xl" + border="border" + /> +