Merge pull request #251 from arc53/feature/streaming

Feature/streaming
pull/253/head
Alex 1 year ago committed by GitHub
commit 6c95d8b13e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,11 +5,12 @@ import json
import os import os
import traceback import traceback
import openai
import dotenv import dotenv
import requests import requests
from celery import Celery from celery import Celery
from celery.result import AsyncResult from celery.result import AsyncResult
from flask import Flask, request, render_template, send_from_directory, jsonify from flask import Flask, request, render_template, send_from_directory, jsonify, Response
from langchain import FAISS from langchain import FAISS
from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI
from langchain.chains import LLMChain, ConversationalRetrievalChain from langchain.chains import LLMChain, ConversationalRetrievalChain
@ -109,7 +110,32 @@ def run_async_chain(chain, question, chat_history):
result["answer"] = answer result["answer"] = answer
return result return result
def get_vectorstore(data):
if "active_docs" in data:
if data["active_docs"].split("/")[0] == "local":
if data["active_docs"].split("/")[1] == "default":
vectorstore = ""
else:
vectorstore = "indexes/" + data["active_docs"]
else:
vectorstore = "vectors/" + data["active_docs"]
if data['active_docs'] == "default":
vectorstore = ""
else:
vectorstore = ""
return vectorstore
def get_docsearch(vectorstore, embeddings_key):
if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002":
docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=embeddings_key))
elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings())
elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large":
docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings())
elif settings.EMBEDDINGS_NAME == "cohere_medium":
docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key))
return docsearch
@celery.task(bind=True) @celery.task(bind=True)
@ -123,6 +149,53 @@ def home():
return render_template("index.html", api_key_set=api_key_set, llm_choice=settings.LLM_NAME, return render_template("index.html", api_key_set=api_key_set, llm_choice=settings.LLM_NAME,
embeddings_choice=settings.EMBEDDINGS_NAME) embeddings_choice=settings.EMBEDDINGS_NAME)
def complete_stream(question, docsearch, chat_history, api_key):
openai.api_key = api_key
docs = docsearch.similarity_search(question, k=2)
# join all page_content together with a newline
docs_together = "\n".join([doc.page_content for doc in docs])
# swap {summaries} in chat_combine_template with the summaries from the docs
p_chat_combine = chat_combine_template.replace("{summaries}", docs_together)
completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[
{"role": "system", "content": p_chat_combine},
{"role": "user", "content": question},
], stream=True, max_tokens=1000, temperature=0)
for line in completion:
if 'content' in line['choices'][0]['delta']:
# check if the delta contains content
data = json.dumps({"answer": str(line['choices'][0]['delta']['content'])})
yield f"data: {data}\n\n"
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
@app.route("/stream", methods=['POST', 'GET'])
def stream():
# get parameter from url question
question = request.args.get('question')
history = request.args.get('history')
# check if active_docs is set
if not api_key_set:
api_key = request.args.get("api_key")
else:
api_key = settings.API_KEY
if not embeddings_key_set:
embeddings_key = request.args.get("embeddings_key")
else:
embeddings_key = settings.EMBEDDINGS_KEY
if "active_docs" in request.args:
vectorstore = get_vectorstore({"active_docs": request.args.get("active_docs")})
else:
vectorstore = ""
docsearch = get_docsearch(vectorstore, embeddings_key)
#question = "Hi"
return Response(complete_stream(question, docsearch,
chat_history= history, api_key=api_key), mimetype='text/event-stream')
@app.route("/api/answer", methods=["POST"]) @app.route("/api/answer", methods=["POST"])
def api_answer(): def api_answer():
@ -142,32 +215,11 @@ def api_answer():
# use try and except to check for exception # use try and except to check for exception
try: try:
# check if the vectorstore is set # check if the vectorstore is set
if "active_docs" in data: vectorstore = get_vectorstore(data)
if data["active_docs"].split("/")[0] == "local":
if data["active_docs"].split("/")[1] == "default":
vectorstore = ""
else:
vectorstore = "indexes/" + data["active_docs"]
else:
vectorstore = "vectors/" + data["active_docs"]
if data['active_docs'] == "default":
vectorstore = ""
else:
vectorstore = ""
print(vectorstore)
# vectorstore = "outputs/inputs/"
# loading the index and the store and the prompt template # loading the index and the store and the prompt template
# Note if you have used other embeddings than OpenAI, you need to change the embeddings # Note if you have used other embeddings than OpenAI, you need to change the embeddings
if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002": docsearch = get_docsearch(vectorstore, embeddings_key)
docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=embeddings_key))
elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings())
elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large":
docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings())
elif settings.EMBEDDINGS_NAME == "cohere_medium":
docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key))
# create a prompt template
q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest, q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
template_format="jinja2") template_format="jinja2")
if settings.LLM_NAME == "openai_chat": if settings.LLM_NAME == "openai_chat":

@ -5,6 +5,7 @@ services:
build: ./frontend build: ./frontend
environment: environment:
- VITE_API_HOST=http://localhost:5001 - VITE_API_HOST=http://localhost:5001
- VITE_API_STREAMING=$VITE_API_STREAMING
ports: ports:
- "5173:5173" - "5173:5173"
depends_on: depends_on:

@ -46,7 +46,7 @@ export function fetchAnswerApi(
if (response.ok) { if (response.ok) {
return response.json(); return response.json();
} else { } else {
Promise.reject(response); return Promise.reject(new Error(response.statusText));
} }
}) })
.then((data) => { .then((data) => {
@ -55,6 +55,51 @@ export function fetchAnswerApi(
}); });
} }
export function fetchAnswerSteaming(
question: string,
apiKey: string,
selectedDocs: Doc,
onEvent: (event: MessageEvent) => void,
): Promise<Answer> {
let namePath = selectedDocs.name;
if (selectedDocs.language === namePath) {
namePath = '.project';
}
let docPath = 'default';
if (selectedDocs.location === 'local') {
docPath = 'local' + '/' + selectedDocs.name + '/';
} else if (selectedDocs.location === 'remote') {
docPath =
selectedDocs.language +
'/' +
namePath +
'/' +
selectedDocs.version +
'/' +
selectedDocs.model +
'/';
}
return new Promise<Answer>((resolve, reject) => {
const url = new URL(apiHost + '/stream');
url.searchParams.append('question', question);
url.searchParams.append('api_key', apiKey);
url.searchParams.append('embeddings_key', apiKey);
url.searchParams.append('history', localStorage.getItem('chatHistory'));
url.searchParams.append('active_docs', docPath);
const eventSource = new EventSource(url.href);
eventSource.onmessage = onEvent;
eventSource.onerror = (error) => {
console.log('Connection failed.');
eventSource.close();
};
});
}
export function sendFeedback( export function sendFeedback(
prompt: string, prompt: string,
response: string, response: string,

@ -1,28 +1,64 @@
import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit'; import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit';
import store from '../store'; import store from '../store';
import { fetchAnswerApi } from './conversationApi'; import { fetchAnswerApi, fetchAnswerSteaming } from './conversationApi';
import { Answer, ConversationState, Query } from './conversationModels'; import { Answer, ConversationState, Query, Status } from './conversationModels';
const initialState: ConversationState = { const initialState: ConversationState = {
queries: [], queries: [],
status: 'idle', status: 'idle',
}; };
export const fetchAnswer = createAsyncThunk< const API_STREAMING = import.meta.env.VITE_API_STREAMING === 'true';
Answer,
{ question: string },
{ state: RootState }
>('fetchAnswer', async ({ question }, { getState }) => {
const state = getState();
const answer = await fetchAnswerApi( export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
question, 'fetchAnswer',
state.preference.apiKey, async ({ question }, { dispatch, getState }) => {
state.preference.selectedDocs!, const state = getState() as RootState;
state.conversation.queries, if (state.preference) {
); if (API_STREAMING) {
return answer; await fetchAnswerSteaming(
}); question,
state.preference.apiKey,
state.preference.selectedDocs!,
(event) => {
const data = JSON.parse(event.data);
// check if the 'end' event has been received
if (data.type === 'end') {
// set status to 'idle'
dispatch(conversationSlice.actions.setStatus('idle'));
} else {
const result = data.answer;
dispatch(
updateStreamingQuery({
index: state.conversation.queries.length - 1,
query: { response: result },
}),
);
}
},
);
} else {
const answer = await fetchAnswerApi(
question,
state.preference.apiKey,
state.preference.selectedDocs!,
state.conversation.queries,
);
if (answer) {
dispatch(
updateQuery({
index: state.conversation.queries.length - 1,
query: { response: answer.answer },
}),
);
dispatch(conversationSlice.actions.setStatus('idle'));
}
}
}
return { answer: '', query: question, result: '' };
},
);
export const conversationSlice = createSlice({ export const conversationSlice = createSlice({
name: 'conversation', name: 'conversation',
@ -31,6 +67,21 @@ export const conversationSlice = createSlice({
addQuery(state, action: PayloadAction<Query>) { addQuery(state, action: PayloadAction<Query>) {
state.queries.push(action.payload); state.queries.push(action.payload);
}, },
updateStreamingQuery(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
) {
const index = action.payload.index;
if (action.payload.query.response) {
state.queries[index].response =
(state.queries[index].response || '') + action.payload.query.response;
} else {
state.queries[index] = {
...state.queries[index],
...action.payload.query,
};
}
},
updateQuery( updateQuery(
state, state,
action: PayloadAction<{ index: number; query: Partial<Query> }>, action: PayloadAction<{ index: number; query: Partial<Query> }>,
@ -41,17 +92,15 @@ export const conversationSlice = createSlice({
...action.payload.query, ...action.payload.query,
}; };
}, },
setStatus(state, action: PayloadAction<Status>) {
state.status = action.payload;
},
}, },
extraReducers(builder) { extraReducers(builder) {
builder builder
.addCase(fetchAnswer.pending, (state) => { .addCase(fetchAnswer.pending, (state) => {
state.status = 'loading'; state.status = 'loading';
}) })
.addCase(fetchAnswer.fulfilled, (state, action) => {
state.status = 'idle';
state.queries[state.queries.length - 1].response =
action.payload.answer;
})
.addCase(fetchAnswer.rejected, (state, action) => { .addCase(fetchAnswer.rejected, (state, action) => {
state.status = 'failed'; state.status = 'failed';
state.queries[state.queries.length - 1].error = state.queries[state.queries.length - 1].error =
@ -66,5 +115,6 @@ export const selectQueries = (state: RootState) => state.conversation.queries;
export const selectStatus = (state: RootState) => state.conversation.status; export const selectStatus = (state: RootState) => state.conversation.status;
export const { addQuery, updateQuery } = conversationSlice.actions; export const { addQuery, updateQuery, updateStreamingQuery } =
conversationSlice.actions;
export default conversationSlice.reducer; export default conversationSlice.reducer;

Loading…
Cancel
Save