diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..ba034aa --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,14 @@ +# Security Policy + +## Supported Versions + +Supported Versions: + +Currently, we support security patches by committing changes and bumping the version published on Github. + +## Reporting a Vulnerability + +Found a vulnerability? Please email us: + +security@arc53.com + diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index bd1fa21..fa0ac4f 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -8,13 +8,12 @@ import traceback from pymongo import MongoClient from bson.objectid import ObjectId -from transformers import GPT2TokenizerFast from application.core.settings import settings -from application.vectorstore.vector_creator import VectorCreator from application.llm.llm_creator import LLMCreator +from application.retriever.retriever_creator import RetrieverCreator from application.error import bad_request @@ -62,9 +61,6 @@ async def async_generate(chain, question, chat_history): return result -def count_tokens(string): - tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') - return len(tokenizer(string)['input_ids']) def run_async_chain(chain, question, chat_history): @@ -104,61 +100,11 @@ def get_vectorstore(data): def is_azure_configured(): return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME - -def complete_stream(question, docsearch, chat_history, prompt_id, conversation_id, chunks=2): - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) - if prompt_id == 'default': - prompt = chat_combine_template - elif prompt_id == 'creative': - prompt = chat_combine_creative - elif prompt_id == 'strict': - prompt = chat_combine_strict - else: - prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] - - if chunks == 0: - docs = [] - else: - docs = docsearch.search(question, k=chunks) - if settings.LLM_NAME == "llama.cpp": - docs = [docs[0]] - # join all page_content together with a newline - docs_together = "\n".join([doc.page_content for doc in docs]) - p_chat_combine = prompt.replace("{summaries}", docs_together) - messages_combine = [{"role": "system", "content": p_chat_combine}] - source_log_docs = [] - for doc in docs: - if doc.metadata: - source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) - else: - source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) - - if len(chat_history) > 1: - tokens_current_history = 0 - # count tokens in history - chat_history.reverse() - for i in chat_history: - if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: - tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) - messages_combine.append({"role": "user", "content": question}) - - response_full = "" - completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, - messages=messages_combine) - for line in completion: - data = json.dumps({"answer": str(line)}) - response_full += str(line) - yield f"data: {data}\n\n" - - # save conversation to database - if conversation_id is not None: +def save_conversation(conversation_id, question, response, source_log_docs, llm): + if conversation_id is not None and conversation_id != "None": conversations_collection.update_one( {"_id": ObjectId(conversation_id)}, - {"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}}, + {"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}}, ) else: @@ -168,19 +114,50 @@ def complete_stream(question, docsearch, chat_history, prompt_id, conversation_i "words, respond ONLY with the summary, use the same " "language as the system \n\nUser: " + question + "\n\n" + "AI: " + - response_full}, + response}, {"role": "user", "content": "Summarise following conversation in no more than 3 words, " "respond ONLY with the summary, use the same language as the " "system"}] - completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, + completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30) conversation_id = conversations_collection.insert_one( {"user": "local", "date": datetime.datetime.utcnow(), "name": completion, - "queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]} + "queries": [{"prompt": question, "response": response, "sources": source_log_docs}]} ).inserted_id + return conversation_id + +def get_prompt(prompt_id): + if prompt_id == 'default': + prompt = chat_combine_template + elif prompt_id == 'creative': + prompt = chat_combine_creative + elif prompt_id == 'strict': + prompt = chat_combine_strict + else: + prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] + return prompt + + +def complete_stream(question, retriever, conversation_id): + + + response_full = "" + source_log_docs = [] + answer = retriever.gen() + for line in answer: + if "answer" in line: + response_full += str(line["answer"]) + data = json.dumps(line) + yield f"data: {data}\n\n" + elif "source" in line: + source_log_docs.append(line["source"]) + + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm) # send data.type = "end" to indicate that the stream has ended as json data = json.dumps({"type": "id", "id": str(conversation_id)}) @@ -207,29 +184,37 @@ def stream(): prompt_id = data["prompt_id"] else: prompt_id = 'default' - if 'chunks' in data: + if 'selectedDocs' in data and data['selectedDocs'] is None: + chunks = 0 + elif 'chunks' in data: chunks = int(data["chunks"]) else: chunks = 2 + + prompt = get_prompt(prompt_id) # check if active_docs is set if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) - vectorstore = get_vectorstore({"active_docs": data_key["source"]}) + source = {"active_docs": data_key["source"]} elif "active_docs" in data: - vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) + source = {"active_docs": data["active_docs"]} else: - vectorstore = "" - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) + source = {} + + if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + retriever_name = "classic" + else: + retriever_name = source['active_docs'] + + retriever = RetrieverCreator.create_retriever(retriever_name, question=question, + source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model + ) return Response( - complete_stream(question, docsearch, - chat_history=history, - prompt_id=prompt_id, - conversation_id=conversation_id, - chunks=chunks), mimetype="text/event-stream" - ) + complete_stream(question=question, retriever=retriever, + conversation_id=conversation_id), mimetype="text/event-stream") @answer.route("/api/answer", methods=["POST"]) @@ -253,110 +238,40 @@ def api_answer(): chunks = int(data["chunks"]) else: chunks = 2 - - if prompt_id == 'default': - prompt = chat_combine_template - elif prompt_id == 'creative': - prompt = chat_combine_creative - elif prompt_id == 'strict': - prompt = chat_combine_strict - else: - prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] + + prompt = get_prompt(prompt_id) # use try and except to check for exception try: # check if the vectorstore is set if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) - vectorstore = get_vectorstore({"active_docs": data_key["source"]}) + source = {"active_docs": data_key["source"]} else: - vectorstore = get_vectorstore(data) - # loading the index and the store and the prompt template - # Note if you have used other embeddings than OpenAI, you need to change the embeddings - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) + source = {data} + if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + retriever_name = "classic" + else: + retriever_name = source['active_docs'] + retriever = RetrieverCreator.create_retriever(retriever_name, question=question, + source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model + ) + source_log_docs = [] + response_full = "" + for line in retriever.gen(): + if "source" in line: + source_log_docs.append(line["source"]) + elif "answer" in line: + response_full += line["answer"] + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + result = {"answer": response_full, "sources": source_log_docs} + result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm) - - if chunks == 0: - docs = [] - else: - docs = docsearch.search(question, k=chunks) - # join all page_content together with a newline - docs_together = "\n".join([doc.page_content for doc in docs]) - p_chat_combine = prompt.replace("{summaries}", docs_together) - messages_combine = [{"role": "system", "content": p_chat_combine}] - source_log_docs = [] - for doc in docs: - if doc.metadata: - source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) - else: - source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) - # join all page_content together with a newline - - - if len(history) > 1: - tokens_current_history = 0 - # count tokens in history - history.reverse() - for i in history: - if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: - tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) - messages_combine.append({"role": "user", "content": question}) - - - completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, - messages=messages_combine) - - - result = {"answer": completion, "sources": source_log_docs} - logger.debug(result) - - # generate conversationId - if conversation_id is not None: - conversations_collection.update_one( - {"_id": ObjectId(conversation_id)}, - {"$push": {"queries": {"prompt": question, - "response": result["answer"], "sources": result['sources']}}}, - ) - - else: - # create new conversation - # generate summary - messages_summary = [ - {"role": "assistant", "content": "Summarise following conversation in no more than 3 words, " - "respond ONLY with the summary, use the same language as the system \n\n" - "User: " + question + "\n\n" + "AI: " + result["answer"]}, - {"role": "user", "content": "Summarise following conversation in no more than 3 words, " - "respond ONLY with the summary, use the same language as the system"} - ] - - completion = llm.gen( - model=gpt_model, - engine=settings.AZURE_DEPLOYMENT_NAME, - messages=messages_summary, - max_tokens=30 - ) - conversation_id = conversations_collection.insert_one( - {"user": "local", - "date": datetime.datetime.utcnow(), - "name": completion, - "queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]} - ).inserted_id - - result["conversation_id"] = str(conversation_id) - - # mock result - # result = { - # "answer": "The answer is 42", - # "sources": ["https://en.wikipedia.org/wiki/42_(number)", "https://en.wikipedia.org/wiki/42_(number)"] - # } return result except Exception as e: # print whole traceback @@ -373,27 +288,24 @@ def api_search(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) - vectorstore = data_key["source"] + source = {"active_docs": data_key["source"]} elif "active_docs" in data: - vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) + source = {"active_docs": data["active_docs"]} else: - vectorstore = "" + source = {} if 'chunks' in data: chunks = int(data["chunks"]) else: chunks = 2 - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) - if chunks == 0: - docs = [] + + if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + retriever_name = "classic" else: - docs = docsearch.search(question, k=chunks) + retriever_name = source['active_docs'] - source_log_docs = [] - for doc in docs: - if doc.metadata: - source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) - else: - source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) - #yield f"data:{data}\n\n" - return source_log_docs + retriever = RetrieverCreator.create_retriever(retriever_name, question=question, + source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model + ) + docs = retriever.search() + return docs diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 239278b..de5a3c0 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1,6 +1,8 @@ import os import uuid +import shutil from flask import Blueprint, request, jsonify +from urllib.parse import urlparse import requests from pymongo import MongoClient from bson.objectid import ObjectId @@ -135,30 +137,43 @@ def upload_file(): return {"status": "no name"} job_name = secure_filename(request.form["name"]) # check if the post request has the file part - if "file" not in request.files: - print("No file part") - return {"status": "no file"} - file = request.files["file"] - if file.filename == "": + files = request.files.getlist("file") + + if not files or all(file.filename == '' for file in files): return {"status": "no file name"} - if file: - filename = secure_filename(file.filename) - # save dir - save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name) - # create dir if not exists - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - file.save(os.path.join(save_dir, filename)) - task = ingest.delay(settings.UPLOAD_FOLDER, [".rst", ".md", ".pdf", ".txt", ".docx", - ".csv", ".epub", ".html", ".mdx"], - job_name, filename, user) - # task id - task_id = task.id - return {"status": "ok", "task_id": task_id} + # Directory where files will be saved + save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name) + os.makedirs(save_dir, exist_ok=True) + + if len(files) > 1: + # Multiple files; prepare them for zip + temp_dir = os.path.join(save_dir, "temp") + os.makedirs(temp_dir, exist_ok=True) + + for file in files: + filename = secure_filename(file.filename) + file.save(os.path.join(temp_dir, filename)) + + # Use shutil.make_archive to zip the temp directory + zip_path = shutil.make_archive(base_name=os.path.join(save_dir, job_name), format='zip', root_dir=temp_dir) + final_filename = os.path.basename(zip_path) + + # Clean up the temporary directory after zipping + shutil.rmtree(temp_dir) else: - return {"status": "error"} + # Single file + file = files[0] + final_filename = secure_filename(file.filename) + file_path = os.path.join(save_dir, final_filename) + file.save(file_path) + + # Call ingest with the single file or zipped file + task = ingest.delay(settings.UPLOAD_FOLDER, [".rst", ".md", ".pdf", ".txt", ".docx", + ".csv", ".epub", ".html", ".mdx"], + job_name, final_filename, user) + + return {"status": "ok", "task_id": task.id} @user.route("/api/remote", methods=["POST"]) def upload_remote(): @@ -236,6 +251,34 @@ def combined_json(): for index in data_remote: index["location"] = "remote" data.append(index) + if 'duckduck_search' in settings.RETRIEVERS_ENABLED: + data.append( + { + "name": "DuckDuckGo Search", + "language": "en", + "version": "", + "description": "duckduck_search", + "fullName": "DuckDuckGo Search", + "date": "duckduck_search", + "docLink": "duckduck_search", + "model": settings.EMBEDDINGS_NAME, + "location": "custom", + } + ) + if 'brave_search' in settings.RETRIEVERS_ENABLED: + data.append( + { + "name": "Brave Search", + "language": "en", + "version": "", + "description": "brave_search", + "fullName": "Brave Search", + "date": "brave_search", + "docLink": "brave_search", + "model": settings.EMBEDDINGS_NAME, + "location": "custom", + } + ) return jsonify(data) @@ -247,25 +290,32 @@ def check_docs(): # split docs on / and take first part if data["docs"].split("/")[0] == "local": return {"status": "exists"} - vectorstore = "vectors/" + data["docs"] + vectorstore = "vectors/" + secure_filename(data["docs"]) base_path = "https://raw.githubusercontent.com/arc53/DocsHUB/main/" if os.path.exists(vectorstore) or data["docs"] == "default": return {"status": "exists"} else: - r = requests.get(base_path + vectorstore + "index.faiss") - - if r.status_code != 200: - return {"status": "null"} + file_url = urlparse(base_path + vectorstore + "index.faiss") + + if ( + file_url.scheme in ['https'] and + file_url.netloc == 'raw.githubusercontent.com' and + file_url.path.startswith('/arc53/DocsHUB/main/') + ): + r = requests.get(file_url.geturl()) + if r.status_code != 200: + return {"status": "null"} + else: + if not os.path.exists(vectorstore): + os.makedirs(vectorstore) + with open(vectorstore + "index.faiss", "wb") as f: + f.write(r.content) + + r = requests.get(base_path + vectorstore + "index.pkl") + with open(vectorstore + "index.pkl", "wb") as f: + f.write(r.content) else: - if not os.path.exists(vectorstore): - os.makedirs(vectorstore) - with open(vectorstore + "index.faiss", "wb") as f: - f.write(r.content) - - # download the store - r = requests.get(base_path + vectorstore + "index.pkl") - with open(vectorstore + "index.pkl", "wb") as f: - f.write(r.content) + return {"status": "null"} return {"status": "loaded"} @@ -351,7 +401,14 @@ def get_api_keys(): keys = api_key_collection.find({"user": user}) list_keys = [] for key in keys: - list_keys.append({"id": str(key["_id"]), "name": key["name"], "key": key["key"][:4] + "..." + key["key"][-4:], "source": key["source"]}) + list_keys.append({ + "id": str(key["_id"]), + "name": key["name"], + "key": key["key"][:4] + "..." + key["key"][-4:], + "source": key["source"], + "prompt_id": key["prompt_id"], + "chunks": key["chunks"] + }) return jsonify(list_keys) @user.route("/api/create_api_key", methods=["POST"]) @@ -359,6 +416,8 @@ def create_api_key(): data = request.get_json() name = data["name"] source = data["source"] + prompt_id = data["prompt_id"] + chunks = data["chunks"] key = str(uuid.uuid4()) user = "local" resp = api_key_collection.insert_one( @@ -367,6 +426,8 @@ def create_api_key(): "key": key, "source": source, "user": user, + "prompt_id": prompt_id, + "chunks": chunks } ) new_id = str(resp.inserted_id) diff --git a/application/app.py b/application/app.py index ae61997..e646ffb 100644 --- a/application/app.py +++ b/application/app.py @@ -40,5 +40,5 @@ def after_request(response): return response if __name__ == "__main__": - app.run(debug=True, port=7091) + app.run(debug=settings.FLASK_DEBUG_MODE, port=7091) diff --git a/application/core/settings.py b/application/core/settings.py index 0e1909e..26c27ed 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -9,7 +9,7 @@ current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__ class Settings(BaseSettings): LLM_NAME: str = "docsgpt" - MODEL_NAME: Optional[str] = None # when LLM_NAME is openai, MODEL_NAME can be e.g. gpt-4-turbo-preview or gpt-3.5-turbo + MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2" CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" @@ -18,6 +18,7 @@ class Settings(BaseSettings): TOKENS_MAX_HISTORY: int = 150 UPLOAD_FOLDER: str = "inputs" VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" + RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search API_URL: str = "http://localhost:7091" # backend url for celery worker @@ -59,6 +60,10 @@ class Settings(BaseSettings): QDRANT_PATH: Optional[str] = None QDRANT_DISTANCE_FUNC: str = "Cosine" + BRAVE_SEARCH_API_KEY: Optional[str] = None + + FLASK_DEBUG_MODE: bool = False + path = Path(__file__).parent.parent.absolute() settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index a64d71e..6b0d646 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -10,7 +10,7 @@ class AnthropicLLM(BaseLLM): self.HUMAN_PROMPT = HUMAN_PROMPT self.AI_PROMPT = AI_PROMPT - def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs): + def gen(self, model, messages, max_tokens=300, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Context \n {context} \n ### Question \n {user_question}" @@ -25,7 +25,7 @@ class AnthropicLLM(BaseLLM): ) return completion.completion - def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs): + def gen_stream(self, model, messages, max_tokens=300, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Context \n {context} \n ### Question \n {user_question}" diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index e0c5dba..d540a91 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -8,7 +8,7 @@ class DocsGPTAPILLM(BaseLLM): self.endpoint = "https://llm.docsgpt.co.uk" - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -24,7 +24,7 @@ class DocsGPTAPILLM(BaseLLM): return response_clean - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index ef3b1fb..554bee2 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -29,7 +29,7 @@ class HuggingFaceLLM(BaseLLM): ) hf = HuggingFacePipeline(pipeline=pipe) - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -38,7 +38,7 @@ class HuggingFaceLLM(BaseLLM): return result.content - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.") diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index f18d437..be34d4f 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -12,7 +12,7 @@ class LlamaCpp(BaseLLM): llama = Llama(model_path=llm_name, n_ctx=2048) - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -24,7 +24,7 @@ class LlamaCpp(BaseLLM): return result['choices'][0]['text'].split('### Answer \n')[-1] - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/llm/openai.py b/application/llm/openai.py index a132399..4b0ed25 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -18,7 +18,7 @@ class OpenAILLM(BaseLLM): return openai - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): response = self.client.chat.completions.create(model=model, messages=messages, stream=stream, @@ -26,7 +26,7 @@ class OpenAILLM(BaseLLM): return response.choices[0].message.content - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): response = self.client.chat.completions.create(model=model, messages=messages, stream=stream, diff --git a/application/llm/premai.py b/application/llm/premai.py index 4bc8a89..5faa5fe 100644 --- a/application/llm/premai.py +++ b/application/llm/premai.py @@ -12,7 +12,7 @@ class PremAILLM(BaseLLM): self.api_key = api_key self.project_id = settings.PREMAI_PROJECT_ID - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): response = self.client.chat.completions.create(model=model, project_id=self.project_id, messages=messages, @@ -21,7 +21,7 @@ class PremAILLM(BaseLLM): return response.choices[0].message["content"] - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): response = self.client.chat.completions.create(model=model, project_id=self.project_id, messages=messages, diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 84ae09a..b81f638 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -74,7 +74,7 @@ class SagemakerAPILLM(BaseLLM): self.runtime = runtime - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -103,7 +103,7 @@ class SagemakerAPILLM(BaseLLM): print(result[0]['generated_text'], file=sys.stderr) return result[0]['generated_text'][len(prompt):] - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/parser/token_func.py b/application/parser/token_func.py index 36ae7e5..7511cde 100644 --- a/application/parser/token_func.py +++ b/application/parser/token_func.py @@ -22,7 +22,10 @@ def group_documents(documents: List[Document], min_tokens: int, max_tokens: int) doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text)) # Check if current group is empty or if the document can be added based on token count and matching metadata - if current_group is None or (len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and doc_len < min_tokens and current_group.extra_info == doc.extra_info): + if (current_group is None or + (len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and + doc_len < min_tokens and + current_group.extra_info == doc.extra_info)): if current_group is None: current_group = doc # Use the document directly to retain its metadata else: diff --git a/application/requirements.txt b/application/requirements.txt index 0874a7c..dbbdbc8 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -3,6 +3,7 @@ boto3==1.34.6 celery==5.3.6 dataclasses_json==0.6.3 docx2txt==0.8 +duckduckgo-search==5.3.0 EbookLib==0.18 elasticsearch==8.12.0 escodegen==1.0.11 @@ -21,7 +22,7 @@ pydantic_settings==2.1.0 pymongo==4.6.1 PyPDF2==3.0.1 python-dotenv==1.0.1 -qdrant-client==1.7.3 +qdrant-client==1.8.2 redis==5.0.1 Requests==2.31.0 retry==0.9.2 diff --git a/application/retriever/__init__.py b/application/retriever/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/application/retriever/base.py b/application/retriever/base.py new file mode 100644 index 0000000..4a37e81 --- /dev/null +++ b/application/retriever/base.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod + + +class BaseRetriever(ABC): + def __init__(self): + pass + + @abstractmethod + def gen(self, *args, **kwargs): + pass + + @abstractmethod + def search(self, *args, **kwargs): + pass diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py new file mode 100644 index 0000000..0cc7bd4 --- /dev/null +++ b/application/retriever/brave_search.py @@ -0,0 +1,75 @@ +import json +from application.retriever.base import BaseRetriever +from application.core.settings import settings +from application.llm.llm_creator import LLMCreator +from application.utils import count_tokens +from langchain_community.tools import BraveSearch + + + +class BraveRetSearch(BaseRetriever): + + def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + self.question = question + self.source = source + self.chat_history = chat_history + self.prompt = prompt + self.chunks = chunks + self.gpt_model = gpt_model + + def _get_data(self): + if self.chunks == 0: + docs = [] + else: + search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY, + search_kwargs={"count": int(self.chunks)}) + results = search.run(self.question) + results = json.loads(results) + + docs = [] + for i in results: + try: + title = i['title'] + link = i['link'] + snippet = i['snippet'] + docs.append({"text": snippet, "title": title, "link": link}) + except IndexError: + pass + if settings.LLM_NAME == "llama.cpp": + docs = [docs[0]] + + return docs + + def gen(self): + docs = self._get_data() + + # join all page_content together with a newline + docs_together = "\n".join([doc["text"] for doc in docs]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + for doc in docs: + yield {"source": doc} + + if len(self.chat_history) > 1: + tokens_current_history = 0 + # count tokens in history + self.chat_history.reverse() + for i in self.chat_history: + if "prompt" in i and "response" in i: + tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) + if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_current_history += tokens_batch + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append({"role": "user", "content": self.question}) + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + completion = llm.gen_stream(model=self.gpt_model, + messages=messages_combine) + for line in completion: + yield {"answer": str(line)} + + def search(self): + return self._get_data() + diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py new file mode 100644 index 0000000..b5f1eb9 --- /dev/null +++ b/application/retriever/classic_rag.py @@ -0,0 +1,91 @@ +import os +from application.retriever.base import BaseRetriever +from application.core.settings import settings +from application.vectorstore.vector_creator import VectorCreator +from application.llm.llm_creator import LLMCreator + +from application.utils import count_tokens + + + +class ClassicRAG(BaseRetriever): + + def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + self.question = question + self.vectorstore = self._get_vectorstore(source=source) + self.chat_history = chat_history + self.prompt = prompt + self.chunks = chunks + self.gpt_model = gpt_model + + def _get_vectorstore(self, source): + if "active_docs" in source: + if source["active_docs"].split("/")[0] == "default": + vectorstore = "" + elif source["active_docs"].split("/")[0] == "local": + vectorstore = "indexes/" + source["active_docs"] + else: + vectorstore = "vectors/" + source["active_docs"] + if source["active_docs"] == "default": + vectorstore = "" + else: + vectorstore = "" + vectorstore = os.path.join("application", vectorstore) + return vectorstore + + + def _get_data(self): + if self.chunks == 0: + docs = [] + else: + docsearch = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, + self.vectorstore, + settings.EMBEDDINGS_KEY + ) + docs_temp = docsearch.search(self.question, k=self.chunks) + docs = [ + { + "title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, + "text": i.page_content + } + for i in docs_temp + ] + if settings.LLM_NAME == "llama.cpp": + docs = [docs[0]] + + return docs + + def gen(self): + docs = self._get_data() + + # join all page_content together with a newline + docs_together = "\n".join([doc["text"] for doc in docs]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + for doc in docs: + yield {"source": doc} + + if len(self.chat_history) > 1: + tokens_current_history = 0 + # count tokens in history + self.chat_history.reverse() + for i in self.chat_history: + if "prompt" in i and "response" in i: + tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) + if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_current_history += tokens_batch + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append({"role": "user", "content": self.question}) + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + completion = llm.gen_stream(model=self.gpt_model, + messages=messages_combine) + for line in completion: + yield {"answer": str(line)} + + def search(self): + return self._get_data() + diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py new file mode 100644 index 0000000..d662bb0 --- /dev/null +++ b/application/retriever/duckduck_search.py @@ -0,0 +1,94 @@ +from application.retriever.base import BaseRetriever +from application.core.settings import settings +from application.llm.llm_creator import LLMCreator +from application.utils import count_tokens +from langchain_community.tools import DuckDuckGoSearchResults +from langchain_community.utilities import DuckDuckGoSearchAPIWrapper + + + +class DuckDuckSearch(BaseRetriever): + + def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + self.question = question + self.source = source + self.chat_history = chat_history + self.prompt = prompt + self.chunks = chunks + self.gpt_model = gpt_model + + def _parse_lang_string(self, input_string): + result = [] + current_item = "" + inside_brackets = False + for char in input_string: + if char == "[": + inside_brackets = True + elif char == "]": + inside_brackets = False + result.append(current_item) + current_item = "" + elif inside_brackets: + current_item += char + + if inside_brackets: + result.append(current_item) + + return result + + def _get_data(self): + if self.chunks == 0: + docs = [] + else: + wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks) + search = DuckDuckGoSearchResults(api_wrapper=wrapper) + results = search.run(self.question) + results = self._parse_lang_string(results) + + docs = [] + for i in results: + try: + text = i.split("title:")[0] + title = i.split("title:")[1].split("link:")[0] + link = i.split("link:")[1] + docs.append({"text": text, "title": title, "link": link}) + except IndexError: + pass + if settings.LLM_NAME == "llama.cpp": + docs = [docs[0]] + + return docs + + def gen(self): + docs = self._get_data() + + # join all page_content together with a newline + docs_together = "\n".join([doc["text"] for doc in docs]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + for doc in docs: + yield {"source": doc} + + if len(self.chat_history) > 1: + tokens_current_history = 0 + # count tokens in history + self.chat_history.reverse() + for i in self.chat_history: + if "prompt" in i and "response" in i: + tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) + if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_current_history += tokens_batch + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append({"role": "user", "content": self.question}) + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + completion = llm.gen_stream(model=self.gpt_model, + messages=messages_combine) + for line in completion: + yield {"answer": str(line)} + + def search(self): + return self._get_data() + diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py new file mode 100644 index 0000000..ad07140 --- /dev/null +++ b/application/retriever/retriever_creator.py @@ -0,0 +1,19 @@ +from application.retriever.classic_rag import ClassicRAG +from application.retriever.duckduck_search import DuckDuckSearch +from application.retriever.brave_search import BraveRetSearch + + + +class RetrieverCreator: + retievers = { + 'classic': ClassicRAG, + 'duckduck_search': DuckDuckSearch, + 'brave_search': BraveRetSearch + } + + @classmethod + def create_retriever(cls, type, *args, **kwargs): + retiever_class = cls.retievers.get(type.lower()) + if not retiever_class: + raise ValueError(f"No retievers class found for type {type}") + return retiever_class(*args, **kwargs) \ No newline at end of file diff --git a/application/utils.py b/application/utils.py new file mode 100644 index 0000000..ac98efc --- /dev/null +++ b/application/utils.py @@ -0,0 +1,6 @@ +from transformers import GPT2TokenizerFast + + +def count_tokens(string): + tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') + return len(tokenizer(string)['input_ids']) \ No newline at end of file diff --git a/application/worker.py b/application/worker.py index 3891fde..eb28242 100644 --- a/application/worker.py +++ b/application/worker.py @@ -36,6 +36,32 @@ current_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) +def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): + """ + Recursively extract zip files with a limit on recursion depth. + + Args: + zip_path (str): Path to the zip file to be extracted. + extract_to (str): Destination path for extracted files. + current_depth (int): Current depth of recursion. + max_depth (int): Maximum allowed depth of recursion to prevent infinite loops. + """ + if current_depth > max_depth: + print(f"Reached maximum recursion depth of {max_depth}") + return + + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(extract_to) + os.remove(zip_path) # Remove the zip file after extracting + + # Check for nested zip files and extract them + for root, dirs, files in os.walk(extract_to): + for file in files: + if file.endswith(".zip"): + # If a nested zip file is found, extract it recursively + file_path = os.path.join(root, file) + extract_zip_recursive(file_path, root, current_depth + 1, max_depth) + # Define the main function for ingesting and processing documents. def ingest_worker(self, directory, formats, name_job, filename, user): @@ -66,9 +92,11 @@ def ingest_worker(self, directory, formats, name_job, filename, user): token_check = True min_tokens = 150 max_tokens = 1250 - full_path = directory + "/" + user + "/" + name_job + recursion_depth = 2 + full_path = os.path.join(directory, user, name_job) import sys + print(full_path, file=sys.stderr) # check if API_URL env variable is set file_data = {"name": name_job, "file": filename, "user": user} @@ -81,14 +109,12 @@ def ingest_worker(self, directory, formats, name_job, filename, user): if not os.path.exists(full_path): os.makedirs(full_path) - with open(full_path + "/" + filename, "wb") as f: + with open(os.path.join(full_path, filename), "wb") as f: f.write(file) # check if file is .zip and extract it if filename.endswith(".zip"): - with zipfile.ZipFile(full_path + "/" + filename, "r") as zip_ref: - zip_ref.extractall(full_path) - os.remove(full_path + "/" + filename) + extract_zip_recursive(os.path.join(full_path, filename), full_path, 0, recursion_depth) self.update_state(state="PROGRESS", meta={"current": 1}) diff --git a/application/wsgi.py b/application/wsgi.py index 5160e11..d0a7db0 100644 --- a/application/wsgi.py +++ b/application/wsgi.py @@ -1,4 +1,5 @@ from application.app import app +from application.core.settings import settings if __name__ == "__main__": - app.run(debug=True, port=7091) + app.run(debug=settings.FLASK_DEBUG_MODE, port=7091) diff --git a/docs/pages/Developing/API-docs.md b/docs/pages/Developing/API-docs.md index c5bd897..a85ed6f 100644 --- a/docs/pages/Developing/API-docs.md +++ b/docs/pages/Developing/API-docs.md @@ -281,6 +281,8 @@ Create a new API key for the user. **Request Body**: JSON object with the following fields: * `name` — A name for the API key. * `source` — The source documents that will be used. +* `prompt_id` — The prompt ID. +* `chunks` — The number of chunks used to process an answer. Here is a JavaScript Fetch Request example: ```js @@ -290,7 +292,10 @@ fetch("http://127.0.0.1:5000/api/create_api_key", { "headers": { "Content-Type": "application/json; charset=utf-8" }, - "body": JSON.stringify({"name":"Example Key Name","source":"Example Source"}) + "body": JSON.stringify({"name":"Example Key Name", + "source":"Example Source", + "prompt_id":"creative", + "chunks":"2"}) }) .then((res) => res.json()) .then(console.log.bind(console)) diff --git a/docs/pages/Extensions/_meta.json b/docs/pages/Extensions/_meta.json index 05a4991..8ce9b78 100644 --- a/docs/pages/Extensions/_meta.json +++ b/docs/pages/Extensions/_meta.json @@ -4,7 +4,11 @@ "href": "/Extensions/Chatwoot-extension" }, "react-widget": { - "title": "🏗️ Widget setup", - "href": "/Extensions/react-widget" - } + "title": "🏗️ Widget setup", + "href": "/Extensions/react-widget" + }, + "api-key-guide": { + "title": "🔐 API Keys guide", + "href": "/Extensions/api-key-guide" + } } \ No newline at end of file diff --git a/docs/pages/Extensions/api-key-guide.md b/docs/pages/Extensions/api-key-guide.md new file mode 100644 index 0000000..6ba30a0 --- /dev/null +++ b/docs/pages/Extensions/api-key-guide.md @@ -0,0 +1,30 @@ +## Guide to DocsGPT API Keys + +DocsGPT API keys are essential for developers and users who wish to integrate the DocsGPT models into external applications, such as the our widget. This guide will walk you through the steps of obtaining an API key, starting from uploading your document to understanding the key variables associated with API keys. + +### Uploading Your Document + +Before creating your first API key, you must upload the document that will be linked to this key. You can upload your document through two methods: + +- **GUI Web App Upload:** A user-friendly graphical interface that allows for easy upload and management of documents. +- **Using `/api/upload` Method:** For users comfortable with API calls, this method provides a direct way to upload documents. + +### Obtaining Your API Key + +After uploading your document, you can obtain an API key either through the graphical user interface or via an API call: + +- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the API Keys option, and press 'Create New' to generate your key. +- **API Call:** Alternatively, you can use the `/api/create_api_key` endpoint to create a new API key. For detailed instructions, visit [DocsGPT API Documentation](https://docs.docsgpt.co.uk/Developing/API-docs#8-apicreate_api_key). + +### Understanding Key Variables + +Upon creating your API key, you will encounter several key variables. Each serves a specific purpose: + +- **Name:** Assign a name to your API key for easy identification. +- **Source:** Indicates the source document(s) linked to your API key, which DocsGPT will use to generate responses. +- **ID:** A unique identifier for your API key. You can view this by making a call to `/api/get_api_keys`. +- **Key:** The API key itself, which will be used in your application to authenticate API requests. + +With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the API key, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation. + +Congratulations on taking the first step towards enhancing your applications with DocsGPT! With this guide, you're now equipped to navigate the process of obtaining and understanding DocsGPT API keys. \ No newline at end of file diff --git a/docs/pages/_app.mdx b/docs/pages/_app.mdx index 992283c..39c0157 100644 --- a/docs/pages/_app.mdx +++ b/docs/pages/_app.mdx @@ -4,7 +4,7 @@ export default function MyApp({ Component, pageProps }) { return ( <> - + ) } \ No newline at end of file diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 8b8f023..98aedce 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -44,7 +44,7 @@ "prettier-plugin-tailwindcss": "^0.2.2", "tailwindcss": "^3.2.4", "typescript": "^4.9.5", - "vite": "^5.0.12", + "vite": "^5.0.13", "vite-plugin-svgr": "^4.2.0" } }, @@ -7855,9 +7855,9 @@ } }, "node_modules/vite": { - "version": "5.0.12", - "resolved": "https://registry.npmjs.org/vite/-/vite-5.0.12.tgz", - "integrity": "sha512-4hsnEkG3q0N4Tzf1+t6NdN9dg/L3BM+q8SWgbSPnJvrgH2kgdyzfVJwbR1ic69/4uMJJ/3dqDZZE5/WwqW8U1w==", + "version": "5.0.13", + "resolved": "https://registry.npmjs.org/vite/-/vite-5.0.13.tgz", + "integrity": "sha512-/9ovhv2M2dGTuA+dY93B9trfyWMDRQw2jdVBhHNP6wr0oF34wG2i/N55801iZIpgUpnHDm4F/FabGQLyc+eOgg==", "dev": true, "dependencies": { "esbuild": "^0.19.3", diff --git a/frontend/package.json b/frontend/package.json index 0c7dead..6e8d8e6 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -55,7 +55,7 @@ "prettier-plugin-tailwindcss": "^0.2.2", "tailwindcss": "^3.2.4", "typescript": "^4.9.5", - "vite": "^5.0.12", + "vite": "^5.0.13", "vite-plugin-svgr": "^4.2.0" } } diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 6a1fd00..987f38e 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -6,7 +6,7 @@ import PageNotFound from './PageNotFound'; import { inject } from '@vercel/analytics'; import { useMediaQuery } from './hooks'; import { useState } from 'react'; -import Setting from './Setting'; +import Setting from './settings'; inject(); diff --git a/frontend/src/Setting.tsx b/frontend/src/Setting.tsx deleted file mode 100644 index 172691a..0000000 --- a/frontend/src/Setting.tsx +++ /dev/null @@ -1,1002 +0,0 @@ -import React, { useState, useEffect } from 'react'; -import { useSelector, useDispatch } from 'react-redux'; -import ArrowLeft from './assets/arrow-left.svg'; -import ArrowRight from './assets/arrow-right.svg'; -import Exit from './assets/exit.svg'; -import Trash from './assets/trash.svg'; -import { - selectPrompt, - setPrompt, - selectSourceDocs, - setSourceDocs, - setChunks, - selectChunks, -} from './preferences/preferenceSlice'; -import { Doc } from './preferences/preferenceApi'; -import { useDarkTheme } from './hooks'; -import Dropdown from './components/Dropdown'; -const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com'; - -const embeddingsName = - import.meta.env.VITE_EMBEDDINGS_NAME || - 'huggingface_sentence-transformers/all-mpnet-base-v2'; -const Setting: React.FC = () => { - const tabs = ['General', 'Prompts', 'Documents', 'API Keys']; - //const tabs = ['General', 'Prompts', 'Documents', 'Widgets']; - - const [activeTab, setActiveTab] = useState('General'); - const [prompts, setPrompts] = useState< - { name: string; id: string; type: string }[] - >([]); - const selectedPrompt = useSelector(selectPrompt); - const [isAddPromptModalOpen, setAddPromptModalOpen] = useState(false); - const documents = useSelector(selectSourceDocs); - const [isAddDocumentModalOpen, setAddDocumentModalOpen] = useState(false); - - const dispatch = useDispatch(); - - const [widgetScreenshot, setWidgetScreenshot] = useState(null); - - const updateWidgetScreenshot = (screenshot: File | null) => { - setWidgetScreenshot(screenshot); - }; - - useEffect(() => { - const fetchPrompts = async () => { - try { - const response = await fetch(`${apiHost}/api/get_prompts`); - if (!response.ok) { - throw new Error('Failed to fetch prompts'); - } - const promptsData = await response.json(); - setPrompts(promptsData); - } catch (error) { - console.error(error); - } - }; - fetchPrompts(); - }, []); - - const onDeletePrompt = (name: string, id: string) => { - setPrompts(prompts.filter((prompt) => prompt.id !== id)); - - fetch(`${apiHost}/api/delete_prompt`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - // send id in body only - body: JSON.stringify({ id: id }), - }) - .then((response) => { - if (!response.ok) { - throw new Error('Failed to delete prompt'); - } - }) - .catch((error) => { - console.error(error); - }); - }; - - const handleDeleteClick = (index: number, doc: Doc) => { - const docPath = 'indexes/' + 'local' + '/' + doc.name; - - fetch(`${apiHost}/api/delete_old?path=${docPath}`, { - method: 'GET', - }) - .then((response) => { - if (response.ok && documents) { - const updatedDocuments = [ - ...documents.slice(0, index), - ...documents.slice(index + 1), - ]; - dispatch(setSourceDocs(updatedDocuments)); - } - }) - .catch((error) => console.error(error)); - }; - - return ( -
-

