From ed081235503d881f67ac3edc8ab9aea953e39aca Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 22 Mar 2024 14:50:56 +0000 Subject: [PATCH] Add support for setting the number of chunks processed per query --- application/api/answer/routes.py | 35 +++++++++++++++---- frontend/src/Setting.tsx | 17 ++++++++- .../src/conversation/ConversationBubble.tsx | 5 ++- frontend/src/conversation/conversationApi.ts | 6 ++++ .../src/conversation/conversationSlice.ts | 3 ++ frontend/src/preferences/preferenceSlice.ts | 17 +++++++++ frontend/src/store.ts | 2 ++ 7 files changed, 76 insertions(+), 9 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 4c86cf4..9485502 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -98,7 +98,7 @@ def is_azure_configured(): return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME -def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conversation_id): +def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conversation_id, chunks=2): llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key) if prompt_id == 'default': @@ -109,8 +109,11 @@ def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conve prompt = chat_combine_strict else: prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] - - docs = docsearch.search(question, k=2) + + if chunks == 0: + docs = [] + else: + docs = docsearch.search(question, k=chunks) if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] # join all page_content together with a newline @@ -193,6 +196,10 @@ def stream(): prompt_id = data["prompt_id"] else: prompt_id = 'default' + if 'chunks' in data: + chunks = int(data["chunks"]) + else: + chunks = 2 # check if active_docs is set @@ -214,7 +221,8 @@ def stream(): complete_stream(question, docsearch, chat_history=history, api_key=api_key, prompt_id=prompt_id, - conversation_id=conversation_id), mimetype="text/event-stream" + conversation_id=conversation_id, + chunks=chunks), mimetype="text/event-stream" ) @@ -240,6 +248,10 @@ def api_answer(): prompt_id = data["prompt_id"] else: prompt_id = 'default' + if 'chunks' in data: + chunks = int(data["chunks"]) + else: + chunks = 2 if prompt_id == 'default': prompt = chat_combine_template @@ -263,7 +275,10 @@ def api_answer(): - docs = docsearch.search(question, k=2) + if chunks == 0: + docs = [] + else: + docs = docsearch.search(question, k=chunks) # join all page_content together with a newline docs_together = "\n".join([doc.page_content for doc in docs]) p_chat_combine = prompt.replace("{summaries}", docs_together) @@ -362,9 +377,15 @@ def api_search(): vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) else: vectorstore = "" + if 'chunks' in data: + chunks = int(data["chunks"]) + else: + chunks = 2 docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key) - - docs = docsearch.search(question, k=2) + if chunks == 0: + docs = [] + else: + docs = docsearch.search(question, k=chunks) source_log_docs = [] for doc in docs: diff --git a/frontend/src/Setting.tsx b/frontend/src/Setting.tsx index 5642379..46999fd 100644 --- a/frontend/src/Setting.tsx +++ b/frontend/src/Setting.tsx @@ -8,6 +8,8 @@ import { setPrompt, selectSourceDocs, setSourceDocs, + setChunks, + selectChunks, } from './preferences/preferenceSlice'; import { Doc } from './preferences/preferenceApi'; import { useDarkTheme } from './hooks'; @@ -193,10 +195,13 @@ const Setting: React.FC = () => { const General: React.FC = () => { const themes = ['Light', 'Dark']; const languages = ['English']; + const chunks = ['0', '2', '4', '6', '8', '10']; + const selectedChunks = useSelector(selectChunks); const [isDarkTheme, toggleTheme] = useDarkTheme(); const [selectedTheme, setSelectedTheme] = useState( isDarkTheme ? 'Dark' : 'Light', ); + const dispatch = useDispatch(); const [selectedLanguage, setSelectedLanguage] = useState(languages[0]); return (
@@ -211,7 +216,7 @@ const General: React.FC = () => { }} />
-
+

Select Language

