2023-09-26 12:00:17 +00:00
|
|
|
import asyncio
|
2023-09-26 09:03:22 +00:00
|
|
|
import os
|
2023-09-27 15:25:57 +00:00
|
|
|
from flask import Blueprint, request, Response
|
2023-09-26 09:03:22 +00:00
|
|
|
import json
|
|
|
|
import datetime
|
2023-09-26 12:00:17 +00:00
|
|
|
import logging
|
|
|
|
import traceback
|
2023-09-26 09:03:22 +00:00
|
|
|
|
|
|
|
from pymongo import MongoClient
|
|
|
|
from bson.objectid import ObjectId
|
2023-09-26 12:00:17 +00:00
|
|
|
from transformers import GPT2TokenizerFast
|
|
|
|
|
2023-09-27 15:25:57 +00:00
|
|
|
|
2023-09-26 09:03:22 +00:00
|
|
|
|
|
|
|
from application.core.settings import settings
|
2023-09-29 16:17:48 +00:00
|
|
|
from application.vectorstore.vector_creator import VectorCreator
|
2023-09-29 00:09:01 +00:00
|
|
|
from application.llm.llm_creator import LLMCreator
|
2023-09-26 12:00:17 +00:00
|
|
|
from application.error import bad_request
|
2023-09-26 09:03:22 +00:00
|
|
|
|
2023-09-27 15:25:57 +00:00
|
|
|
|
|
|
|
|
2023-09-26 12:00:17 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
2023-09-26 09:03:22 +00:00
|
|
|
|
|
|
|
mongo = MongoClient(settings.MONGO_URI)
|
|
|
|
db = mongo["docsgpt"]
|
|
|
|
conversations_collection = db["conversations"]
|
|
|
|
vectors_collection = db["vectors"]
|
2023-11-22 23:55:41 +00:00
|
|
|
prompts_collection = db["prompts"]
|
2023-09-26 09:03:22 +00:00
|
|
|
answer = Blueprint('answer', __name__)
|
|
|
|
|
2023-09-26 12:00:17 +00:00
|
|
|
if settings.LLM_NAME == "gpt4":
|
|
|
|
gpt_model = 'gpt-4'
|
2023-10-28 18:51:12 +00:00
|
|
|
elif settings.LLM_NAME == "anthropic":
|
|
|
|
gpt_model = 'claude-2'
|
2023-09-26 12:00:17 +00:00
|
|
|
else:
|
|
|
|
gpt_model = 'gpt-3.5-turbo'
|
|
|
|
|
|
|
|
# load the prompts
|
|
|
|
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
2023-11-14 01:16:06 +00:00
|
|
|
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
|
2023-09-26 12:00:17 +00:00
|
|
|
chat_combine_template = f.read()
|
|
|
|
|
|
|
|
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
|
|
|
|
chat_reduce_template = f.read()
|
|
|
|
|
2023-11-14 01:16:06 +00:00
|
|
|
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
|
2023-11-22 23:55:41 +00:00
|
|
|
chat_combine_creative = f.read()
|
2023-11-14 01:16:06 +00:00
|
|
|
|
|
|
|
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
|
2023-11-22 23:55:41 +00:00
|
|
|
chat_combine_strict = f.read()
|
2023-11-14 01:16:06 +00:00
|
|
|
|
2023-09-26 12:00:17 +00:00
|
|
|
api_key_set = settings.API_KEY is not None
|
|
|
|
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
|
|
|
|
|
|
|
|
|
|
|
|
async def async_generate(chain, question, chat_history):
|
|
|
|
result = await chain.arun({"question": question, "chat_history": chat_history})
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def count_tokens(string):
|
|
|
|
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
|
|
|
|
return len(tokenizer(string)['input_ids'])
|
|
|
|
|
|
|
|
|
|
|
|
def run_async_chain(chain, question, chat_history):
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
result = {}
|
|
|
|
try:
|
|
|
|
answer = loop.run_until_complete(async_generate(chain, question, chat_history))
|
|
|
|
finally:
|
|
|
|
loop.close()
|
|
|
|
result["answer"] = answer
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def get_vectorstore(data):
|
|
|
|
if "active_docs" in data:
|
2023-11-17 15:33:51 +00:00
|
|
|
if data["active_docs"].split("/")[0] == "default":
|
2023-09-26 12:00:17 +00:00
|
|
|
vectorstore = ""
|
2023-11-17 15:31:53 +00:00
|
|
|
elif data["active_docs"].split("/")[0] == "local":
|
|
|
|
vectorstore = "indexes/" + data["active_docs"]
|
2023-09-26 12:00:17 +00:00
|
|
|
else:
|
|
|
|
vectorstore = "vectors/" + data["active_docs"]
|
|
|
|
if data["active_docs"] == "default":
|
|
|
|
vectorstore = ""
|
|
|
|
else:
|
|
|
|
vectorstore = ""
|
|
|
|
vectorstore = os.path.join("application", vectorstore)
|
|
|
|
return vectorstore
|
|
|
|
|
|
|
|
|
2023-09-26 09:03:22 +00:00
|
|
|
def is_azure_configured():
|
|
|
|
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
|
|
|
|
|
|
|
|
2023-11-14 01:16:06 +00:00
|
|
|
def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conversation_id):
|
2023-09-29 00:09:01 +00:00
|
|
|
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
|
2023-11-14 01:16:06 +00:00
|
|
|
|
|
|
|
if prompt_id == 'default':
|
2023-11-22 23:55:41 +00:00
|
|
|
prompt = chat_combine_template
|
2023-11-14 01:16:06 +00:00
|
|
|
elif prompt_id == 'creative':
|
2023-11-22 23:55:41 +00:00
|
|
|
prompt = chat_combine_creative
|
2023-11-14 01:16:06 +00:00
|
|
|
elif prompt_id == 'strict':
|
2023-11-22 23:55:41 +00:00
|
|
|
prompt = chat_combine_strict
|
2023-11-14 01:16:06 +00:00
|
|
|
else:
|
2023-11-22 23:55:41 +00:00
|
|
|
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
2023-09-26 12:00:17 +00:00
|
|
|
|
2023-09-27 15:25:57 +00:00
|
|
|
docs = docsearch.search(question, k=2)
|
2023-10-01 18:16:13 +00:00
|
|
|
if settings.LLM_NAME == "llama.cpp":
|
|
|
|
docs = [docs[0]]
|
2023-09-26 09:03:22 +00:00
|
|
|
# join all page_content together with a newline
|
|
|
|
docs_together = "\n".join([doc.page_content for doc in docs])
|
2023-11-14 01:16:06 +00:00
|
|
|
p_chat_combine = prompt.replace("{summaries}", docs_together)
|
2023-09-26 09:03:22 +00:00
|
|
|
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
|
|
|
source_log_docs = []
|
|
|
|
for doc in docs:
|
|
|
|
if doc.metadata:
|
|
|
|
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
|
|
|
else:
|
|
|
|
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
|
|
|
|
|
|
|
if len(chat_history) > 1:
|
|
|
|
tokens_current_history = 0
|
|
|
|
# count tokens in history
|
|
|
|
chat_history.reverse()
|
|
|
|
for i in chat_history:
|
|
|
|
if "prompt" in i and "response" in i:
|
|
|
|
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
|
|
|
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
|
|
|
tokens_current_history += tokens_batch
|
|
|
|
messages_combine.append({"role": "user", "content": i["prompt"]})
|
|
|
|
messages_combine.append({"role": "system", "content": i["response"]})
|
|
|
|
messages_combine.append({"role": "user", "content": question})
|
2023-09-26 12:00:17 +00:00
|
|
|
|
|
|
|
response_full = ""
|
2023-09-26 09:03:22 +00:00
|
|
|
completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
|
2023-09-26 12:00:17 +00:00
|
|
|
messages=messages_combine)
|
2023-09-26 09:03:22 +00:00
|
|
|
for line in completion:
|
|
|
|
data = json.dumps({"answer": str(line)})
|
2023-09-26 12:00:17 +00:00
|
|
|
response_full += str(line)
|
2023-09-26 09:03:22 +00:00
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
|
|
|
|
# save conversation to database
|
|
|
|
if conversation_id is not None:
|
|
|
|
conversations_collection.update_one(
|
|
|
|
{"_id": ObjectId(conversation_id)},
|
2023-09-26 12:00:17 +00:00
|
|
|
{"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}},
|
2023-09-26 09:03:22 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
# create new conversation
|
|
|
|
# generate summary
|
|
|
|
messages_summary = [{"role": "assistant", "content": "Summarise following conversation in no more than 3 "
|
|
|
|
"words, respond ONLY with the summary, use the same "
|
|
|
|
"language as the system \n\nUser: " + question + "\n\n" +
|
|
|
|
"AI: " +
|
2023-09-26 12:00:17 +00:00
|
|
|
response_full},
|
2023-09-26 09:03:22 +00:00
|
|
|
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
|
|
|
|
"respond ONLY with the summary, use the same language as the "
|
|
|
|
"system"}]
|
2023-09-26 12:00:17 +00:00
|
|
|
|
2023-09-26 09:03:22 +00:00
|
|
|
completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
|
2023-09-26 12:00:17 +00:00
|
|
|
messages=messages_summary, max_tokens=30)
|
2023-09-26 09:03:22 +00:00
|
|
|
conversation_id = conversations_collection.insert_one(
|
|
|
|
{"user": "local",
|
|
|
|
"date": datetime.datetime.utcnow(),
|
2023-09-26 12:00:17 +00:00
|
|
|
"name": completion,
|
|
|
|
"queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]}
|
2023-09-26 09:03:22 +00:00
|
|
|
).inserted_id
|
|
|
|
|
|
|
|
# send data.type = "end" to indicate that the stream has ended as json
|
|
|
|
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
data = json.dumps({"type": "end"})
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
@answer.route("/stream", methods=["POST"])
|
|
|
|
def stream():
|
|
|
|
data = request.get_json()
|
|
|
|
# get parameter from url question
|
|
|
|
question = data["question"]
|
|
|
|
history = data["history"]
|
|
|
|
# history to json object from string
|
|
|
|
history = json.loads(history)
|
|
|
|
conversation_id = data["conversation_id"]
|
2023-11-14 01:16:06 +00:00
|
|
|
if 'prompt_id' in data:
|
|
|
|
prompt_id = data["prompt_id"]
|
|
|
|
else:
|
|
|
|
prompt_id = 'default'
|
2023-09-26 09:03:22 +00:00
|
|
|
|
|
|
|
# check if active_docs is set
|
|
|
|
|
|
|
|
if not api_key_set:
|
|
|
|
api_key = data["api_key"]
|
|
|
|
else:
|
|
|
|
api_key = settings.API_KEY
|
|
|
|
if not embeddings_key_set:
|
|
|
|
embeddings_key = data["embeddings_key"]
|
|
|
|
else:
|
|
|
|
embeddings_key = settings.EMBEDDINGS_KEY
|
|
|
|
if "active_docs" in data:
|
|
|
|
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
|
|
|
|
else:
|
|
|
|
vectorstore = ""
|
2023-09-29 16:17:48 +00:00
|
|
|
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key)
|
2023-09-26 09:03:22 +00:00
|
|
|
|
|
|
|
return Response(
|
|
|
|
complete_stream(question, docsearch,
|
|
|
|
chat_history=history, api_key=api_key,
|
2023-11-14 01:16:06 +00:00
|
|
|
prompt_id=prompt_id,
|
2023-09-26 09:03:22 +00:00
|
|
|
conversation_id=conversation_id), mimetype="text/event-stream"
|
2023-09-26 12:00:17 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@answer.route("/api/answer", methods=["POST"])
|
|
|
|
def api_answer():
|
|
|
|
data = request.get_json()
|
|
|
|
question = data["question"]
|
|
|
|
history = data["history"]
|
|
|
|
if "conversation_id" not in data:
|
|
|
|
conversation_id = None
|
|
|
|
else:
|
|
|
|
conversation_id = data["conversation_id"]
|
|
|
|
print("-" * 5)
|
|
|
|
if not api_key_set:
|
|
|
|
api_key = data["api_key"]
|
|
|
|
else:
|
|
|
|
api_key = settings.API_KEY
|
|
|
|
if not embeddings_key_set:
|
|
|
|
embeddings_key = data["embeddings_key"]
|
|
|
|
else:
|
|
|
|
embeddings_key = settings.EMBEDDINGS_KEY
|
2023-11-22 23:55:41 +00:00
|
|
|
if 'prompt_id' in data:
|
|
|
|
prompt_id = data["prompt_id"]
|
|
|
|
else:
|
|
|
|
prompt_id = 'default'
|
|
|
|
|
|
|
|
if prompt_id == 'default':
|
|
|
|
prompt = chat_combine_template
|
|
|
|
elif prompt_id == 'creative':
|
|
|
|
prompt = chat_combine_creative
|
|
|
|
elif prompt_id == 'strict':
|
|
|
|
prompt = chat_combine_strict
|
|
|
|
else:
|
|
|
|
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
2023-09-26 12:00:17 +00:00
|
|
|
|
|
|
|
# use try and except to check for exception
|
|
|
|
try:
|
|
|
|
# check if the vectorstore is set
|
|
|
|
vectorstore = get_vectorstore(data)
|
|
|
|
# loading the index and the store and the prompt template
|
|
|
|
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
|
2023-09-29 16:17:48 +00:00
|
|
|
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key)
|
2023-09-26 12:00:17 +00:00
|
|
|
|
2023-09-29 00:09:01 +00:00
|
|
|
|
|
|
|
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
|
2023-09-27 15:25:57 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs = docsearch.search(question, k=2)
|
|
|
|
# join all page_content together with a newline
|
|
|
|
docs_together = "\n".join([doc.page_content for doc in docs])
|
2023-11-22 23:55:41 +00:00
|
|
|
p_chat_combine = prompt.replace("{summaries}", docs_together)
|
2023-09-27 15:25:57 +00:00
|
|
|
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
|
|
|
source_log_docs = []
|
|
|
|
for doc in docs:
|
2023-09-26 12:00:17 +00:00
|
|
|
if doc.metadata:
|
2023-09-27 15:25:57 +00:00
|
|
|
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
2023-09-26 12:00:17 +00:00
|
|
|
else:
|
2023-09-27 15:25:57 +00:00
|
|
|
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
|
|
|
# join all page_content together with a newline
|
|
|
|
|
|
|
|
|
|
|
|
if len(history) > 1:
|
|
|
|
tokens_current_history = 0
|
|
|
|
# count tokens in history
|
|
|
|
history.reverse()
|
|
|
|
for i in history:
|
|
|
|
if "prompt" in i and "response" in i:
|
|
|
|
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
|
|
|
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
|
|
|
tokens_current_history += tokens_batch
|
|
|
|
messages_combine.append({"role": "user", "content": i["prompt"]})
|
|
|
|
messages_combine.append({"role": "system", "content": i["response"]})
|
|
|
|
messages_combine.append({"role": "user", "content": question})
|
|
|
|
|
|
|
|
|
|
|
|
completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
|
|
|
|
messages=messages_combine)
|
|
|
|
|
|
|
|
|
|
|
|
result = {"answer": completion, "sources": source_log_docs}
|
|
|
|
logger.debug(result)
|
2023-09-26 12:00:17 +00:00
|
|
|
|
|
|
|
# generate conversationId
|
|
|
|
if conversation_id is not None:
|
|
|
|
conversations_collection.update_one(
|
|
|
|
{"_id": ObjectId(conversation_id)},
|
|
|
|
{"$push": {"queries": {"prompt": question,
|
|
|
|
"response": result["answer"], "sources": result['sources']}}},
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
# create new conversation
|
|
|
|
# generate summary
|
2023-09-27 17:01:40 +00:00
|
|
|
messages_summary = [
|
|
|
|
{"role": "assistant", "content": "Summarise following conversation in no more than 3 words, "
|
|
|
|
"respond ONLY with the summary, use the same language as the system \n\n"
|
|
|
|
"User: " + question + "\n\n" + "AI: " + result["answer"]},
|
|
|
|
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
|
|
|
|
"respond ONLY with the summary, use the same language as the system"}
|
|
|
|
]
|
|
|
|
|
|
|
|
completion = llm.gen(
|
|
|
|
model=gpt_model,
|
|
|
|
engine=settings.AZURE_DEPLOYMENT_NAME,
|
|
|
|
messages=messages_summary,
|
|
|
|
max_tokens=30
|
|
|
|
)
|
2023-09-26 12:00:17 +00:00
|
|
|
conversation_id = conversations_collection.insert_one(
|
|
|
|
{"user": "local",
|
2023-09-27 15:25:57 +00:00
|
|
|
"date": datetime.datetime.utcnow(),
|
|
|
|
"name": completion,
|
|
|
|
"queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]}
|
2023-09-26 12:00:17 +00:00
|
|
|
).inserted_id
|
|
|
|
|
|
|
|
result["conversation_id"] = str(conversation_id)
|
|
|
|
|
|
|
|
# mock result
|
|
|
|
# result = {
|
|
|
|
# "answer": "The answer is 42",
|
|
|
|
# "sources": ["https://en.wikipedia.org/wiki/42_(number)", "https://en.wikipedia.org/wiki/42_(number)"]
|
|
|
|
# }
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
|
|
# print whole traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
print(str(e))
|
|
|
|
return bad_request(500, str(e))
|
2024-01-12 14:35:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
@answer.route("/api/search", methods=["POST"])
|
|
|
|
def api_search():
|
|
|
|
data = request.get_json()
|
|
|
|
# get parameter from url question
|
|
|
|
question = data["question"]
|
|
|
|
|
|
|
|
if not embeddings_key_set:
|
|
|
|
embeddings_key = data["embeddings_key"]
|
|
|
|
else:
|
|
|
|
embeddings_key = settings.EMBEDDINGS_KEY
|
|
|
|
if "active_docs" in data:
|
|
|
|
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
|
|
|
|
else:
|
|
|
|
vectorstore = ""
|
|
|
|
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key)
|
|
|
|
|
|
|
|
docs = docsearch.search(question, k=2)
|
|
|
|
|
|
|
|
source_log_docs = []
|
|
|
|
for doc in docs:
|
|
|
|
if doc.metadata:
|
|
|
|
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
|
|
|
else:
|
|
|
|
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
2024-01-12 14:39:17 +00:00
|
|
|
#yield f"data:{data}\n\n"
|
2024-01-12 14:35:23 +00:00
|
|
|
return source_log_docs
|
|
|
|
|