- Settings -

-
-
- -
-
- {tabs.map((tab, index) => ( - - ))} -
-
- -
-
- {renderActiveTab()} - - {/* {activeTab === 'Widgets' && ( - - )} */} -
- ); - - function scrollTabs(direction: number) { - const container = document.querySelector('.flex-nowrap'); - if (container) { - container.scrollLeft += direction * 100; // Adjust the scroll amount as needed - } - } - - function renderActiveTab() { - switch (activeTab) { - case 'General': - return ; - case 'Prompts': - return ( - - dispatch(setPrompt({ name: name, id: id, type: type })) - } - setPrompts={setPrompts} - /> - ); - case 'Documents': - return ( - - ); - case 'Widgets': - return ( - - ); - case 'API Keys': - return ; - default: - return null; - } - } -}; - -const General: React.FC = () => { - const themes = ['Light', 'Dark']; - const languages = ['English']; - const chunks = ['0', '2', '4', '6', '8', '10']; - const selectedChunks = useSelector(selectChunks); - const [isDarkTheme, toggleTheme] = useDarkTheme(); - const [selectedTheme, setSelectedTheme] = useState( - isDarkTheme ? 'Dark' : 'Light', - ); - const dispatch = useDispatch(); - const [selectedLanguage, setSelectedLanguage] = useState(languages[0]); - return ( -
-
-

Select Theme

- { - setSelectedTheme(option); - option !== selectedTheme && toggleTheme(); - }} - /> -
-
-

- Select Language -

- -
-
-

- Chunks processed per query -

- dispatch(setChunks(value))} - /> -
-
- ); -}; - -export default Setting; -type PromptProps = { - prompts: { name: string; id: string; type: string }[]; - selectedPrompt: { name: string; id: string; type: string }; - onSelectPrompt: (name: string, id: string, type: string) => void; - setPrompts: (prompts: { name: string; id: string; type: string }[]) => void; -}; - -const Prompts: React.FC = ({ - prompts, - selectedPrompt, - onSelectPrompt, - setPrompts, -}) => { - const handleSelectPrompt = ({ - name, - id, - type, - }: { - name: string; - id: string; - type: string; - }) => { - setNewPromptName(name); - onSelectPrompt(name, id, type); - }; - const [newPromptName, setNewPromptName] = useState(selectedPrompt.name); - const [newPromptContent, setNewPromptContent] = useState(''); - - const handleAddPrompt = async () => { - try { - const response = await fetch(`${apiHost}/api/create_prompt`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - name: newPromptName, - content: newPromptContent, - }), - }); - if (!response.ok) { - throw new Error('Failed to add prompt'); - } - const newPrompt = await response.json(); - if (setPrompts) { - setPrompts([ - ...prompts, - { name: newPromptName, id: newPrompt.id, type: 'private' }, - ]); - } - onSelectPrompt(newPromptName, newPrompt.id, newPromptContent); - setNewPromptName(newPromptName); - } catch (error) { - console.error(error); - } - }; - - const handleDeletePrompt = () => { - setPrompts(prompts.filter((prompt) => prompt.id !== selectedPrompt.id)); - console.log('selectedPrompt.id', selectedPrompt.id); - - fetch(`${apiHost}/api/delete_prompt`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ id: selectedPrompt.id }), - }) - .then((response) => { - if (!response.ok) { - throw new Error('Failed to delete prompt'); - } - // get 1st prompt and set it as selected - if (prompts.length > 0) { - onSelectPrompt(prompts[0].name, prompts[0].id, prompts[0].type); - setNewPromptName(prompts[0].name); - } - }) - .catch((error) => { - console.error(error); - }); - }; - - useEffect(() => { - const fetchPromptContent = async () => { - console.log('fetching prompt content'); - try { - const response = await fetch( - `${apiHost}/api/get_single_prompt?id=${selectedPrompt.id}`, - { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - }, - }, - ); - if (!response.ok) { - throw new Error('Failed to fetch prompt content'); - } - const promptContent = await response.json(); - setNewPromptContent(promptContent.content); - } catch (error) { - console.error(error); - } - }; - - fetchPromptContent(); - }, [selectedPrompt]); - - const handleSaveChanges = () => { - fetch(`${apiHost}/api/update_prompt`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - id: selectedPrompt.id, - name: newPromptName, - content: newPromptContent, - }), - }) - .then((response) => { - if (!response.ok) { - throw new Error('Failed to update prompt'); - } - onSelectPrompt(newPromptName, selectedPrompt.id, selectedPrompt.type); - setNewPromptName(newPromptName); - }) - .catch((error) => { - console.error(error); - }); - }; - - return ( -
-
-

Active Prompt

- -
- -
-

Prompt name

{' '} -

- start by editing name -

- setNewPromptName(e.target.value)} - /> -
- -
-

Prompt content

- +
+
+ +
+
+ + ); +} + +function EditPrompt({ + setModalState, + handleEditPrompt, + editPromptName, + setEditPromptName, + editPromptContent, + setEditPromptContent, + currentPromptEdit, +}: { + setModalState: (state: ActiveState) => void; + handleEditPrompt?: (id: string, type: string) => void; + editPromptName: string; + setEditPromptName: (name: string) => void; + editPromptContent: string; + setEditPromptContent: (content: string) => void; + currentPromptEdit: { name: string; id: string; type: string }; +}) { + return ( +
+ +
+

+ Edit Prompt +

+

+ Edit your custom prompt and save it to DocsGPT +

+
+ setEditPromptName(e.target.value)} + > +
+ + Prompt Name + +
+
+ + Prompt Text + +
+ +
+
+ +
+
+
+ ); +} + +export default function PromptsModal({ + modalState, + setModalState, + type, + newPromptName, + setNewPromptName, + newPromptContent, + setNewPromptContent, + editPromptName, + setEditPromptName, + editPromptContent, + setEditPromptContent, + currentPromptEdit, + handleAddPrompt, + handleEditPrompt, +}: { + modalState: ActiveState; + setModalState: (state: ActiveState) => void; + type: 'ADD' | 'EDIT'; + newPromptName: string; + setNewPromptName: (name: string) => void; + newPromptContent: string; + setNewPromptContent: (content: string) => void; + editPromptName: string; + setEditPromptName: (name: string) => void; + editPromptContent: string; + setEditPromptContent: (content: string) => void; + currentPromptEdit: { name: string; id: string; type: string }; + handleAddPrompt?: () => void; + handleEditPrompt?: (id: string, type: string) => void; +}) { + let view; + + if (type === 'ADD') { + view = ( + + ); + } else if (type === 'EDIT') { + view = ( + + ); + } else { + view = <>; + } + return ( +
+
+ {view} +
+
+ ); +} diff --git a/frontend/src/settings/APIKeys.tsx b/frontend/src/settings/APIKeys.tsx new file mode 100644 index 0000000..7803ff4 --- /dev/null +++ b/frontend/src/settings/APIKeys.tsx @@ -0,0 +1,336 @@ +import React from 'react'; +import { useSelector } from 'react-redux'; +import Dropdown from '../components/Dropdown'; +import { + Doc, + CreateAPIKeyModalProps, + SaveAPIKeyModalProps, +} from '../models/misc'; +import { selectSourceDocs } from '../preferences/preferenceSlice'; +import Exit from '../assets/exit.svg'; +import Trash from '../assets/trash.svg'; + +const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com'; +const embeddingsName = + import.meta.env.VITE_EMBEDDINGS_NAME || + 'huggingface_sentence-transformers/all-mpnet-base-v2'; + +const APIKeys: React.FC = () => { + const [isCreateModalOpen, setCreateModal] = React.useState(false); + const [isSaveKeyModalOpen, setSaveKeyModal] = React.useState(false); + const [newKey, setNewKey] = React.useState(''); + const [apiKeys, setApiKeys] = React.useState< + { name: string; key: string; source: string; id: string }[] + >([]); + const handleDeleteKey = (id: string) => { + fetch(`${apiHost}/api/delete_api_key`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ id }), + }) + .then((response) => { + if (!response.ok) { + throw new Error('Failed to delete API Key'); + } + return response.json(); + }) + .then((data) => { + data.status === 'ok' && + setApiKeys((previous) => previous.filter((elem) => elem.id !== id)); + }) + .catch((error) => { + console.error(error); + }); + }; + React.useEffect(() => { + fetchAPIKeys(); + }, []); + const fetchAPIKeys = async () => { + try { + const response = await fetch(`${apiHost}/api/get_api_keys`); + if (!response.ok) { + throw new Error('Failed to fetch API Keys'); + } + const apiKeys = await response.json(); + setApiKeys(apiKeys); + } catch (error) { + console.log(error); + } + }; + const createAPIKey = (payload: { + name: string; + source: string; + prompt_id: string; + chunks: string; + }) => { + fetch(`${apiHost}/api/create_api_key`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(payload), + }) + .then((response) => { + if (!response.ok) { + throw new Error('Failed to create API Key'); + } + return response.json(); + }) + .then((data) => { + setApiKeys([...apiKeys, data]); + setCreateModal(false); + setNewKey(data.key); + setSaveKeyModal(true); + fetchAPIKeys(); + }) + .catch((error) => { + console.error(error); + }); + }; + return ( +
+
+
+ +
+ {isCreateModalOpen && ( + setCreateModal(false)} + createAPIKey={createAPIKey} + /> + )} + {isSaveKeyModalOpen && ( + setSaveKeyModal(false)} + /> + )} +
+
+ + + + + + + + + + + {apiKeys?.map((element, index) => ( + + + + + + + ))} + +
Name + Source document + API Key
{element.name}{element.source}{element.key} + Delete handleDeleteKey(element.id)} + /> +
+
+
+
+
+ ); +}; + +const CreateAPIKeyModal: React.FC = ({ + close, + createAPIKey, +}) => { + const [APIKeyName, setAPIKeyName] = React.useState(''); + const [sourcePath, setSourcePath] = React.useState<{ + label: string; + value: string; + } | null>(null); + + const chunkOptions = ['0', '2', '4', '6', '8', '10']; + const [chunk, setChunk] = React.useState('2'); + const [activePrompts, setActivePrompts] = React.useState< + { name: string; id: string; type: string }[] + >([]); + const [prompt, setPrompt] = React.useState<{ + name: string; + id: string; + type: string; + } | null>(null); + const docs = useSelector(selectSourceDocs); + React.useEffect(() => { + const fetchPrompts = async () => { + try { + const response = await fetch(`${apiHost}/api/get_prompts`); + if (!response.ok) { + throw new Error('Failed to fetch prompts'); + } + const promptsData = await response.json(); + setActivePrompts(promptsData); + } catch (error) { + console.error(error); + } + }; + fetchPrompts(); + }, []); + const extractDocPaths = () => + docs + ? docs + .filter((doc) => doc.model === embeddingsName) + .map((doc: Doc) => { + let namePath = doc.name; + if (doc.language === namePath) { + namePath = '.project'; + } + let docPath = 'default'; + if (doc.location === 'local') { + docPath = 'local' + '/' + doc.name + '/'; + } else if (doc.location === 'remote') { + docPath = + doc.language + + '/' + + namePath + + '/' + + doc.version + + '/' + + doc.model + + '/'; + } + return { + label: doc.name, + value: docPath, + }; + }) + : []; + + return ( +
+
+ +
+ + Create New API Key + +
+
+ + API Key Name + + setAPIKeyName(e.target.value)} + /> +
+
+ + setSourcePath(selection) + } + options={extractDocPaths()} + size="w-full" + rounded="xl" + /> +
+
+ + setPrompt(value) + } + size="w-full" + /> +
+
+

