diff --git a/application/app.py b/application/app.py
index 0097f33..99ab8e0 100644
--- a/application/app.py
+++ b/application/app.py
@@ -33,12 +33,14 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
AIMessagePromptTemplate,
)
+from langchain.schema import HumanMessage, AIMessage
from pymongo import MongoClient
from werkzeug.utils import secure_filename
from core.settings import settings
from error import bad_request
from worker import ingest_worker
+from bson.objectid import ObjectId
# os.environ["LANGCHAIN_HANDLER"] = "langchain"
@@ -94,6 +96,7 @@ celery.config_from_object("celeryconfig")
mongo = MongoClient(app.config["MONGO_URI"])
db = mongo["docsgpt"]
vectors_collection = db["vectors"]
+conversations_collection = db["conversations"]
async def async_generate(chain, question, chat_history):
@@ -159,7 +162,7 @@ def home():
)
-def complete_stream(question, docsearch, chat_history, api_key):
+def complete_stream(question, docsearch, chat_history, api_key, conversation_id):
openai.api_key = api_key
if is_azure_configured():
logger.debug("in Azure")
@@ -180,11 +183,14 @@ def complete_stream(question, docsearch, chat_history, api_key):
docs_together = "\n".join([doc.page_content for doc in docs])
p_chat_combine = chat_combine_template.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
+ source_log_docs = []
for doc in docs:
if doc.metadata:
data = json.dumps({"type": "source", "doc": doc.page_content, "metadata": doc.metadata})
+ source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
else:
data = json.dumps({"type": "source", "doc": doc.page_content})
+ source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
yield f"data:{data}\n\n"
if len(chat_history) > 1:
@@ -201,13 +207,43 @@ def complete_stream(question, docsearch, chat_history, api_key):
messages_combine.append({"role": "user", "content": question})
completion = openai.ChatCompletion.create(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
messages=messages_combine, stream=True, max_tokens=500, temperature=0)
-
+ reponse_full = ""
for line in completion:
if "content" in line["choices"][0]["delta"]:
# check if the delta contains content
data = json.dumps({"answer": str(line["choices"][0]["delta"]["content"])})
+ reponse_full += str(line["choices"][0]["delta"]["content"])
yield f"data: {data}\n\n"
+ # save conversation to database
+ if conversation_id is not None:
+ conversations_collection.update_one(
+ {"_id": ObjectId(conversation_id)},
+ {"$push": {"queries": {"prompt": question, "response": reponse_full, "sources": source_log_docs}}},
+ )
+
+ else:
+ # create new conversation
+ # generate summary
+ messages_summary = [{"role": "assistant", "content": "Summarise following conversation in no more than 3 "
+ "words, respond ONLY with the summary, use the same "
+ "language as the system \n\nUser: " + question + "\n\n" +
+ "AI: " +
+ reponse_full},
+ {"role": "user", "content": "Summarise following conversation in no more than 3 words, "
+ "respond ONLY with the summary, use the same language as the "
+ "system"}]
+ completion = openai.ChatCompletion.create(model='gpt-3.5-turbo', engine=settings.AZURE_DEPLOYMENT_NAME,
+ messages=messages_summary, max_tokens=30, temperature=0)
+ conversation_id = conversations_collection.insert_one(
+ {"user": "local",
+ "date": datetime.datetime.utcnow(),
+ "name": completion["choices"][0]["message"]["content"],
+ "queries": [{"prompt": question, "response": reponse_full, "sources": source_log_docs}]}
+ ).inserted_id
+
# send data.type = "end" to indicate that the stream has ended as json
+ data = json.dumps({"type": "id", "id": str(conversation_id)})
+ yield f"data: {data}\n\n"
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
@@ -220,6 +256,7 @@ def stream():
history = data["history"]
# history to json object from string
history = json.loads(history)
+ conversation_id = data["conversation_id"]
# check if active_docs is set
@@ -239,7 +276,9 @@ def stream():
# question = "Hi"
return Response(
- complete_stream(question, docsearch, chat_history=history, api_key=api_key), mimetype="text/event-stream"
+ complete_stream(question, docsearch,
+ chat_history=history, api_key=api_key,
+ conversation_id=conversation_id), mimetype="text/event-stream"
)
@@ -252,6 +291,10 @@ def api_answer():
data = request.get_json()
question = data["question"]
history = data["history"]
+ if "conversation_id" not in data:
+ conversation_id = None
+ else:
+ conversation_id = data["conversation_id"]
print("-" * 5)
if not api_key_set:
api_key = data["api_key"]
@@ -364,6 +407,38 @@ def api_answer():
sources_doc.append({'title': doc.page_content, 'text': doc.page_content})
result['sources'] = sources_doc
+ # generate conversationId
+ if conversation_id is not None:
+ conversations_collection.update_one(
+ {"_id": ObjectId(conversation_id)},
+ {"$push": {"queries": {"prompt": question,
+ "response": result["answer"], "sources": result['sources']}}},
+ )
+
+ else:
+ # create new conversation
+ # generate summary
+ messages_summary = [AIMessage(content="Summarise following conversation in no more than 3 " +
+ "words, respond ONLY with the summary, use the same " +
+ "language as the system \n\nUser: " + question + "\n\nAI: " +
+ result["answer"]),
+ HumanMessage(content="Summarise following conversation in no more than 3 words, " +
+ "respond ONLY with the summary, use the same language as the " +
+ "system")]
+
+
+ # completion = openai.ChatCompletion.create(model='gpt-3.5-turbo', engine=settings.AZURE_DEPLOYMENT_NAME,
+ # messages=messages_summary, max_tokens=30, temperature=0)
+ completion = llm.predict_messages(messages_summary)
+ conversation_id = conversations_collection.insert_one(
+ {"user": "local",
+ "date": datetime.datetime.utcnow(),
+ "name": completion.content,
+ "queries": [{"prompt": question, "response": result["answer"], "sources": result['sources']}]}
+ ).inserted_id
+
+ result["conversation_id"] = str(conversation_id)
+
# mock result
# result = {
# "answer": "The answer is 42",
@@ -591,6 +666,39 @@ def delete_old():
return {"status": "ok"}
+@app.route("/api/get_conversations", methods=["get"])
+def get_conversations():
+ # provides a list of conversations
+ conversations = conversations_collection.find().sort("date", -1)
+ list_conversations = []
+ for conversation in conversations:
+ list_conversations.append({"id": str(conversation["_id"]), "name": conversation["name"]})
+
+ #list_conversations = [{"id": "default", "name": "default"}, {"id": "jeff", "name": "jeff"}]
+
+ return jsonify(list_conversations)
+
+@app.route("/api/get_single_conversation", methods=["get"])
+def get_single_conversation():
+ # provides data for a conversation
+ conversation_id = request.args.get("id")
+ conversation = conversations_collection.find_one({"_id": ObjectId(conversation_id)})
+ return jsonify(conversation['queries'])
+
+@app.route("/api/delete_conversation", methods=["POST"])
+def delete_conversation():
+ # deletes a conversation from the database
+ conversation_id = request.args.get("id")
+ # write to mongodb
+ conversations_collection.delete_one(
+ {
+ "_id": ObjectId(conversation_id),
+ }
+ )
+
+ return {"status": "ok"}
+
+
# handling CORS
@app.after_request
def after_request(response):
diff --git a/frontend/src/Navigation.tsx b/frontend/src/Navigation.tsx
index 697f0d9..b114ee9 100644
--- a/frontend/src/Navigation.tsx
+++ b/frontend/src/Navigation.tsx
@@ -19,10 +19,17 @@ import {
selectSelectedDocsStatus,
selectSourceDocs,
setSelectedDocs,
+ selectConversations,
+ setConversations,
+ selectConversationId,
} from './preferences/preferenceSlice';
+import {
+ setConversation,
+ updateConversationId,
+} from './conversation/conversationSlice';
import { useOutsideAlerter } from './hooks';
import Upload from './upload/Upload';
-import { Doc } from './preferences/preferenceApi';
+import { Doc, getConversations } from './preferences/preferenceApi';
export default function Navigation({
navState,
@@ -34,6 +41,8 @@ export default function Navigation({
const dispatch = useDispatch();
const docs = useSelector(selectSourceDocs);
const selectedDocs = useSelector(selectSelectedDocs);
+ const conversations = useSelector(selectConversations);
+ const conversationId = useSelector(selectConversationId);
const [isDocsListOpen, setIsDocsListOpen] = useState(false);
@@ -51,6 +60,33 @@ export default function Navigation({
const navRef = useRef(null);
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
+ useEffect(() => {
+ if (!conversations) {
+ getConversations()
+ .then((fetchedConversations) => {
+ dispatch(setConversations(fetchedConversations));
+ })
+ .catch((error) => {
+ console.error('Failed to fetch conversations: ', error);
+ });
+ }
+ }, [conversations, dispatch]);
+
+ const handleDeleteConversation = (id: string) => {
+ fetch(`${apiHost}/api/delete_conversation?id=${id}`, {
+ method: 'POST',
+ })
+ .then(() => {
+ // remove the image element from the DOM
+ const imageElement = document.querySelector(
+ `#img-${id}`,
+ ) as HTMLElement;
+ const parentElement = imageElement.parentNode as HTMLElement;
+ parentElement.parentNode?.removeChild(parentElement);
+ })
+ .catch((error) => console.error(error));
+ };
+
const handleDeleteClick = (index: number, doc: Doc) => {
const docPath = 'indexes/' + 'local' + '/' + doc.name;
@@ -67,6 +103,22 @@ export default function Navigation({
})
.catch((error) => console.error(error));
};
+
+ const handleConversationClick = (index: string) => {
+ // fetch the conversation from the server and setConversation in the store
+ fetch(`${apiHost}/api/get_single_conversation?id=${index}`, {
+ method: 'GET',
+ })
+ .then((response) => response.json())
+ .then((data) => {
+ dispatch(setConversation(data));
+ dispatch(
+ updateConversationId({
+ query: { conversationId: index },
+ }),
+ );
+ });
+ };
useOutsideAlerter(
navRef,
() => {
@@ -121,15 +173,56 @@ export default function Navigation({
Chat New Chat
-
+ {conversation.name} +
+