From 69287c519852be81313b97b0ae882c01d4facfc1 Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Tue, 18 Jun 2024 16:12:18 +0530 Subject: [PATCH] feat: err handling /stream --- application/api/answer/routes.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 6c5e3e9c..20c8e1cf 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -9,13 +9,11 @@ import traceback from pymongo import MongoClient from bson.objectid import ObjectId - from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator from application.error import bad_request - logger = logging.getLogger(__name__) mongo = MongoClient(settings.MONGO_URI) @@ -75,8 +73,10 @@ def run_async_chain(chain, question, chat_history): def get_data_from_api_key(api_key): data = api_key_collection.find_one({"key": api_key}) + + # # Raise custom exception if the API key is not found if data is None: - return bad_request(401, "Invalid API key") + raise Exception("API key is invalid", 401) return data @@ -128,10 +128,10 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm) "content": "Summarise following conversation in no more than 3 " "words, respond ONLY with the summary, use the same " "language as the system \n\nUser: " - + question - + "\n\n" - + "AI: " - + response, + +question + +"\n\n" + +"AI: " + +response, }, { "role": "user", @@ -200,6 +200,7 @@ def complete_stream(question, retriever, conversation_id, user_api_key): @answer.route("/stream", methods=["POST"]) def stream(): + try: data = request.get_json() # get parameter from url question question = data["question"] @@ -273,6 +274,22 @@ def stream(): ), mimetype="text/event-stream", ) + except Exception as e: + message = e.args[0] + status_code = 400 + # # Custom exceptions with two arguments, index 1 as status code + if(len(e.args) >= 2): + status_code = e.args[1] + + def error_stream_generate(): + data = json.dumps({"type": "error", "error":message}) + yield f"data: {data}\n\n" + + return Response( + error_stream_generate(), + status=status_code, + mimetype="text/event-stream", + ) @answer.route("/api/answer", methods=["POST"])