+ Chunks processed per query +

+ setChunk(value)} + size="w-full" + /> +
+ +
+
+ ); +}; + +const SaveAPIKeyModal: React.FC = ({ apiKey, close }) => { + const [isCopied, setIsCopied] = React.useState(false); + const handleCopyKey = () => { + navigator.clipboard.writeText(apiKey); + setIsCopied(true); + }; + return ( +
+
+ +

Please save your Key

+

+ This is the only time your key will be shown. +

+
+
+

API Key

+ {apiKey} +
+ +
+ +
+
+ ); +}; + +export default APIKeys; diff --git a/frontend/src/settings/Documents.tsx b/frontend/src/settings/Documents.tsx new file mode 100644 index 0000000..bccefd3 --- /dev/null +++ b/frontend/src/settings/Documents.tsx @@ -0,0 +1,60 @@ +import { DocumentsProps } from '../models/misc'; +import Trash from '../assets/trash.svg'; + +const Documents: React.FC = ({ + documents, + handleDeleteDocument, +}) => { + return ( +
+
+
+ + + + + + + + + + + {documents && + documents.map((document, index) => ( + + + + + + + ))} + +
Document NameVector DateType
+ {document.name} + + {document.date} + + {document.location === 'remote' + ? 'Pre-loaded' + : 'Private'} + + {document.location !== 'remote' && ( + Delete { + event.stopPropagation(); + handleDeleteDocument(index, document); + }} + /> + )} +
+
+
+
+ ); +}; + +export default Documents; diff --git a/frontend/src/settings/General.tsx b/frontend/src/settings/General.tsx new file mode 100644 index 0000000..215628e --- /dev/null +++ b/frontend/src/settings/General.tsx @@ -0,0 +1,100 @@ +import React from 'react'; +import { useSelector, useDispatch } from 'react-redux'; +import Prompts from './Prompts'; +import { useDarkTheme } from '../hooks'; +import Dropdown from '../components/Dropdown'; +import { + selectPrompt, + setPrompt, + setChunks, + selectChunks, +} from '../preferences/preferenceSlice'; + +const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com'; + +const General: React.FC = () => { + const themes = ['Light', 'Dark']; + const languages = ['English']; + const chunks = ['0', '2', '4', '6', '8', '10']; + const [prompts, setPrompts] = React.useState< + { name: string; id: string; type: string }[] + >([]); + const selectedChunks = useSelector(selectChunks); + const [isDarkTheme, toggleTheme] = useDarkTheme(); + const [selectedTheme, setSelectedTheme] = React.useState( + isDarkTheme ? 'Dark' : 'Light', + ); + const dispatch = useDispatch(); + const [selectedLanguage, setSelectedLanguage] = React.useState(languages[0]); + const selectedPrompt = useSelector(selectPrompt); + + React.useEffect(() => { + const fetchPrompts = async () => { + try { + const response = await fetch(`${apiHost}/api/get_prompts`); + if (!response.ok) { + throw new Error('Failed to fetch prompts'); + } + const promptsData = await response.json(); + setPrompts(promptsData); + } catch (error) { + console.error(error); + } + }; + fetchPrompts(); + }, []); + return ( +
+
+