@@ -221,6 +226,16 @@ const General: React.FC = () => { onSelect={setSelectedLanguage} />
+
+

+ Chunks processed per query +

+ dispatch(setChunks(value))} + /> +
); }; diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index b95b413..e8caf2f 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -160,7 +160,10 @@ const ConversationBubble = forwardRef< > {message} - {DisableSourceFE || type === 'ERROR' ? null : ( + {DisableSourceFE || + type === 'ERROR' || + !sources || + sources.length === 0 ? null : ( <>
diff --git a/frontend/src/conversation/conversationApi.ts b/frontend/src/conversation/conversationApi.ts index 8293df1..d8d7693 100644 --- a/frontend/src/conversation/conversationApi.ts +++ b/frontend/src/conversation/conversationApi.ts @@ -11,6 +11,7 @@ export function fetchAnswerApi( history: Array = [], conversationId: string | null, promptId: string | null, + chunks: string, ): Promise< | { result: any; @@ -65,6 +66,7 @@ export function fetchAnswerApi( active_docs: docPath, conversation_id: conversationId, prompt_id: promptId, + chunks: chunks, }), signal, }) @@ -95,6 +97,7 @@ export function fetchAnswerSteaming( history: Array = [], conversationId: string | null, promptId: string | null, + chunks: string, onEvent: (event: MessageEvent) => void, ): Promise { let namePath = selectedDocs.name; @@ -130,6 +133,7 @@ export function fetchAnswerSteaming( history: JSON.stringify(history), conversation_id: conversationId, prompt_id: promptId, + chunks: chunks, }; fetch(apiHost + '/stream', { method: 'POST', @@ -192,6 +196,7 @@ export function searchEndpoint( selectedDocs: Doc, conversation_id: string | null, history: Array = [], + chunks: string, ) { /* "active_docs": "default", @@ -223,6 +228,7 @@ export function searchEndpoint( active_docs: docPath, conversation_id, history, + chunks: chunks, }; return fetch(`${apiHost}/api/search`, { method: 'POST', diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 35aadd9..85fc351 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -28,6 +28,7 @@ export const fetchAnswer = createAsyncThunk( state.conversation.queries, state.conversation.conversationId, state.preference.prompt.id, + state.preference.chunks, (event) => { const data = JSON.parse(event.data); @@ -51,6 +52,7 @@ export const fetchAnswer = createAsyncThunk( state.preference.selectedDocs!, state.conversation.conversationId, state.conversation.queries, + state.preference.chunks, ).then((sources) => { //dispatch streaming sources dispatch( @@ -86,6 +88,7 @@ export const fetchAnswer = createAsyncThunk( state.conversation.queries, state.conversation.conversationId, state.preference.prompt.id, + state.preference.chunks, ); if (answer) { let sourcesPrepped = []; diff --git a/frontend/src/preferences/preferenceSlice.ts b/frontend/src/preferences/preferenceSlice.ts index 0aa8b3b..dc72fae 100644 --- a/frontend/src/preferences/preferenceSlice.ts +++ b/frontend/src/preferences/preferenceSlice.ts @@ -10,6 +10,7 @@ interface Preference { apiKey: string; prompt: { name: string; id: string; type: string }; selectedDocs: Doc | null; + chunks: string; sourceDocs: Doc[] | null; conversations: { name: string; id: string }[] | null; } @@ -17,6 +18,7 @@ interface Preference { const initialState: Preference = { apiKey: 'xxx', prompt: { name: 'default', id: 'default', type: 'public' }, + chunks: '2', selectedDocs: { name: 'default', language: 'default', @@ -51,6 +53,9 @@ export const prefSlice = createSlice({ setPrompt: (state, action) => { state.prompt = action.payload; }, + setChunks: (state, action) => { + state.chunks = action.payload; + }, }, }); @@ -60,6 +65,7 @@ export const { setSourceDocs, setConversations, setPrompt, + setChunks, } = prefSlice.actions; export default prefSlice.reducer; @@ -91,6 +97,16 @@ prefListenerMiddleware.startListening({ }, }); +prefListenerMiddleware.startListening({ + matcher: isAnyOf(setChunks), + effect: (action, listenerApi) => { + localStorage.setItem( + 'DocsGPTChunks', + JSON.stringify((listenerApi.getState() as RootState).preference.chunks), + ); + }, +}); + export const selectApiKey = (state: RootState) => state.preference.apiKey; export const selectApiKeyStatus = (state: RootState) => !!state.preference.apiKey; @@ -105,3 +121,4 @@ export const selectConversations = (state: RootState) => export const selectConversationId = (state: RootState) => state.conversation.conversationId; export const selectPrompt = (state: RootState) => state.preference.prompt; +export const selectChunks = (state: RootState) => state.preference.chunks; diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 234cc8e..232675a 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -8,11 +8,13 @@ import { const key = localStorage.getItem('DocsGPTApiKey'); const prompt = localStorage.getItem('DocsGPTPrompt'); const doc = localStorage.getItem('DocsGPTRecentDocs'); +const chunks = localStorage.getItem('DocsGPTChunks'); const store = configureStore({ preloadedState: { preference: { apiKey: key ?? '', + chunks: JSON.parse(chunks ?? '2'), selectedDocs: doc !== null ? JSON.parse(doc) : null, prompt: prompt !== null