mirror of
https://github.com/arc53/DocsGPT
synced 2024-11-19 21:25:39 +00:00
feat: err handling /stream
This commit is contained in:
parent
e6b3984f78
commit
69287c5198
@ -9,13 +9,11 @@ import traceback
|
|||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from bson.objectid import ObjectId
|
from bson.objectid import ObjectId
|
||||||
|
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.retriever.retriever_creator import RetrieverCreator
|
from application.retriever.retriever_creator import RetrieverCreator
|
||||||
from application.error import bad_request
|
from application.error import bad_request
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
mongo = MongoClient(settings.MONGO_URI)
|
mongo = MongoClient(settings.MONGO_URI)
|
||||||
@ -75,8 +73,10 @@ def run_async_chain(chain, question, chat_history):
|
|||||||
|
|
||||||
def get_data_from_api_key(api_key):
|
def get_data_from_api_key(api_key):
|
||||||
data = api_key_collection.find_one({"key": api_key})
|
data = api_key_collection.find_one({"key": api_key})
|
||||||
|
|
||||||
|
# # Raise custom exception if the API key is not found
|
||||||
if data is None:
|
if data is None:
|
||||||
return bad_request(401, "Invalid API key")
|
raise Exception("API key is invalid", 401)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@ -128,10 +128,10 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
|
|||||||
"content": "Summarise following conversation in no more than 3 "
|
"content": "Summarise following conversation in no more than 3 "
|
||||||
"words, respond ONLY with the summary, use the same "
|
"words, respond ONLY with the summary, use the same "
|
||||||
"language as the system \n\nUser: "
|
"language as the system \n\nUser: "
|
||||||
+ question
|
+question
|
||||||
+ "\n\n"
|
+"\n\n"
|
||||||
+ "AI: "
|
+"AI: "
|
||||||
+ response,
|
+response,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@ -200,6 +200,7 @@ def complete_stream(question, retriever, conversation_id, user_api_key):
|
|||||||
|
|
||||||
@answer.route("/stream", methods=["POST"])
|
@answer.route("/stream", methods=["POST"])
|
||||||
def stream():
|
def stream():
|
||||||
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
# get parameter from url question
|
# get parameter from url question
|
||||||
question = data["question"]
|
question = data["question"]
|
||||||
@ -273,6 +274,22 @@ def stream():
|
|||||||
),
|
),
|
||||||
mimetype="text/event-stream",
|
mimetype="text/event-stream",
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
message = e.args[0]
|
||||||
|
status_code = 400
|
||||||
|
# # Custom exceptions with two arguments, index 1 as status code
|
||||||
|
if(len(e.args) >= 2):
|
||||||
|
status_code = e.args[1]
|
||||||
|
|
||||||
|
def error_stream_generate():
|
||||||
|
data = json.dumps({"type": "error", "error":message})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
error_stream_generate(),
|
||||||
|
status=status_code,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@answer.route("/api/answer", methods=["POST"])
|
@answer.route("/api/answer", methods=["POST"])
|
||||||
|
Loading…
Reference in New Issue
Block a user