diff --git a/application/app.py b/application/app.py index 95b3e4e..d0ecb3b 100644 --- a/application/app.py +++ b/application/app.py @@ -5,11 +5,12 @@ import json import os import traceback +import openai import dotenv import requests from celery import Celery from celery.result import AsyncResult -from flask import Flask, request, render_template, send_from_directory, jsonify +from flask import Flask, request, render_template, send_from_directory, jsonify, Response from langchain import FAISS from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI from langchain.chains import LLMChain, ConversationalRetrievalChain @@ -109,7 +110,32 @@ def run_async_chain(chain, question, chat_history): result["answer"] = answer return result - + +def get_vectorstore(data): + if "active_docs" in data: + if data["active_docs"].split("/")[0] == "local": + if data["active_docs"].split("/")[1] == "default": + vectorstore = "" + else: + vectorstore = "indexes/" + data["active_docs"] + else: + vectorstore = "vectors/" + data["active_docs"] + if data['active_docs'] == "default": + vectorstore = "" + else: + vectorstore = "" + return vectorstore + +def get_docsearch(vectorstore, embeddings_key): + if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002": + docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=embeddings_key)) + elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2": + docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings()) + elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large": + docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings()) + elif settings.EMBEDDINGS_NAME == "cohere_medium": + docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key)) + return docsearch @celery.task(bind=True) @@ -123,6 +149,53 @@ def home(): return render_template("index.html", api_key_set=api_key_set, llm_choice=settings.LLM_NAME, embeddings_choice=settings.EMBEDDINGS_NAME) +def complete_stream(question, docsearch, chat_history, api_key): + openai.api_key = api_key + docs = docsearch.similarity_search(question, k=2) + # join all page_content together with a newline + docs_together = "\n".join([doc.page_content for doc in docs]) + + # swap {summaries} in chat_combine_template with the summaries from the docs + p_chat_combine = chat_combine_template.replace("{summaries}", docs_together) + completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[ + {"role": "system", "content": p_chat_combine}, + {"role": "user", "content": question}, + ], stream=True, max_tokens=1000, temperature=0) + + for line in completion: + if 'content' in line['choices'][0]['delta']: + # check if the delta contains content + data = json.dumps({"answer": str(line['choices'][0]['delta']['content'])}) + yield f"data: {data}\n\n" + # send data.type = "end" to indicate that the stream has ended as json + data = json.dumps({"type": "end"}) + yield f"data: {data}\n\n" +@app.route("/stream", methods=['POST', 'GET']) +def stream(): + # get parameter from url question + question = request.args.get('question') + history = request.args.get('history') + # check if active_docs is set + + if not api_key_set: + api_key = request.args.get("api_key") + else: + api_key = settings.API_KEY + if not embeddings_key_set: + embeddings_key = request.args.get("embeddings_key") + else: + embeddings_key = settings.EMBEDDINGS_KEY + if "active_docs" in request.args: + vectorstore = get_vectorstore({"active_docs": request.args.get("active_docs")}) + else: + vectorstore = "" + docsearch = get_docsearch(vectorstore, embeddings_key) + + + #question = "Hi" + return Response(complete_stream(question, docsearch, + chat_history= history, api_key=api_key), mimetype='text/event-stream') + @app.route("/api/answer", methods=["POST"]) def api_answer(): @@ -142,32 +215,11 @@ def api_answer(): # use try and except to check for exception try: # check if the vectorstore is set - if "active_docs" in data: - if data["active_docs"].split("/")[0] == "local": - if data["active_docs"].split("/")[1] == "default": - vectorstore = "" - else: - vectorstore = "indexes/" + data["active_docs"] - else: - vectorstore = "vectors/" + data["active_docs"] - if data['active_docs'] == "default": - vectorstore = "" - else: - vectorstore = "" - print(vectorstore) - # vectorstore = "outputs/inputs/" + vectorstore = get_vectorstore(data) # loading the index and the store and the prompt template # Note if you have used other embeddings than OpenAI, you need to change the embeddings - if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002": - docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=embeddings_key)) - elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2": - docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings()) - elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large": - docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings()) - elif settings.EMBEDDINGS_NAME == "cohere_medium": - docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key)) - - # create a prompt template + docsearch = get_docsearch(vectorstore, embeddings_key) + q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest, template_format="jinja2") if settings.LLM_NAME == "openai_chat": diff --git a/docker-compose.yaml b/docker-compose.yaml index 703ed4e..c06b61b 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -5,6 +5,7 @@ services: build: ./frontend environment: - VITE_API_HOST=http://localhost:5001 + - VITE_API_STREAMING=$VITE_API_STREAMING ports: - "5173:5173" depends_on: diff --git a/frontend/src/conversation/conversationApi.ts b/frontend/src/conversation/conversationApi.ts index 4d5bdfb..29a3956 100644 --- a/frontend/src/conversation/conversationApi.ts +++ b/frontend/src/conversation/conversationApi.ts @@ -46,7 +46,7 @@ export function fetchAnswerApi( if (response.ok) { return response.json(); } else { - Promise.reject(response); + return Promise.reject(new Error(response.statusText)); } }) .then((data) => { @@ -55,6 +55,51 @@ export function fetchAnswerApi( }); } +export function fetchAnswerSteaming( + question: string, + apiKey: string, + selectedDocs: Doc, + onEvent: (event: MessageEvent) => void, +): Promise { + let namePath = selectedDocs.name; + if (selectedDocs.language === namePath) { + namePath = '.project'; + } + + let docPath = 'default'; + if (selectedDocs.location === 'local') { + docPath = 'local' + '/' + selectedDocs.name + '/'; + } else if (selectedDocs.location === 'remote') { + docPath = + selectedDocs.language + + '/' + + namePath + + '/' + + selectedDocs.version + + '/' + + selectedDocs.model + + '/'; + } + + return new Promise((resolve, reject) => { + const url = new URL(apiHost + '/stream'); + url.searchParams.append('question', question); + url.searchParams.append('api_key', apiKey); + url.searchParams.append('embeddings_key', apiKey); + url.searchParams.append('history', localStorage.getItem('chatHistory')); + url.searchParams.append('active_docs', docPath); + + const eventSource = new EventSource(url.href); + + eventSource.onmessage = onEvent; + + eventSource.onerror = (error) => { + console.log('Connection failed.'); + eventSource.close(); + }; + }); +} + export function sendFeedback( prompt: string, response: string, diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index a822c9b..37a7b0e 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -1,28 +1,64 @@ import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit'; import store from '../store'; -import { fetchAnswerApi } from './conversationApi'; -import { Answer, ConversationState, Query } from './conversationModels'; +import { fetchAnswerApi, fetchAnswerSteaming } from './conversationApi'; +import { Answer, ConversationState, Query, Status } from './conversationModels'; const initialState: ConversationState = { queries: [], status: 'idle', }; -export const fetchAnswer = createAsyncThunk< - Answer, - { question: string }, - { state: RootState } ->('fetchAnswer', async ({ question }, { getState }) => { - const state = getState(); +const API_STREAMING = import.meta.env.VITE_API_STREAMING === 'true'; - const answer = await fetchAnswerApi( - question, - state.preference.apiKey, - state.preference.selectedDocs!, - state.conversation.queries, - ); - return answer; -}); +export const fetchAnswer = createAsyncThunk( + 'fetchAnswer', + async ({ question }, { dispatch, getState }) => { + const state = getState() as RootState; + if (state.preference) { + if (API_STREAMING) { + await fetchAnswerSteaming( + question, + state.preference.apiKey, + state.preference.selectedDocs!, + (event) => { + const data = JSON.parse(event.data); + + // check if the 'end' event has been received + if (data.type === 'end') { + // set status to 'idle' + dispatch(conversationSlice.actions.setStatus('idle')); + } else { + const result = data.answer; + dispatch( + updateStreamingQuery({ + index: state.conversation.queries.length - 1, + query: { response: result }, + }), + ); + } + }, + ); + } else { + const answer = await fetchAnswerApi( + question, + state.preference.apiKey, + state.preference.selectedDocs!, + state.conversation.queries, + ); + if (answer) { + dispatch( + updateQuery({ + index: state.conversation.queries.length - 1, + query: { response: answer.answer }, + }), + ); + dispatch(conversationSlice.actions.setStatus('idle')); + } + } + } + return { answer: '', query: question, result: '' }; + }, +); export const conversationSlice = createSlice({ name: 'conversation', @@ -31,6 +67,21 @@ export const conversationSlice = createSlice({ addQuery(state, action: PayloadAction) { state.queries.push(action.payload); }, + updateStreamingQuery( + state, + action: PayloadAction<{ index: number; query: Partial }>, + ) { + const index = action.payload.index; + if (action.payload.query.response) { + state.queries[index].response = + (state.queries[index].response || '') + action.payload.query.response; + } else { + state.queries[index] = { + ...state.queries[index], + ...action.payload.query, + }; + } + }, updateQuery( state, action: PayloadAction<{ index: number; query: Partial }>, @@ -41,17 +92,15 @@ export const conversationSlice = createSlice({ ...action.payload.query, }; }, + setStatus(state, action: PayloadAction) { + state.status = action.payload; + }, }, extraReducers(builder) { builder .addCase(fetchAnswer.pending, (state) => { state.status = 'loading'; }) - .addCase(fetchAnswer.fulfilled, (state, action) => { - state.status = 'idle'; - state.queries[state.queries.length - 1].response = - action.payload.answer; - }) .addCase(fetchAnswer.rejected, (state, action) => { state.status = 'failed'; state.queries[state.queries.length - 1].error = @@ -66,5 +115,6 @@ export const selectQueries = (state: RootState) => state.conversation.queries; export const selectStatus = (state: RootState) => state.conversation.status; -export const { addQuery, updateQuery } = conversationSlice.actions; +export const { addQuery, updateQuery, updateStreamingQuery } = + conversationSlice.actions; export default conversationSlice.reducer;