diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 20c8e1cf..3e3b5d6b 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -173,30 +173,33 @@ def get_prompt(prompt_id): def complete_stream(question, retriever, conversation_id, user_api_key): - response_full = "" - source_log_docs = [] - answer = retriever.gen() - for line in answer: - if "answer" in line: - response_full += str(line["answer"]) - data = json.dumps(line) + try: + response_full = "" + source_log_docs = [] + answer = retriever.gen() + for line in answer: + if "answer" in line: + response_full += str(line["answer"]) + data = json.dumps(line) + yield f"data: {data}\n\n" + elif "source" in line: + source_log_docs.append(line["source"]) + + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + ) + conversation_id = save_conversation( + conversation_id, question, response_full, source_log_docs, llm + ) + + # send data.type = "end" to indicate that the stream has ended as json + data = json.dumps({"type": "id", "id": str(conversation_id)}) yield f"data: {data}\n\n" - elif "source" in line: - source_log_docs.append(line["source"]) - - llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key - ) - conversation_id = save_conversation( - conversation_id, question, response_full, source_log_docs, llm - ) - - # send data.type = "end" to indicate that the stream has ended as json - data = json.dumps({"type": "id", "id": str(conversation_id)}) - yield f"data: {data}\n\n" - data = json.dumps({"type": "end"}) - yield f"data: {data}\n\n" - + data = json.dumps({"type": "end"}) + yield f"data: {data}\n\n" + except Exception: + data = json.dumps({"type": "error","error":"Please try again later. We apologize for any inconvenience."}) + yield f"data: {data}\n\n" @answer.route("/stream", methods=["POST"]) def stream(): @@ -274,23 +277,29 @@ def stream(): ), mimetype="text/event-stream", ) + + except ValueError: + message = "Malformed request body" + return Response( + error_stream_generate(message), + status=400, + mimetype="text/event-stream", + ) except Exception as e: + print("err",str(e)) message = e.args[0] status_code = 400 # # Custom exceptions with two arguments, index 1 as status code if(len(e.args) >= 2): status_code = e.args[1] - - def error_stream_generate(): - data = json.dumps({"type": "error", "error":message}) - yield f"data: {data}\n\n" - return Response( - error_stream_generate(), + error_stream_generate(message), status=status_code, mimetype="text/event-stream", ) - +def error_stream_generate(err_response): + data = json.dumps({"type": "error", "error":err_response}) + yield f"data: {data}\n\n" @answer.route("/api/answer", methods=["POST"]) def api_answer(): diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 73880698..a5e1189c 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -68,6 +68,15 @@ export const fetchAnswer = createAsyncThunk( query: { conversationId: data.id }, }), ); + } else if (data.type === 'error') { + // set status to 'failed' + dispatch(conversationSlice.actions.setStatus('failed')); + dispatch( + conversationSlice.actions.raiseError({ + index: state.conversation.queries.length - 1, + message: data.error, + }), + ); } else { const result = data.answer; dispatch( @@ -191,6 +200,13 @@ export const conversationSlice = createSlice({ setStatus(state, action: PayloadAction) { state.status = action.payload; }, + raiseError( + state, + action: PayloadAction<{ index: number; message: string }>, + ) { + const { index, message } = action.payload; + state.queries[index].error = message; + }, }, extraReducers(builder) { builder