Select Theme

+ { + setSelectedTheme(option); + option !== selectedTheme && toggleTheme(); + }} + size="w-56" + rounded="3xl" + /> +
+
+

+ Select Language +

+ +
+
+

+ Chunks processed per query +

+ dispatch(setChunks(value))} + size="w-56" + rounded="3xl" + /> +
+
+ + dispatch(setPrompt({ name: name, id: id, type: type })) + } + setPrompts={setPrompts} + apiHost={apiHost} + /> +
+
+ ); +}; + +export default General; diff --git a/frontend/src/settings/Prompts.tsx b/frontend/src/settings/Prompts.tsx new file mode 100644 index 0000000..6956df6 --- /dev/null +++ b/frontend/src/settings/Prompts.tsx @@ -0,0 +1,219 @@ +import React from 'react'; +import { PromptProps, ActiveState } from '../models/misc'; +import Dropdown from '../components/Dropdown'; +import PromptsModal from '../preferences/PromptsModal'; + +const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com'; + +const Prompts: React.FC = ({ + prompts, + selectedPrompt, + onSelectPrompt, + setPrompts, +}) => { + const handleSelectPrompt = ({ + name, + id, + type, + }: { + name: string; + id: string; + type: string; + }) => { + setEditPromptName(name); + onSelectPrompt(name, id, type); + }; + const [newPromptName, setNewPromptName] = React.useState(''); + const [newPromptContent, setNewPromptContent] = React.useState(''); + const [editPromptName, setEditPromptName] = React.useState(''); + const [editPromptContent, setEditPromptContent] = React.useState(''); + const [currentPromptEdit, setCurrentPromptEdit] = React.useState({ + id: '', + name: '', + type: '', + }); + const [modalType, setModalType] = React.useState<'ADD' | 'EDIT'>('ADD'); + const [modalState, setModalState] = React.useState('INACTIVE'); + + const handleAddPrompt = async () => { + try { + const response = await fetch(`${apiHost}/api/create_prompt`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + name: newPromptName, + content: newPromptContent, + }), + }); + if (!response.ok) { + throw new Error('Failed to add prompt'); + } + const newPrompt = await response.json(); + if (setPrompts) { + setPrompts([ + ...prompts, + { name: newPromptName, id: newPrompt.id, type: 'private' }, + ]); + } + setModalState('INACTIVE'); + onSelectPrompt(newPromptName, newPrompt.id, newPromptContent); + setNewPromptName(newPromptName); + } catch (error) { + console.error(error); + } + }; + + const handleDeletePrompt = (id: string) => { + setPrompts(prompts.filter((prompt) => prompt.id !== id)); + fetch(`${apiHost}/api/delete_prompt`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ id: id }), + }) + .then((response) => { + if (!response.ok) { + throw new Error('Failed to delete prompt'); + } + // get 1st prompt and set it as selected + if (prompts.length > 0) { + onSelectPrompt(prompts[0].name, prompts[0].id, prompts[0].type); + } + }) + .catch((error) => { + console.error(error); + }); + }; + + const fetchPromptContent = async (id: string) => { + console.log('fetching prompt content'); + try { + const response = await fetch( + `${apiHost}/api/get_single_prompt?id=${id}`, + { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }, + ); + if (!response.ok) { + throw new Error('Failed to fetch prompt content'); + } + const promptContent = await response.json(); + setEditPromptContent(promptContent.content); + } catch (error) { + console.error(error); + } + }; + + const handleSaveChanges = (id: string, type: string) => { + fetch(`${apiHost}/api/update_prompt`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + id: id, + name: editPromptName, + content: editPromptContent, + }), + }) + .then((response) => { + if (!response.ok) { + throw new Error('Failed to update prompt'); + } + if (setPrompts) { + const existingPromptIndex = prompts.findIndex( + (prompt) => prompt.id === id, + ); + if (existingPromptIndex === -1) { + setPrompts([ + ...prompts, + { name: editPromptName, id: id, type: type }, + ]); + } else { + const updatedPrompts = [...prompts]; + updatedPrompts[existingPromptIndex] = { + name: editPromptName, + id: id, + type: type, + }; + setPrompts(updatedPrompts); + } + } + setModalState('INACTIVE'); + onSelectPrompt(editPromptName, id, type); + }) + .catch((error) => { + console.error(error); + }); + }; + + return ( + <> +
+
+
+

Active Prompt

+ { + setModalType('EDIT'); + setEditPromptName(name); + fetchPromptContent(id); + setCurrentPromptEdit({ id: id, name: name, type: type }); + setModalState('ACTIVE'); + }} + onDelete={handleDeletePrompt} + /> +
+ +
+
+ + + ); +}; + +export default Prompts; diff --git a/frontend/src/settings/Widgets.tsx b/frontend/src/settings/Widgets.tsx new file mode 100644 index 0000000..27066dc --- /dev/null +++ b/frontend/src/settings/Widgets.tsx @@ -0,0 +1,112 @@ +import React from 'react'; +import Dropdown from '../components/Dropdown'; + +const Widgets: React.FC<{ + widgetScreenshot: File | null; + onWidgetScreenshotChange: (screenshot: File | null) => void; +}> = ({ widgetScreenshot, onWidgetScreenshotChange }) => { + const widgetSources = ['Source 1', 'Source 2', 'Source 3']; + const widgetMethods = ['Method 1', 'Method 2', 'Method 3']; + const widgetTypes = ['Type 1', 'Type 2', 'Type 3']; + + const [selectedWidgetSource, setSelectedWidgetSource] = React.useState( + widgetSources[0], + ); + const [selectedWidgetMethod, setSelectedWidgetMethod] = React.useState( + widgetMethods[0], + ); + const [selectedWidgetType, setSelectedWidgetType] = React.useState( + widgetTypes[0], + ); + + // const [widgetScreenshot, setWidgetScreenshot] = useState(null); + const [widgetCode, setWidgetCode] = React.useState(''); // Your widget code state + + const handleScreenshotChange = ( + event: React.ChangeEvent, + ) => { + const files = event.target.files; + + if (files && files.length > 0) { + const selectedScreenshot = files[0]; + onWidgetScreenshotChange(selectedScreenshot); // Update the screenshot in the parent component + } + }; + + const handleCopyToClipboard = () => { + // Create a new textarea element to select the text + const textArea = document.createElement('textarea'); + textArea.value = widgetCode; + document.body.appendChild(textArea); + + // Select and copy the text + textArea.select(); + document.execCommand('copy'); + + // Clean up the textarea element + document.body.removeChild(textArea); + }; + + return ( +
+
+

Widget Source

+ +
+
+

Widget Method

+ +
+
+

Widget Type

+ +
+
+

Widget Code Snippet

+