Merge branch 'arc53:main' into main

pull/925/head
Manish Madan 1 month ago committed by GitHub
commit c30c6d9f10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,14 @@
# Security Policy
## Supported Versions
Supported Versions:
Currently, we support security patches by committing changes and bumping the version published on Github.
## Reporting a Vulnerability
Found a vulnerability? Please email us:
security@arc53.com

@ -8,13 +8,12 @@ import traceback
from pymongo import MongoClient
from bson.objectid import ObjectId
from transformers import GPT2TokenizerFast
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.error import bad_request
@ -62,9 +61,6 @@ async def async_generate(chain, question, chat_history):
return result
def count_tokens(string):
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
return len(tokenizer(string)['input_ids'])
def run_async_chain(chain, question, chat_history):
@ -104,61 +100,11 @@ def get_vectorstore(data):
def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
def complete_stream(question, docsearch, chat_history, prompt_id, conversation_id, chunks=2):
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
if prompt_id == 'default':
prompt = chat_combine_template
elif prompt_id == 'creative':
prompt = chat_combine_creative
elif prompt_id == 'strict':
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
if chunks == 0:
docs = []
else:
docs = docsearch.search(question, k=chunks)
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
# join all page_content together with a newline
docs_together = "\n".join([doc.page_content for doc in docs])
p_chat_combine = prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
source_log_docs = []
for doc in docs:
if doc.metadata:
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
else:
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
if len(chat_history) > 1:
tokens_current_history = 0
# count tokens in history
chat_history.reverse()
for i in chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append({"role": "user", "content": question})
response_full = ""
completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
messages=messages_combine)
for line in completion:
data = json.dumps({"answer": str(line)})
response_full += str(line)
yield f"data: {data}\n\n"
# save conversation to database
if conversation_id is not None:
def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}},
{"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}},
)
else:
@ -168,19 +114,50 @@ def complete_stream(question, docsearch, chat_history, prompt_id, conversation_i
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: " + question + "\n\n" +
"AI: " +
response_full},
response},
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system"}]
completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
completion = llm.gen(model=gpt_model,
messages=messages_summary, max_tokens=30)
conversation_id = conversations_collection.insert_one(
{"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]}
"queries": [{"prompt": question, "response": response, "sources": source_log_docs}]}
).inserted_id
return conversation_id
def get_prompt(prompt_id):
if prompt_id == 'default':
prompt = chat_combine_template
elif prompt_id == 'creative':
prompt = chat_combine_creative
elif prompt_id == 'strict':
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt
def complete_stream(question, retriever, conversation_id):
response_full = ""
source_log_docs = []
answer = retriever.gen()
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
data = json.dumps(line)
yield f"data: {data}\n\n"
elif "source" in line:
source_log_docs.append(line["source"])
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
@ -207,29 +184,37 @@ def stream():
prompt_id = data["prompt_id"]
else:
prompt_id = 'default'
if 'chunks' in data:
if 'selectedDocs' in data and data['selectedDocs'] is None:
chunks = 0
elif 'chunks' in data:
chunks = int(data["chunks"])
else:
chunks = 2
prompt = get_prompt(prompt_id)
# check if active_docs is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
source = {"active_docs": data_key["source"]}
elif "active_docs" in data:
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
source = {"active_docs": data["active_docs"]}
else:
vectorstore = ""
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
source = {}
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
retriever_name = source['active_docs']
retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
)
return Response(
complete_stream(question, docsearch,
chat_history=history,
prompt_id=prompt_id,
conversation_id=conversation_id,
chunks=chunks), mimetype="text/event-stream"
)
complete_stream(question=question, retriever=retriever,
conversation_id=conversation_id), mimetype="text/event-stream")
@answer.route("/api/answer", methods=["POST"])
@ -253,110 +238,40 @@ def api_answer():
chunks = int(data["chunks"])
else:
chunks = 2
if prompt_id == 'default':
prompt = chat_combine_template
elif prompt_id == 'creative':
prompt = chat_combine_creative
elif prompt_id == 'strict':
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
prompt = get_prompt(prompt_id)
# use try and except to check for exception
try:
# check if the vectorstore is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
source = {"active_docs": data_key["source"]}
else:
vectorstore = get_vectorstore(data)
# loading the index and the store and the prompt template
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
source = {data}
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
retriever_name = source['active_docs']
retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
)
source_log_docs = []
response_full = ""
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
if chunks == 0:
docs = []
else:
docs = docsearch.search(question, k=chunks)
# join all page_content together with a newline
docs_together = "\n".join([doc.page_content for doc in docs])
p_chat_combine = prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
source_log_docs = []
for doc in docs:
if doc.metadata:
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
else:
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
# join all page_content together with a newline
if len(history) > 1:
tokens_current_history = 0
# count tokens in history
history.reverse()
for i in history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append({"role": "user", "content": question})
completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
messages=messages_combine)
result = {"answer": completion, "sources": source_log_docs}
logger.debug(result)
# generate conversationId
if conversation_id is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{"$push": {"queries": {"prompt": question,
"response": result["answer"], "sources": result['sources']}}},
)
else:
# create new conversation
# generate summary
messages_summary = [
{"role": "assistant", "content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the system \n\n"
"User: " + question + "\n\n" + "AI: " + result["answer"]},
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the system"}
]
completion = llm.gen(
model=gpt_model,
engine=settings.AZURE_DEPLOYMENT_NAME,
messages=messages_summary,
max_tokens=30
)
conversation_id = conversations_collection.insert_one(
{"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]}
).inserted_id
result["conversation_id"] = str(conversation_id)
# mock result
# result = {
# "answer": "The answer is 42",
# "sources": ["https://en.wikipedia.org/wiki/42_(number)", "https://en.wikipedia.org/wiki/42_(number)"]
# }
return result
except Exception as e:
# print whole traceback
@ -373,27 +288,24 @@ def api_search():
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = data_key["source"]
source = {"active_docs": data_key["source"]}
elif "active_docs" in data:
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
source = {"active_docs": data["active_docs"]}
else:
vectorstore = ""
source = {}
if 'chunks' in data:
chunks = int(data["chunks"])
else:
chunks = 2
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
if chunks == 0:
docs = []
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
docs = docsearch.search(question, k=chunks)
retriever_name = source['active_docs']
source_log_docs = []
for doc in docs:
if doc.metadata:
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
else:
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
#yield f"data:{data}\n\n"
return source_log_docs
retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model
)
docs = retriever.search()
return docs

@ -1,6 +1,8 @@
import os
import uuid
import shutil
from flask import Blueprint, request, jsonify
from urllib.parse import urlparse
import requests
from pymongo import MongoClient
from bson.objectid import ObjectId
@ -135,30 +137,43 @@ def upload_file():
return {"status": "no name"}
job_name = secure_filename(request.form["name"])
# check if the post request has the file part
if "file" not in request.files:
print("No file part")
return {"status": "no file"}
file = request.files["file"]
if file.filename == "":
files = request.files.getlist("file")
if not files or all(file.filename == '' for file in files):
return {"status": "no file name"}
if file:
filename = secure_filename(file.filename)
# save dir
save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name)
# create dir if not exists
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file.save(os.path.join(save_dir, filename))
task = ingest.delay(settings.UPLOAD_FOLDER, [".rst", ".md", ".pdf", ".txt", ".docx",
".csv", ".epub", ".html", ".mdx"],
job_name, filename, user)
# task id
task_id = task.id
return {"status": "ok", "task_id": task_id}
# Directory where files will be saved
save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name)
os.makedirs(save_dir, exist_ok=True)
if len(files) > 1:
# Multiple files; prepare them for zip
temp_dir = os.path.join(save_dir, "temp")
os.makedirs(temp_dir, exist_ok=True)
for file in files:
filename = secure_filename(file.filename)
file.save(os.path.join(temp_dir, filename))
# Use shutil.make_archive to zip the temp directory
zip_path = shutil.make_archive(base_name=os.path.join(save_dir, job_name), format='zip', root_dir=temp_dir)
final_filename = os.path.basename(zip_path)
# Clean up the temporary directory after zipping
shutil.rmtree(temp_dir)
else:
return {"status": "error"}
# Single file
file = files[0]
final_filename = secure_filename(file.filename)
file_path = os.path.join(save_dir, final_filename)
file.save(file_path)
# Call ingest with the single file or zipped file
task = ingest.delay(settings.UPLOAD_FOLDER, [".rst", ".md", ".pdf", ".txt", ".docx",
".csv", ".epub", ".html", ".mdx"],
job_name, final_filename, user)
return {"status": "ok", "task_id": task.id}
@user.route("/api/remote", methods=["POST"])
def upload_remote():
@ -236,6 +251,34 @@ def combined_json():
for index in data_remote:
index["location"] = "remote"
data.append(index)
if 'duckduck_search' in settings.RETRIEVERS_ENABLED:
data.append(
{
"name": "DuckDuckGo Search",
"language": "en",
"version": "",
"description": "duckduck_search",
"fullName": "DuckDuckGo Search",
"date": "duckduck_search",
"docLink": "duckduck_search",
"model": settings.EMBEDDINGS_NAME,
"location": "custom",
}
)
if 'brave_search' in settings.RETRIEVERS_ENABLED:
data.append(
{
"name": "Brave Search",
"language": "en",
"version": "",
"description": "brave_search",
"fullName": "Brave Search",
"date": "brave_search",
"docLink": "brave_search",
"model": settings.EMBEDDINGS_NAME,
"location": "custom",
}
)
return jsonify(data)
@ -247,25 +290,32 @@ def check_docs():
# split docs on / and take first part
if data["docs"].split("/")[0] == "local":
return {"status": "exists"}
vectorstore = "vectors/" + data["docs"]
vectorstore = "vectors/" + secure_filename(data["docs"])
base_path = "https://raw.githubusercontent.com/arc53/DocsHUB/main/"
if os.path.exists(vectorstore) or data["docs"] == "default":
return {"status": "exists"}
else:
r = requests.get(base_path + vectorstore + "index.faiss")
if r.status_code != 200:
return {"status": "null"}
file_url = urlparse(base_path + vectorstore + "index.faiss")
if (
file_url.scheme in ['https'] and
file_url.netloc == 'raw.githubusercontent.com' and
file_url.path.startswith('/arc53/DocsHUB/main/')
):
r = requests.get(file_url.geturl())
if r.status_code != 200:
return {"status": "null"}
else:
if not os.path.exists(vectorstore):
os.makedirs(vectorstore)
with open(vectorstore + "index.faiss", "wb") as f:
f.write(r.content)
r = requests.get(base_path + vectorstore + "index.pkl")
with open(vectorstore + "index.pkl", "wb") as f:
f.write(r.content)
else:
if not os.path.exists(vectorstore):
os.makedirs(vectorstore)
with open(vectorstore + "index.faiss", "wb") as f:
f.write(r.content)
# download the store
r = requests.get(base_path + vectorstore + "index.pkl")
with open(vectorstore + "index.pkl", "wb") as f:
f.write(r.content)
return {"status": "null"}
return {"status": "loaded"}
@ -351,7 +401,14 @@ def get_api_keys():
keys = api_key_collection.find({"user": user})
list_keys = []
for key in keys:
list_keys.append({"id": str(key["_id"]), "name": key["name"], "key": key["key"][:4] + "..." + key["key"][-4:], "source": key["source"]})
list_keys.append({
"id": str(key["_id"]),
"name": key["name"],
"key": key["key"][:4] + "..." + key["key"][-4:],
"source": key["source"],
"prompt_id": key["prompt_id"],
"chunks": key["chunks"]
})
return jsonify(list_keys)
@user.route("/api/create_api_key", methods=["POST"])
@ -359,6 +416,8 @@ def create_api_key():
data = request.get_json()
name = data["name"]
source = data["source"]
prompt_id = data["prompt_id"]
chunks = data["chunks"]
key = str(uuid.uuid4())
user = "local"
resp = api_key_collection.insert_one(
@ -367,6 +426,8 @@ def create_api_key():
"key": key,
"source": source,
"user": user,
"prompt_id": prompt_id,
"chunks": chunks
}
)
new_id = str(resp.inserted_id)

@ -40,5 +40,5 @@ def after_request(response):
return response
if __name__ == "__main__":
app.run(debug=True, port=7091)
app.run(debug=settings.FLASK_DEBUG_MODE, port=7091)

@ -9,7 +9,7 @@ current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__
class Settings(BaseSettings):
LLM_NAME: str = "docsgpt"
MODEL_NAME: Optional[str] = None # when LLM_NAME is openai, MODEL_NAME can be e.g. gpt-4-turbo-preview or gpt-3.5-turbo
MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
@ -18,6 +18,7 @@ class Settings(BaseSettings):
TOKENS_MAX_HISTORY: int = 150
UPLOAD_FOLDER: str = "inputs"
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant"
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
API_URL: str = "http://localhost:7091" # backend url for celery worker
@ -59,6 +60,10 @@ class Settings(BaseSettings):
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
BRAVE_SEARCH_API_KEY: Optional[str] = None
FLASK_DEBUG_MODE: bool = False
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

@ -10,7 +10,7 @@ class AnthropicLLM(BaseLLM):
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT
def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs):
def gen(self, model, messages, max_tokens=300, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
@ -25,7 +25,7 @@ class AnthropicLLM(BaseLLM):
)
return completion.completion
def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs):
def gen_stream(self, model, messages, max_tokens=300, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Context \n {context} \n ### Question \n {user_question}"

@ -8,7 +8,7 @@ class DocsGPTAPILLM(BaseLLM):
self.endpoint = "https://llm.docsgpt.co.uk"
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -24,7 +24,7 @@ class DocsGPTAPILLM(BaseLLM):
return response_clean
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

@ -29,7 +29,7 @@ class HuggingFaceLLM(BaseLLM):
)
hf = HuggingFacePipeline(pipeline=pipe)
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -38,7 +38,7 @@ class HuggingFaceLLM(BaseLLM):
return result.content
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, **kwargs):
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")

@ -12,7 +12,7 @@ class LlamaCpp(BaseLLM):
llama = Llama(model_path=llm_name, n_ctx=2048)
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -24,7 +24,7 @@ class LlamaCpp(BaseLLM):
return result['choices'][0]['text'].split('### Answer \n')[-1]
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

@ -18,7 +18,7 @@ class OpenAILLM(BaseLLM):
return openai
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs):
response = self.client.chat.completions.create(model=model,
messages=messages,
stream=stream,
@ -26,7 +26,7 @@ class OpenAILLM(BaseLLM):
return response.choices[0].message.content
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs):
response = self.client.chat.completions.create(model=model,
messages=messages,
stream=stream,

@ -12,7 +12,7 @@ class PremAILLM(BaseLLM):
self.api_key = api_key
self.project_id = settings.PREMAI_PROJECT_ID
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, **kwargs):
response = self.client.chat.completions.create(model=model,
project_id=self.project_id,
messages=messages,
@ -21,7 +21,7 @@ class PremAILLM(BaseLLM):
return response.choices[0].message["content"]
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, **kwargs):
response = self.client.chat.completions.create(model=model,
project_id=self.project_id,
messages=messages,

@ -74,7 +74,7 @@ class SagemakerAPILLM(BaseLLM):
self.runtime = runtime
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -103,7 +103,7 @@ class SagemakerAPILLM(BaseLLM):
print(result[0]['generated_text'], file=sys.stderr)
return result[0]['generated_text'][len(prompt):]
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

@ -22,7 +22,10 @@ def group_documents(documents: List[Document], min_tokens: int, max_tokens: int)
doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
# Check if current group is empty or if the document can be added based on token count and matching metadata
if current_group is None or (len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and doc_len < min_tokens and current_group.extra_info == doc.extra_info):
if (current_group is None or
(len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and
doc_len < min_tokens and
current_group.extra_info == doc.extra_info)):
if current_group is None:
current_group = doc # Use the document directly to retain its metadata
else:

@ -3,6 +3,7 @@ boto3==1.34.6
celery==5.3.6
dataclasses_json==0.6.3
docx2txt==0.8
duckduckgo-search==5.3.0
EbookLib==0.18
elasticsearch==8.12.0
escodegen==1.0.11
@ -21,7 +22,7 @@ pydantic_settings==2.1.0
pymongo==4.6.1
PyPDF2==3.0.1
python-dotenv==1.0.1
qdrant-client==1.7.3
qdrant-client==1.8.2
redis==5.0.1
Requests==2.31.0
retry==0.9.2

@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
class BaseRetriever(ABC):
def __init__(self):
pass
@abstractmethod
def gen(self, *args, **kwargs):
pass
@abstractmethod
def search(self, *args, **kwargs):
pass

@ -0,0 +1,75 @@
import json
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens
from langchain_community.tools import BraveSearch
class BraveRetSearch(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
self.question = question
self.source = source
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
def _get_data(self):
if self.chunks == 0:
docs = []
else:
search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY,
search_kwargs={"count": int(self.chunks)})
results = search.run(self.question)
results = json.loads(results)
docs = []
for i in results:
try:
title = i['title']
link = i['link']
snippet = i['snippet']
docs.append({"text": snippet, "title": title, "link": link})
except IndexError:
pass
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
completion = llm.gen_stream(model=self.gpt_model,
messages=messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()

@ -0,0 +1,91 @@
import os
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens
class ClassicRAG(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
self.question = question
self.vectorstore = self._get_vectorstore(source=source)
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
def _get_vectorstore(self, source):
if "active_docs" in source:
if source["active_docs"].split("/")[0] == "default":
vectorstore = ""
elif source["active_docs"].split("/")[0] == "local":
vectorstore = "indexes/" + source["active_docs"]
else:
vectorstore = "vectors/" + source["active_docs"]
if source["active_docs"] == "default":
vectorstore = ""
else:
vectorstore = ""
vectorstore = os.path.join("application", vectorstore)
return vectorstore
def _get_data(self):
if self.chunks == 0:
docs = []
else:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
self.vectorstore,
settings.EMBEDDINGS_KEY
)
docs_temp = docsearch.search(self.question, k=self.chunks)
docs = [
{
"title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content,
"text": i.page_content
}
for i in docs_temp
]
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
completion = llm.gen_stream(model=self.gpt_model,
messages=messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()

@ -0,0 +1,94 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
class DuckDuckSearch(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
self.question = question
self.source = source
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
def _parse_lang_string(self, input_string):
result = []
current_item = ""
inside_brackets = False
for char in input_string:
if char == "[":
inside_brackets = True
elif char == "]":
inside_brackets = False
result.append(current_item)
current_item = ""
elif inside_brackets:
current_item += char
if inside_brackets:
result.append(current_item)
return result
def _get_data(self):
if self.chunks == 0:
docs = []
else:
wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks)
search = DuckDuckGoSearchResults(api_wrapper=wrapper)
results = search.run(self.question)
results = self._parse_lang_string(results)
docs = []
for i in results:
try:
text = i.split("title:")[0]
title = i.split("title:")[1].split("link:")[0]
link = i.split("link:")[1]
docs.append({"text": text, "title": title, "link": link})
except IndexError:
pass
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
completion = llm.gen_stream(model=self.gpt_model,
messages=messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()

@ -0,0 +1,19 @@
from application.retriever.classic_rag import ClassicRAG
from application.retriever.duckduck_search import DuckDuckSearch
from application.retriever.brave_search import BraveRetSearch
class RetrieverCreator:
retievers = {
'classic': ClassicRAG,
'duckduck_search': DuckDuckSearch,
'brave_search': BraveRetSearch
}
@classmethod
def create_retriever(cls, type, *args, **kwargs):
retiever_class = cls.retievers.get(type.lower())
if not retiever_class:
raise ValueError(f"No retievers class found for type {type}")
return retiever_class(*args, **kwargs)

@ -0,0 +1,6 @@
from transformers import GPT2TokenizerFast
def count_tokens(string):
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
return len(tokenizer(string)['input_ids'])

@ -36,6 +36,32 @@ current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
"""
Recursively extract zip files with a limit on recursion depth.
Args:
zip_path (str): Path to the zip file to be extracted.
extract_to (str): Destination path for extracted files.
current_depth (int): Current depth of recursion.
max_depth (int): Maximum allowed depth of recursion to prevent infinite loops.
"""
if current_depth > max_depth:
print(f"Reached maximum recursion depth of {max_depth}")
return
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
os.remove(zip_path) # Remove the zip file after extracting
# Check for nested zip files and extract them
for root, dirs, files in os.walk(extract_to):
for file in files:
if file.endswith(".zip"):
# If a nested zip file is found, extract it recursively
file_path = os.path.join(root, file)
extract_zip_recursive(file_path, root, current_depth + 1, max_depth)
# Define the main function for ingesting and processing documents.
def ingest_worker(self, directory, formats, name_job, filename, user):
@ -66,9 +92,11 @@ def ingest_worker(self, directory, formats, name_job, filename, user):
token_check = True
min_tokens = 150
max_tokens = 1250
full_path = directory + "/" + user + "/" + name_job
recursion_depth = 2
full_path = os.path.join(directory, user, name_job)
import sys
print(full_path, file=sys.stderr)
# check if API_URL env variable is set
file_data = {"name": name_job, "file": filename, "user": user}
@ -81,14 +109,12 @@ def ingest_worker(self, directory, formats, name_job, filename, user):
if not os.path.exists(full_path):
os.makedirs(full_path)
with open(full_path + "/" + filename, "wb") as f:
with open(os.path.join(full_path, filename), "wb") as f:
f.write(file)
# check if file is .zip and extract it
if filename.endswith(".zip"):
with zipfile.ZipFile(full_path + "/" + filename, "r") as zip_ref:
zip_ref.extractall(full_path)
os.remove(full_path + "/" + filename)
extract_zip_recursive(os.path.join(full_path, filename), full_path, 0, recursion_depth)
self.update_state(state="PROGRESS", meta={"current": 1})

@ -1,4 +1,5 @@
from application.app import app
from application.core.settings import settings
if __name__ == "__main__":
app.run(debug=True, port=7091)
app.run(debug=settings.FLASK_DEBUG_MODE, port=7091)

@ -281,6 +281,8 @@ Create a new API key for the user.
**Request Body**: JSON object with the following fields:
* `name` — A name for the API key.
* `source` — The source documents that will be used.
* `prompt_id` — The prompt ID.
* `chunks` — The number of chunks used to process an answer.
Here is a JavaScript Fetch Request example:
```js
@ -290,7 +292,10 @@ fetch("http://127.0.0.1:5000/api/create_api_key", {
"headers": {
"Content-Type": "application/json; charset=utf-8"
},
"body": JSON.stringify({"name":"Example Key Name","source":"Example Source"})
"body": JSON.stringify({"name":"Example Key Name",
"source":"Example Source",
"prompt_id":"creative",
"chunks":"2"})
})
.then((res) => res.json())
.then(console.log.bind(console))

@ -4,7 +4,11 @@
"href": "/Extensions/Chatwoot-extension"
},
"react-widget": {
"title": "🏗️ Widget setup",
"href": "/Extensions/react-widget"
}
"title": "🏗️ Widget setup",
"href": "/Extensions/react-widget"
},
"api-key-guide": {
"title": "🔐 API Keys guide",
"href": "/Extensions/api-key-guide"
}
}

@ -0,0 +1,30 @@
## Guide to DocsGPT API Keys
DocsGPT API keys are essential for developers and users who wish to integrate the DocsGPT models into external applications, such as the our widget. This guide will walk you through the steps of obtaining an API key, starting from uploading your document to understanding the key variables associated with API keys.
### Uploading Your Document
Before creating your first API key, you must upload the document that will be linked to this key. You can upload your document through two methods:
- **GUI Web App Upload:** A user-friendly graphical interface that allows for easy upload and management of documents.
- **Using `/api/upload` Method:** For users comfortable with API calls, this method provides a direct way to upload documents.
### Obtaining Your API Key
After uploading your document, you can obtain an API key either through the graphical user interface or via an API call:
- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the API Keys option, and press 'Create New' to generate your key.
- **API Call:** Alternatively, you can use the `/api/create_api_key` endpoint to create a new API key. For detailed instructions, visit [DocsGPT API Documentation](https://docs.docsgpt.co.uk/Developing/API-docs#8-apicreate_api_key).
### Understanding Key Variables
Upon creating your API key, you will encounter several key variables. Each serves a specific purpose:
- **Name:** Assign a name to your API key for easy identification.
- **Source:** Indicates the source document(s) linked to your API key, which DocsGPT will use to generate responses.
- **ID:** A unique identifier for your API key. You can view this by making a call to `/api/get_api_keys`.
- **Key:** The API key itself, which will be used in your application to authenticate API requests.
With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the API key, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation.
Congratulations on taking the first step towards enhancing your applications with DocsGPT! With this guide, you're now equipped to navigate the process of obtaining and understanding DocsGPT API keys.

@ -4,7 +4,7 @@ export default function MyApp({ Component, pageProps }) {
return (
<>
<Component {...pageProps} />
<DocsGPTWidget selectDocs="local/docsgpt-sep.zip/" apiKey="82962c9a-aa77-4152-94e5-a4f84fd44c6a" />
<DocsGPTWidget apiKey="d61a020c-ac8f-4f23-bb98-458e4da3c240" />
</>
)
}

@ -44,7 +44,7 @@
"prettier-plugin-tailwindcss": "^0.2.2",
"tailwindcss": "^3.2.4",
"typescript": "^4.9.5",
"vite": "^5.0.12",
"vite": "^5.0.13",
"vite-plugin-svgr": "^4.2.0"
}
},
@ -7855,9 +7855,9 @@
}
},
"node_modules/vite": {
"version": "5.0.12",
"resolved": "https://registry.npmjs.org/vite/-/vite-5.0.12.tgz",
"integrity": "sha512-4hsnEkG3q0N4Tzf1+t6NdN9dg/L3BM+q8SWgbSPnJvrgH2kgdyzfVJwbR1ic69/4uMJJ/3dqDZZE5/WwqW8U1w==",
"version": "5.0.13",
"resolved": "https://registry.npmjs.org/vite/-/vite-5.0.13.tgz",
"integrity": "sha512-/9ovhv2M2dGTuA+dY93B9trfyWMDRQw2jdVBhHNP6wr0oF34wG2i/N55801iZIpgUpnHDm4F/FabGQLyc+eOgg==",
"dev": true,
"dependencies": {
"esbuild": "^0.19.3",

@ -55,7 +55,7 @@
"prettier-plugin-tailwindcss": "^0.2.2",
"tailwindcss": "^3.2.4",
"typescript": "^4.9.5",
"vite": "^5.0.12",
"vite": "^5.0.13",
"vite-plugin-svgr": "^4.2.0"
}
}

@ -6,7 +6,7 @@ import PageNotFound from './PageNotFound';
import { inject } from '@vercel/analytics';
import { useMediaQuery } from './hooks';
import { useState } from 'react';
import Setting from './Setting';
import Setting from './settings';
inject();

File diff suppressed because it is too large Load Diff

@ -1,10 +1,16 @@
import { useState } from 'react';
import React from 'react';
import Arrow2 from '../assets/dropdown-arrow.svg';
import Edit from '../assets/edit.svg';
import Trash from '../assets/trash.svg';
function Dropdown({
options,
selectedValue,
onSelect,
size = 'w-32',
rounded = 'xl',
showEdit,
onEdit,
showDelete,
onDelete,
placeholder,
@ -18,35 +24,32 @@ function Dropdown({
| ((value: string) => void)
| ((value: { name: string; id: string; type: string }) => void)
| ((value: { label: string; value: string }) => void);
size?: string;
rounded?: 'xl' | '3xl';
showEdit?: boolean;
onEdit?: (value: { name: string; id: string; type: string }) => void;
showDelete?: boolean;
onDelete?: (value: string) => void;
placeholder?: string;
className?: string;
width?: string;
}) {
const [isOpen, setIsOpen] = useState(false);
const [isOpen, setIsOpen] = React.useState(false);
return (
<div
className={
className={[
typeof selectedValue === 'string'
? 'relative mt-2 w-32'
: 'relative w-full align-middle'
}
? 'relative mt-2'
: 'relative align-middle',
size,
].join(' ')}
>
<button
onClick={() => setIsOpen(!isOpen)}
className={`flex w-full cursor-pointer items-center justify-between border-2 border-silver bg-white p-3 dark:border-chinese-silver dark:bg-transparent ${
isOpen
? typeof selectedValue === 'string'
? 'rounded-t-xl'
: 'rounded-t-3xl'
: typeof selectedValue === 'string'
? 'rounded-xl'
: 'rounded-3xl'
className={`flex w-full cursor-pointer items-center justify-between border-2 border-silver bg-white px-5 py-3 dark:border-chinese-silver dark:bg-transparent ${
isOpen ? `rounded-t-${rounded}` : `rounded-${rounded}`
}`}
>
{typeof selectedValue === 'string' ? (
<span className="flex-1 overflow-hidden text-ellipsis dark:text-bright-gray">
<span className="overflow-hidden text-ellipsis dark:text-bright-gray">
{selectedValue}
</span>
) : (
@ -71,7 +74,7 @@ function Dropdown({
/>
</button>
{isOpen && (
<div className="absolute left-0 right-0 z-50 -mt-1 max-h-40 overflow-y-auto rounded-b-xl border-2 bg-white shadow-lg dark:border-chinese-silver dark:bg-dark-charcoal">
<div className="absolute left-0 right-0 z-20 -mt-1 max-h-40 overflow-y-auto rounded-b-xl border-2 border-silver bg-white shadow-lg dark:border-chinese-silver dark:bg-dark-charcoal">
{options.map((option: any, index) => (
<div
key={index}
@ -82,7 +85,7 @@ function Dropdown({
onSelect(option);
setIsOpen(false);
}}
className="ml-2 flex-1 overflow-hidden overflow-ellipsis whitespace-nowrap py-3 dark:text-light-gray"
className="ml-5 flex-1 overflow-hidden overflow-ellipsis whitespace-nowrap py-3 dark:text-light-gray"
>
{typeof option === 'string'
? option
@ -90,9 +93,35 @@ function Dropdown({
? option.name
: option.label}
</span>
{showEdit && onEdit && (
<img
src={Edit}
alt="Edit"
className="mr-4 h-4 w-4 cursor-pointer hover:opacity-50"
onClick={() => {
onEdit({
id: option.id,
name: option.name,
type: option.type,
});
setIsOpen(false);
}}
/>
)}
{showDelete && onDelete && (
<button onClick={() => onDelete(option)} className="p-2">
Delete
<button
onClick={() => onDelete(option.id)}
disabled={option.type === 'public'}
>
<img
src={Trash}
alt="Delete"
className={`mr-2 h-4 w-4 cursor-pointer hover:opacity-50 ${
option.type === 'public'
? 'cursor-not-allowed opacity-50'
: ''
}`}
/>
</button>
)}
</div>

@ -24,6 +24,12 @@ function SourceDropdown({
const embeddingsName =
import.meta.env.VITE_EMBEDDINGS_NAME ||
'huggingface_sentence-transformers/all-mpnet-base-v2';
const handleEmptyDocumentSelect = () => {
dispatch(setSelectedDocs(null));
setIsDocsListOpen(false);
};
return (
<div className="relative w-5/6 rounded-3xl">
<button
@ -35,7 +41,7 @@ function SourceDropdown({
<span className="ml-1 mr-2 flex-1 overflow-hidden text-ellipsis text-left dark:text-bright-gray">
<div className="flex flex-row gap-2">
<p className="max-w-3/4 truncate whitespace-nowrap">
{selectedDocs?.name}
{selectedDocs?.name || ''}
</p>
<p className="flex flex-col items-center justify-center">
{selectedDocs?.version}
@ -93,6 +99,14 @@ function SourceDropdown({
<p className="ml-5 py-3">No default documentation.</p>
</div>
)}
<div
className="flex cursor-pointer items-center justify-between hover:bg-gray-100 dark:text-bright-gray dark:hover:bg-purple-taupe"
onClick={handleEmptyDocumentSelect}
>
<span className="ml-4 flex-1 overflow-hidden overflow-ellipsis whitespace-nowrap py-3">
Empty
</span>
</div>
</div>
)}
</div>

@ -3,10 +3,37 @@ import { Doc } from '../preferences/preferenceApi';
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
function getDocPath(selectedDocs: Doc | null): string {
let docPath = 'default';
if (selectedDocs) {
let namePath = selectedDocs.name;
if (selectedDocs.language === namePath) {
namePath = '.project';
}
if (selectedDocs.location === 'local') {
docPath = 'local' + '/' + selectedDocs.name + '/';
} else if (selectedDocs.location === 'remote') {
docPath =
selectedDocs.language +
'/' +
namePath +
'/' +
selectedDocs.version +
'/' +
selectedDocs.model +
'/';
} else if (selectedDocs.location === 'custom') {
docPath = selectedDocs.docLink;
}
}
return docPath;
}
export function fetchAnswerApi(
question: string,
signal: AbortSignal,
selectedDocs: Doc,
selectedDocs: Doc | null,
history: Array<any> = [],
conversationId: string | null,
promptId: string | null,
@ -28,25 +55,7 @@ export function fetchAnswerApi(
title: any;
}
> {
let namePath = selectedDocs.name;
if (selectedDocs.language === namePath) {
namePath = '.project';
}
let docPath = 'default';
if (selectedDocs.location === 'local') {
docPath = 'local' + '/' + selectedDocs.name + '/';
} else if (selectedDocs.location === 'remote') {
docPath =
selectedDocs.language +
'/' +
namePath +
'/' +
selectedDocs.version +
'/' +
selectedDocs.model +
'/';
}
const docPath = getDocPath(selectedDocs);
//in history array remove all keys except prompt and response
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
@ -89,32 +98,14 @@ export function fetchAnswerApi(
export function fetchAnswerSteaming(
question: string,
signal: AbortSignal,
selectedDocs: Doc,
selectedDocs: Doc | null,
history: Array<any> = [],
conversationId: string | null,
promptId: string | null,
chunks: string,
onEvent: (event: MessageEvent) => void,
): Promise<Answer> {
let namePath = selectedDocs.name;
if (selectedDocs.language === namePath) {
namePath = '.project';
}
let docPath = 'default';
if (selectedDocs.location === 'local') {
docPath = 'local' + '/' + selectedDocs.name + '/';
} else if (selectedDocs.location === 'remote') {
docPath =
selectedDocs.language +
'/' +
namePath +
'/' +
selectedDocs.version +
'/' +
selectedDocs.model +
'/';
}
const docPath = getDocPath(selectedDocs);
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
@ -186,35 +177,12 @@ export function fetchAnswerSteaming(
}
export function searchEndpoint(
question: string,
selectedDocs: Doc,
selectedDocs: Doc | null,
conversation_id: string | null,
history: Array<any> = [],
chunks: string,
) {
/*
"active_docs": "default",
"question": "Summarise",
"conversation_id": null,
"history": "[]" */
let namePath = selectedDocs.name;
if (selectedDocs.language === namePath) {
namePath = '.project';
}
let docPath = 'default';
if (selectedDocs.location === 'local') {
docPath = 'local' + '/' + selectedDocs.name + '/';
} else if (selectedDocs.location === 'remote') {
docPath =
selectedDocs.language +
'/' +
namePath +
'/' +
selectedDocs.version +
'/' +
selectedDocs.model +
'/';
}
const docPath = getDocPath(selectedDocs);
const body = {
question: question,

@ -3,3 +3,42 @@ export type ActiveState = 'ACTIVE' | 'INACTIVE';
export type User = {
avatar: string;
};
export type Doc = {
location: string;
name: string;
language: string;
version: string;
description: string;
fullName: string;
date: string;
docLink: string;
model: string;
};
export type PromptProps = {
prompts: { name: string; id: string; type: string }[];
selectedPrompt: { name: string; id: string; type: string };
onSelectPrompt: (name: string, id: string, type: string) => void;
setPrompts: (prompts: { name: string; id: string; type: string }[]) => void;
apiHost: string;
};
export type DocumentsProps = {
documents: Doc[] | null;
handleDeleteDocument: (index: number, document: Doc) => void;
};
export type CreateAPIKeyModalProps = {
close: () => void;
createAPIKey: (payload: {
name: string;
source: string;
prompt_id: string;
chunks: string;
}) => void;
};
export type SaveAPIKeyModalProps = {
apiKey: string;
close: () => void;
};

@ -0,0 +1,222 @@
import { ActiveState } from '../models/misc';
import Exit from '../assets/exit.svg';
function AddPrompt({
setModalState,
handleAddPrompt,
newPromptName,
setNewPromptName,
newPromptContent,
setNewPromptContent,
}: {
setModalState: (state: ActiveState) => void;
handleAddPrompt?: () => void;
newPromptName: string;
setNewPromptName: (name: string) => void;
newPromptContent: string;
setNewPromptContent: (content: string) => void;
}) {
return (
<div className="relative">
<button
className="absolute top-3 right-4 m-2 w-3"
onClick={() => {
setModalState('INACTIVE');
}}
>
<img className="filter dark:invert" src={Exit} />
</button>
<div className="p-8">
<p className="mb-1 text-xl text-jet dark:text-bright-gray">
Add Prompt
</p>
<p className="mb-7 text-xs text-[#747474] dark:text-[#7F7F82]">
Add your custom prompt and save it to DocsGPT
</p>
<div>
<input
placeholder="Prompt Name"
type="text"
className="h-10 w-full rounded-lg border-2 border-silver px-3 outline-none dark:bg-transparent dark:text-silver"
value={newPromptName}
onChange={(e) => setNewPromptName(e.target.value)}
></input>
<div className="relative bottom-12 left-3 mt-[-3.00px]">
<span className="bg-white px-1 text-xs text-silver dark:bg-outer-space dark:text-silver">
Prompt Name
</span>
</div>
<div className="relative top-[7px] left-3">
<span className="bg-white px-1 text-xs text-silver dark:bg-outer-space dark:text-silver">
Prompt Text
</span>
</div>
<textarea
className="h-56 w-full rounded-lg border-2 border-silver px-3 py-2 outline-none dark:bg-transparent dark:text-silver"
value={newPromptContent}
onChange={(e) => setNewPromptContent(e.target.value)}
></textarea>
</div>
<div className="mt-6 flex flex-row-reverse">
<button
onClick={handleAddPrompt}
className="rounded-3xl bg-purple-30 px-5 py-2 text-sm text-white transition-all hover:opacity-90"
>
Save
</button>
</div>
</div>
</div>
);
}
function EditPrompt({
setModalState,
handleEditPrompt,
editPromptName,
setEditPromptName,
editPromptContent,
setEditPromptContent,
currentPromptEdit,
}: {
setModalState: (state: ActiveState) => void;
handleEditPrompt?: (id: string, type: string) => void;
editPromptName: string;
setEditPromptName: (name: string) => void;
editPromptContent: string;
setEditPromptContent: (content: string) => void;
currentPromptEdit: { name: string; id: string; type: string };
}) {
return (
<div className="relative">
<button
className="absolute top-3 right-4 m-2 w-3"
onClick={() => {
setModalState('INACTIVE');
}}
>
<img className="filter dark:invert" src={Exit} />
</button>
<div className="p-8">
<p className="mb-1 text-xl text-jet dark:text-bright-gray">
Edit Prompt
</p>
<p className="mb-7 text-xs text-[#747474] dark:text-[#7F7F82]">
Edit your custom prompt and save it to DocsGPT
</p>
<div>
<input
placeholder="Prompt Name"
type="text"
className="h-10 w-full rounded-lg border-2 border-silver px-3 outline-none dark:bg-transparent dark:text-silver"
value={editPromptName}
onChange={(e) => setEditPromptName(e.target.value)}
></input>
<div className="relative bottom-12 left-3 mt-[-3.00px]">
<span className="bg-white px-1 text-xs text-silver dark:bg-outer-space dark:text-silver">
Prompt Name
</span>
</div>
<div className="relative top-[7px] left-3">
<span className="bg-white px-1 text-xs text-silver dark:bg-outer-space dark:text-silver">
Prompt Text
</span>
</div>
<textarea
className="h-56 w-full rounded-lg border-2 border-silver px-3 py-2 outline-none dark:bg-transparent dark:text-silver"
value={editPromptContent}
onChange={(e) => setEditPromptContent(e.target.value)}
></textarea>
</div>
<div className="mt-6 flex flex-row-reverse gap-4">
<button
className={`rounded-3xl bg-purple-30 px-5 py-2 text-sm text-white transition-all ${
currentPromptEdit.type === 'public'
? 'cursor-not-allowed opacity-50'
: 'hover:opacity-90'
}`}
onClick={() => {
handleEditPrompt &&
handleEditPrompt(currentPromptEdit.id, currentPromptEdit.type);
}}
disabled={currentPromptEdit.type === 'public'}
>
Save
</button>
</div>
</div>
</div>
);
}
export default function PromptsModal({
modalState,
setModalState,
type,
newPromptName,
setNewPromptName,
newPromptContent,
setNewPromptContent,
editPromptName,
setEditPromptName,
editPromptContent,
setEditPromptContent,
currentPromptEdit,
handleAddPrompt,
handleEditPrompt,
}: {
modalState: ActiveState;
setModalState: (state: ActiveState) => void;
type: 'ADD' | 'EDIT';
newPromptName: string;
setNewPromptName: (name: string) => void;
newPromptContent: string;
setNewPromptContent: (content: string) => void;
editPromptName: string;
setEditPromptName: (name: string) => void;
editPromptContent: string;
setEditPromptContent: (content: string) => void;
currentPromptEdit: { name: string; id: string; type: string };
handleAddPrompt?: () => void;
handleEditPrompt?: (id: string, type: string) => void;
}) {
let view;
if (type === 'ADD') {
view = (
<AddPrompt
setModalState={setModalState}
handleAddPrompt={handleAddPrompt}
newPromptName={newPromptName}
setNewPromptName={setNewPromptName}
newPromptContent={newPromptContent}
setNewPromptContent={setNewPromptContent}
/>
);
} else if (type === 'EDIT') {
view = (
<EditPrompt
setModalState={setModalState}
handleEditPrompt={handleEditPrompt}
editPromptName={editPromptName}
setEditPromptName={setEditPromptName}
editPromptContent={editPromptContent}
setEditPromptContent={setEditPromptContent}
currentPromptEdit={currentPromptEdit}
/>
);
} else {
view = <></>;
}
return (
<article
className={`${
modalState === 'ACTIVE' ? 'visible' : 'hidden'
} fixed top-0 left-0 z-30 h-screen w-screen bg-gray-alpha`}
>
<article className="mx-auto mt-24 flex w-[90vw] max-w-lg flex-col gap-4 rounded-2xl bg-white shadow-lg dark:bg-outer-space">
{view}
</article>
</article>
);
}

@ -0,0 +1,336 @@
import React from 'react';
import { useSelector } from 'react-redux';
import Dropdown from '../components/Dropdown';
import {
Doc,
CreateAPIKeyModalProps,
SaveAPIKeyModalProps,
} from '../models/misc';
import { selectSourceDocs } from '../preferences/preferenceSlice';
import Exit from '../assets/exit.svg';
import Trash from '../assets/trash.svg';
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
const embeddingsName =
import.meta.env.VITE_EMBEDDINGS_NAME ||
'huggingface_sentence-transformers/all-mpnet-base-v2';
const APIKeys: React.FC = () => {
const [isCreateModalOpen, setCreateModal] = React.useState(false);
const [isSaveKeyModalOpen, setSaveKeyModal] = React.useState(false);
const [newKey, setNewKey] = React.useState('');
const [apiKeys, setApiKeys] = React.useState<
{ name: string; key: string; source: string; id: string }[]
>([]);
const handleDeleteKey = (id: string) => {
fetch(`${apiHost}/api/delete_api_key`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ id }),
})
.then((response) => {
if (!response.ok) {
throw new Error('Failed to delete API Key');
}
return response.json();
})
.then((data) => {
data.status === 'ok' &&
setApiKeys((previous) => previous.filter((elem) => elem.id !== id));
})
.catch((error) => {
console.error(error);
});
};
React.useEffect(() => {
fetchAPIKeys();
}, []);
const fetchAPIKeys = async () => {
try {
const response = await fetch(`${apiHost}/api/get_api_keys`);
if (!response.ok) {
throw new Error('Failed to fetch API Keys');
}
const apiKeys = await response.json();
setApiKeys(apiKeys);
} catch (error) {
console.log(error);
}
};
const createAPIKey = (payload: {
name: string;
source: string;
prompt_id: string;
chunks: string;
}) => {
fetch(`${apiHost}/api/create_api_key`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(payload),
})
.then((response) => {
if (!response.ok) {
throw new Error('Failed to create API Key');
}
return response.json();
})
.then((data) => {
setApiKeys([...apiKeys, data]);
setCreateModal(false);
setNewKey(data.key);
setSaveKeyModal(true);
fetchAPIKeys();
})
.catch((error) => {
console.error(error);
});
};
return (
<div className="mt-8">
<div className="flex w-full flex-col lg:w-max">
<div className="flex justify-end">
<button
onClick={() => setCreateModal(true)}
className="rounded-full bg-purple-30 px-4 py-3 text-sm text-white hover:opacity-90"
>
Create new
</button>
</div>
{isCreateModalOpen && (
<CreateAPIKeyModal
close={() => setCreateModal(false)}
createAPIKey={createAPIKey}
/>
)}
{isSaveKeyModalOpen && (
<SaveAPIKeyModal
apiKey={newKey}
close={() => setSaveKeyModal(false)}
/>
)}
<div className="mt-[27px] w-full">
<div className="w-full overflow-x-auto">
<table className="block w-max table-auto content-center justify-center rounded-xl border text-center dark:border-chinese-silver dark:text-bright-gray">
<thead>
<tr>
<th className="border-r p-4 md:w-[244px]">Name</th>
<th className="w-[244px] border-r px-4 py-2">
Source document
</th>
<th className="w-[244px] border-r px-4 py-2">API Key</th>
<th className="px-4 py-2"></th>
</tr>
</thead>
<tbody>
{apiKeys?.map((element, index) => (
<tr key={index}>
<td className="border-r border-t p-4">{element.name}</td>
<td className="border-r border-t p-4">{element.source}</td>
<td className="border-r border-t p-4">{element.key}</td>
<td className="border-t p-4">
<img
src={Trash}
alt="Delete"
className="h-4 w-4 cursor-pointer hover:opacity-50"
id={`img-${index}`}
onClick={() => handleDeleteKey(element.id)}
/>
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
</div>
</div>
);
};
const CreateAPIKeyModal: React.FC<CreateAPIKeyModalProps> = ({
close,
createAPIKey,
}) => {
const [APIKeyName, setAPIKeyName] = React.useState<string>('');
const [sourcePath, setSourcePath] = React.useState<{
label: string;
value: string;
} | null>(null);
const chunkOptions = ['0', '2', '4', '6', '8', '10'];
const [chunk, setChunk] = React.useState<string>('2');
const [activePrompts, setActivePrompts] = React.useState<
{ name: string; id: string; type: string }[]
>([]);
const [prompt, setPrompt] = React.useState<{
name: string;
id: string;
type: string;
} | null>(null);
const docs = useSelector(selectSourceDocs);
React.useEffect(() => {
const fetchPrompts = async () => {
try {
const response = await fetch(`${apiHost}/api/get_prompts`);
if (!response.ok) {
throw new Error('Failed to fetch prompts');
}
const promptsData = await response.json();
setActivePrompts(promptsData);
} catch (error) {
console.error(error);
}
};
fetchPrompts();
}, []);
const extractDocPaths = () =>
docs
? docs
.filter((doc) => doc.model === embeddingsName)
.map((doc: Doc) => {
let namePath = doc.name;
if (doc.language === namePath) {
namePath = '.project';
}
let docPath = 'default';
if (doc.location === 'local') {
docPath = 'local' + '/' + doc.name + '/';
} else if (doc.location === 'remote') {
docPath =
doc.language +
'/' +
namePath +
'/' +
doc.version +
'/' +
doc.model +
'/';
}
return {
label: doc.name,
value: docPath,
};
})
: [];
return (
<div className="fixed top-0 left-0 z-30 flex h-screen w-screen items-center justify-center bg-gray-alpha bg-opacity-50">
<div className="relative w-11/12 rounded-2xl bg-white p-10 dark:bg-outer-space sm:w-[512px]">
<button className="absolute top-3 right-4 m-2 w-3" onClick={close}>
<img className="filter dark:invert" src={Exit} />
</button>
<div className="mb-6">
<span className="text-xl text-jet dark:text-bright-gray">
Create New API Key
</span>
</div>
<div className="relative mt-5 mb-4">
<span className="absolute left-2 -top-2 bg-white px-2 text-xs text-gray-4000 dark:bg-outer-space dark:text-silver">
API Key Name
</span>
<input
type="text"
className="h-10 w-full rounded-md border-2 border-silver px-3 outline-none dark:bg-transparent dark:text-silver"
value={APIKeyName}
onChange={(e) => setAPIKeyName(e.target.value)}
/>
</div>
<div className="my-4">
<Dropdown
placeholder="Source document"
selectedValue={sourcePath}
onSelect={(selection: { label: string; value: string }) =>
setSourcePath(selection)
}
options={extractDocPaths()}
size="w-full"
rounded="xl"
/>
</div>
<div className="my-4">
<Dropdown
options={activePrompts}
selectedValue={prompt ? prompt.name : null}
placeholder="Select active prompt"
onSelect={(value: { name: string; id: string; type: string }) =>
setPrompt(value)
}
size="w-full"
/>
</div>
<div className="my-4">
<p className="mb-2 ml-2 font-bold text-jet dark:text-bright-gray">
Chunks processed per query
</p>
<Dropdown
options={chunkOptions}
selectedValue={chunk}
onSelect={(value: string) => setChunk(value)}
size="w-full"
/>
</div>
<button
disabled={!sourcePath || APIKeyName.length === 0 || !prompt}
onClick={() =>
sourcePath &&
prompt &&
createAPIKey({
name: APIKeyName,
source: sourcePath.value,
prompt_id: prompt.id,
chunks: chunk,
})
}
className="float-right mt-4 rounded-full bg-purple-30 px-4 py-3 text-white disabled:opacity-50"
>
Create
</button>
</div>
</div>
);
};
const SaveAPIKeyModal: React.FC<SaveAPIKeyModalProps> = ({ apiKey, close }) => {
const [isCopied, setIsCopied] = React.useState(false);
const handleCopyKey = () => {
navigator.clipboard.writeText(apiKey);
setIsCopied(true);
};
return (
<div className="fixed top-0 left-0 z-30 flex h-screen w-screen items-center justify-center bg-gray-alpha bg-opacity-50">
<div className="relative w-11/12 rounded-3xl bg-white px-6 py-8 dark:bg-outer-space dark:text-bright-gray sm:w-[512px]">
<button className="absolute top-3 right-4 m-2 w-3" onClick={close}>
<img className="filter dark:invert" src={Exit} />
</button>
<h1 className="my-0 text-xl font-medium">Please save your Key</h1>
<h3 className="text-sm font-normal text-outer-space">
This is the only time your key will be shown.
</h3>
<div className="flex justify-between py-2">
<div>
<h2 className="text-base font-semibold">API Key</h2>
<span className="text-sm font-normal leading-7 ">{apiKey}</span>
</div>
<button
className="my-1 h-10 w-20 rounded-full border border-purple-30 p-2 text-sm text-purple-30 hover:bg-purple-30 hover:text-white dark:border-purple-500 dark:text-purple-500"
onClick={handleCopyKey}
>
{isCopied ? 'Copied' : 'Copy'}
</button>
</div>
<button
onClick={close}
className="rounded-full bg-philippine-yellow px-4 py-3 font-medium text-black hover:bg-[#E6B91A]"
>
I saved the Key
</button>
</div>
</div>
);
};
export default APIKeys;

@ -0,0 +1,60 @@
import { DocumentsProps } from '../models/misc';
import Trash from '../assets/trash.svg';
const Documents: React.FC<DocumentsProps> = ({
documents,
handleDeleteDocument,
}) => {
return (
<div className="mt-8">
<div className="flex flex-col">
<div className="mt-[27px] w-max overflow-x-auto rounded-xl border dark:border-chinese-silver">
<table className="block w-full table-auto content-center justify-center text-center dark:text-bright-gray">
<thead>
<tr>
<th className="border-r p-4 md:w-[244px]">Document Name</th>
<th className="w-[244px] border-r px-4 py-2">Vector Date</th>
<th className="w-[244px] border-r px-4 py-2">Type</th>
<th className="px-4 py-2"></th>
</tr>
</thead>
<tbody>
{documents &&
documents.map((document, index) => (
<tr key={index}>
<td className="border-r border-t px-4 py-2">
{document.name}
</td>
<td className="border-r border-t px-4 py-2">
{document.date}
</td>
<td className="border-r border-t px-4 py-2">
{document.location === 'remote'
? 'Pre-loaded'
: 'Private'}
</td>
<td className="border-t px-4 py-2">
{document.location !== 'remote' && (
<img
src={Trash}
alt="Delete"
className="h-4 w-4 cursor-pointer hover:opacity-50"
id={`img-${index}`}
onClick={(event) => {
event.stopPropagation();
handleDeleteDocument(index, document);
}}
/>
)}
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
</div>
);
};
export default Documents;

@ -0,0 +1,100 @@
import React from 'react';
import { useSelector, useDispatch } from 'react-redux';
import Prompts from './Prompts';
import { useDarkTheme } from '../hooks';
import Dropdown from '../components/Dropdown';
import {
selectPrompt,
setPrompt,
setChunks,
selectChunks,
} from '../preferences/preferenceSlice';
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
const General: React.FC = () => {
const themes = ['Light', 'Dark'];
const languages = ['English'];
const chunks = ['0', '2', '4', '6', '8', '10'];
const [prompts, setPrompts] = React.useState<
{ name: string; id: string; type: string }[]
>([]);
const selectedChunks = useSelector(selectChunks);
const [isDarkTheme, toggleTheme] = useDarkTheme();
const [selectedTheme, setSelectedTheme] = React.useState(
isDarkTheme ? 'Dark' : 'Light',
);
const dispatch = useDispatch();
const [selectedLanguage, setSelectedLanguage] = React.useState(languages[0]);
const selectedPrompt = useSelector(selectPrompt);
React.useEffect(() => {
const fetchPrompts = async () => {
try {
const response = await fetch(`${apiHost}/api/get_prompts`);
if (!response.ok) {
throw new Error('Failed to fetch prompts');
}
const promptsData = await response.json();
setPrompts(promptsData);
} catch (error) {
console.error(error);
}
};
fetchPrompts();
}, []);
return (
<div className="mt-[59px]">
<div className="mb-4">
<p className="font-bold text-jet dark:text-bright-gray">Select Theme</p>
<Dropdown
options={themes}
selectedValue={selectedTheme}
onSelect={(option: string) => {
setSelectedTheme(option);
option !== selectedTheme && toggleTheme();
}}
size="w-56"
rounded="3xl"
/>
</div>
<div className="mb-4">
<p className="font-bold text-jet dark:text-bright-gray">
Select Language
</p>
<Dropdown
options={languages}
selectedValue={selectedLanguage}
onSelect={setSelectedLanguage}
size="w-56"
rounded="3xl"
/>
</div>
<div className="mb-4">
<p className="font-bold text-jet dark:text-bright-gray">
Chunks processed per query
</p>
<Dropdown
options={chunks}
selectedValue={selectedChunks}
onSelect={(value: string) => dispatch(setChunks(value))}
size="w-56"
rounded="3xl"
/>
</div>
<div>
<Prompts
prompts={prompts}
selectedPrompt={selectedPrompt}
onSelectPrompt={(name, id, type) =>
dispatch(setPrompt({ name: name, id: id, type: type }))
}
setPrompts={setPrompts}
apiHost={apiHost}
/>
</div>
</div>
);
};
export default General;

@ -0,0 +1,219 @@
import React from 'react';
import { PromptProps, ActiveState } from '../models/misc';
import Dropdown from '../components/Dropdown';
import PromptsModal from '../preferences/PromptsModal';
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
const Prompts: React.FC<PromptProps> = ({
prompts,
selectedPrompt,
onSelectPrompt,
setPrompts,
}) => {
const handleSelectPrompt = ({
name,
id,
type,
}: {
name: string;
id: string;
type: string;
}) => {
setEditPromptName(name);
onSelectPrompt(name, id, type);
};
const [newPromptName, setNewPromptName] = React.useState('');
const [newPromptContent, setNewPromptContent] = React.useState('');
const [editPromptName, setEditPromptName] = React.useState('');
const [editPromptContent, setEditPromptContent] = React.useState('');
const [currentPromptEdit, setCurrentPromptEdit] = React.useState({
id: '',
name: '',
type: '',
});
const [modalType, setModalType] = React.useState<'ADD' | 'EDIT'>('ADD');
const [modalState, setModalState] = React.useState<ActiveState>('INACTIVE');
const handleAddPrompt = async () => {
try {
const response = await fetch(`${apiHost}/api/create_prompt`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
name: newPromptName,
content: newPromptContent,
}),
});
if (!response.ok) {
throw new Error('Failed to add prompt');
}
const newPrompt = await response.json();
if (setPrompts) {
setPrompts([
...prompts,
{ name: newPromptName, id: newPrompt.id, type: 'private' },
]);
}
setModalState('INACTIVE');
onSelectPrompt(newPromptName, newPrompt.id, newPromptContent);
setNewPromptName(newPromptName);
} catch (error) {
console.error(error);
}
};
const handleDeletePrompt = (id: string) => {
setPrompts(prompts.filter((prompt) => prompt.id !== id));
fetch(`${apiHost}/api/delete_prompt`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ id: id }),
})
.then((response) => {
if (!response.ok) {
throw new Error('Failed to delete prompt');
}
// get 1st prompt and set it as selected
if (prompts.length > 0) {
onSelectPrompt(prompts[0].name, prompts[0].id, prompts[0].type);
}
})
.catch((error) => {
console.error(error);
});
};
const fetchPromptContent = async (id: string) => {
console.log('fetching prompt content');
try {
const response = await fetch(
`${apiHost}/api/get_single_prompt?id=${id}`,
{
method: 'GET',
headers: {
'Content-Type': 'application/json',
},
},
);
if (!response.ok) {
throw new Error('Failed to fetch prompt content');
}
const promptContent = await response.json();
setEditPromptContent(promptContent.content);
} catch (error) {
console.error(error);
}
};
const handleSaveChanges = (id: string, type: string) => {
fetch(`${apiHost}/api/update_prompt`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
id: id,
name: editPromptName,
content: editPromptContent,
}),
})
.then((response) => {
if (!response.ok) {
throw new Error('Failed to update prompt');
}
if (setPrompts) {
const existingPromptIndex = prompts.findIndex(
(prompt) => prompt.id === id,
);
if (existingPromptIndex === -1) {
setPrompts([
...prompts,
{ name: editPromptName, id: id, type: type },
]);
} else {
const updatedPrompts = [...prompts];
updatedPrompts[existingPromptIndex] = {
name: editPromptName,
id: id,
type: type,
};
setPrompts(updatedPrompts);
}
}
setModalState('INACTIVE');
onSelectPrompt(editPromptName, id, type);
})
.catch((error) => {
console.error(error);
});
};
return (
<>
<div>
<div className="mb-4 flex flex-row items-center gap-8">
<div>
<p className="font-semibold dark:text-bright-gray">Active Prompt</p>
<Dropdown
options={prompts}
selectedValue={selectedPrompt.name}
onSelect={handleSelectPrompt}
size="w-56"
rounded="3xl"
showEdit
showDelete
onEdit={({
id,
name,
type,
}: {
id: string;
name: string;
type: string;
}) => {
setModalType('EDIT');
setEditPromptName(name);
fetchPromptContent(id);
setCurrentPromptEdit({ id: id, name: name, type: type });
setModalState('ACTIVE');
}}
onDelete={handleDeletePrompt}
/>
</div>
<button
className="mt-[24px] rounded-3xl border-2 border-solid border-purple-30 px-5 py-3 text-purple-30 hover:bg-purple-30 hover:text-white"
onClick={() => {
setModalType('ADD');
setModalState('ACTIVE');
}}
>
Add new
</button>
</div>
</div>
<PromptsModal
type={modalType}
modalState={modalState}
setModalState={setModalState}
newPromptName={newPromptName}
setNewPromptName={setNewPromptName}
newPromptContent={newPromptContent}
setNewPromptContent={setNewPromptContent}
editPromptName={editPromptName}
setEditPromptName={setEditPromptName}
editPromptContent={editPromptContent}
setEditPromptContent={setEditPromptContent}
currentPromptEdit={currentPromptEdit}
handleAddPrompt={handleAddPrompt}
handleEditPrompt={handleSaveChanges}
/>
</>
);
};
export default Prompts;

@ -0,0 +1,112 @@
import React from 'react';
import Dropdown from '../components/Dropdown';
const Widgets: React.FC<{
widgetScreenshot: File | null;
onWidgetScreenshotChange: (screenshot: File | null) => void;
}> = ({ widgetScreenshot, onWidgetScreenshotChange }) => {
const widgetSources = ['Source 1', 'Source 2', 'Source 3'];
const widgetMethods = ['Method 1', 'Method 2', 'Method 3'];
const widgetTypes = ['Type 1', 'Type 2', 'Type 3'];
const [selectedWidgetSource, setSelectedWidgetSource] = React.useState(
widgetSources[0],
);
const [selectedWidgetMethod, setSelectedWidgetMethod] = React.useState(
widgetMethods[0],
);
const [selectedWidgetType, setSelectedWidgetType] = React.useState(
widgetTypes[0],
);
// const [widgetScreenshot, setWidgetScreenshot] = useState<File | null>(null);
const [widgetCode, setWidgetCode] = React.useState<string>(''); // Your widget code state
const handleScreenshotChange = (
event: React.ChangeEvent<HTMLInputElement>,
) => {
const files = event.target.files;
if (files && files.length > 0) {
const selectedScreenshot = files[0];
onWidgetScreenshotChange(selectedScreenshot); // Update the screenshot in the parent component
}
};
const handleCopyToClipboard = () => {
// Create a new textarea element to select the text
const textArea = document.createElement('textarea');
textArea.value = widgetCode;
document.body.appendChild(textArea);
// Select and copy the text
textArea.select();
document.execCommand('copy');
// Clean up the textarea element
document.body.removeChild(textArea);
};
return (
<div>
<div className="mt-[59px]">
<p className="font-bold text-jet">Widget Source</p>
<Dropdown
options={widgetSources}
selectedValue={selectedWidgetSource}
onSelect={setSelectedWidgetSource}
/>
</div>
<div className="mt-5">
<p className="font-bold text-jet">Widget Method</p>
<Dropdown
options={widgetMethods}
selectedValue={selectedWidgetMethod}
onSelect={setSelectedWidgetMethod}
/>
</div>
<div className="mt-5">
<p className="font-bold text-jet">Widget Type</p>
<Dropdown
options={widgetTypes}
selectedValue={selectedWidgetType}
onSelect={setSelectedWidgetType}
/>
</div>
<div className="mt-6">
<p className="font-bold text-jet">Widget Code Snippet</p>
<textarea
rows={4}
value={widgetCode}
onChange={(e) => setWidgetCode(e.target.value)}
className="mt-3 w-full rounded-lg border-2 p-2"
/>
</div>
<div className="mt-1">
<button
onClick={handleCopyToClipboard}
className="rounded-lg bg-blue-400 px-2 py-2 font-bold text-white transition-all hover:bg-blue-600"
>
Copy
</button>
</div>
<div className="mt-4">
<p className="text-lg font-semibold">Widget Screenshot</p>
<input type="file" accept="image/*" onChange={handleScreenshotChange} />
</div>
{widgetScreenshot && (
<div className="mt-4">
<img
src={URL.createObjectURL(widgetScreenshot)}
alt="Widget Screenshot"
className="max-w-full rounded-lg border border-gray-300"
/>
</div>
)}
</div>
);
};
export default Widgets;

@ -0,0 +1,127 @@
import React from 'react';
import { useSelector, useDispatch } from 'react-redux';
import General from './General';
import Documents from './Documents';
import APIKeys from './APIKeys';
import Widgets from './Widgets';
import {
selectSourceDocs,
setSourceDocs,
} from '../preferences/preferenceSlice';
import { Doc } from '../preferences/preferenceApi';
import ArrowLeft from '../assets/arrow-left.svg';
import ArrowRight from '../assets/arrow-right.svg';
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
const Settings: React.FC = () => {
const dispatch = useDispatch();
const tabs = ['General', 'Documents', 'API Keys'];
const [activeTab, setActiveTab] = React.useState('General');
const [widgetScreenshot, setWidgetScreenshot] = React.useState<File | null>(
null,
);
const documents = useSelector(selectSourceDocs);
const updateWidgetScreenshot = (screenshot: File | null) => {
setWidgetScreenshot(screenshot);
};
const handleDeleteClick = (index: number, doc: Doc) => {
const docPath = 'indexes/' + 'local' + '/' + doc.name;
fetch(`${apiHost}/api/delete_old?path=${docPath}`, {
method: 'GET',
})
.then((response) => {
if (response.ok && documents) {
const updatedDocuments = [
...documents.slice(0, index),
...documents.slice(index + 1),
];
dispatch(setSourceDocs(updatedDocuments));
}
})
.catch((error) => console.error(error));
};
return (
<div className="wa p-4 pt-20 md:p-12">
<p className="text-2xl font-bold text-eerie-black dark:text-bright-gray">
Settings
</p>
<div className="mt-6 flex flex-row items-center space-x-4 overflow-x-auto md:space-x-8 ">
<div className="md:hidden">
<button
onClick={() => scrollTabs(-1)}
className="flex h-8 w-8 items-center justify-center rounded-full border-2 border-purple-30 transition-all hover:bg-gray-100"
>
<img src={ArrowLeft} alt="left-arrow" className="h-6 w-6" />
</button>
</div>
<div className="flex flex-nowrap space-x-4 overflow-x-auto md:space-x-8">
{tabs.map((tab, index) => (
<button
key={index}
onClick={() => setActiveTab(tab)}
className={`h-9 rounded-3xl px-4 font-bold ${
activeTab === tab
? 'bg-purple-3000 text-purple-30 dark:bg-dark-charcoal'
: 'text-gray-6000'
}`}
>
{tab}
</button>
))}
</div>
<div className="md:hidden">
<button
onClick={() => scrollTabs(1)}
className="flex h-8 w-8 items-center justify-center rounded-full border-2 border-purple-30 hover:bg-gray-100"
>
<img src={ArrowRight} alt="right-arrow" className="h-6 w-6" />
</button>
</div>
</div>
{renderActiveTab()}
{/* {activeTab === 'Widgets' && (
<Widgets
widgetScreenshot={widgetScreenshot}
onWidgetScreenshotChange={updateWidgetScreenshot}
/>
)} */}
</div>
);
function scrollTabs(direction: number) {
const container = document.querySelector('.flex-nowrap');
if (container) {
container.scrollLeft += direction * 100; // Adjust the scroll amount as needed
}
}
function renderActiveTab() {
switch (activeTab) {
case 'General':
return <General />;
case 'Documents':
return (
<Documents
documents={documents}
handleDeleteDocument={handleDeleteClick}
/>
);
case 'Widgets':
return (
<Widgets
widgetScreenshot={widgetScreenshot} // Add this line
onWidgetScreenshotChange={updateWidgetScreenshot} // Add this line
/>
);
case 'API Keys':
return <APIKeys />;
default:
return null;
}
}
};
export default Settings;

@ -201,7 +201,7 @@ export default function Upload({
const { getRootProps, getInputProps, isDragActive } = useDropzone({
onDrop,
multiple: false,
multiple: true,
onDragEnter: doNothing,
onDragOver: doNothing,
onDragLeave: doNothing,
@ -307,6 +307,8 @@ export default function Upload({
onSelect={(value: { label: string; value: string }) =>
setUrlType(value)
}
size="w-full"
rounded="3xl"
/>
{urlType.label !== 'Reddit' ? (
<>
@ -417,7 +419,7 @@ export default function Upload({
disabled={
(files.length === 0 || docName.trim().length === 0) &&
activeTab === 'file'
} // Disable the button if no file is selected or docName is empty
}
>
Train
</button>
@ -448,4 +450,3 @@ export default function Upload({
</article>
);
}
// TODO: sanitize all inputs

@ -357,9 +357,9 @@
}
},
"node_modules/cookie": {
"version": "0.5.0",
"resolved": "https://registry.npmjs.org/cookie/-/cookie-0.5.0.tgz",
"integrity": "sha512-YZ3GUyn/o8gfKJlnlX7g7xq4gyO6OSuhGPKaaGssGB2qgDUS0gPgtTvoyZLTt9Ab6dC4hfc9dV5arkvc/OCmrw==",
"version": "0.6.0",
"resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz",
"integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==",
"engines": {
"node": ">= 0.6"
}
@ -458,16 +458,16 @@
}
},
"node_modules/express": {
"version": "4.18.2",
"resolved": "https://registry.npmjs.org/express/-/express-4.18.2.tgz",
"integrity": "sha512-5/PsL6iGPdfQ/lKM1UuielYgv3BUoJfz1aUwU9vHZ+J7gyvwdQXFEBIEIaxeGf0GIcreATNyBExtalisDbuMqQ==",
"version": "4.19.2",
"resolved": "https://registry.npmjs.org/express/-/express-4.19.2.tgz",
"integrity": "sha512-5T6nhjsT+EOMzuck8JjBHARTHfMht0POzlA60WV2pMD3gyXw2LZnZ+ueGdNxG+0calOJcWKbpFcuzLZ91YWq9Q==",
"dependencies": {
"accepts": "~1.3.8",
"array-flatten": "1.1.1",
"body-parser": "1.20.1",
"body-parser": "1.20.2",
"content-disposition": "0.5.4",
"content-type": "~1.0.4",
"cookie": "0.5.0",
"cookie": "0.6.0",
"cookie-signature": "1.0.6",
"debug": "2.6.9",
"depd": "2.0.0",
@ -515,43 +515,6 @@
"isarray": "0.0.1"
}
},
"node_modules/express/node_modules/body-parser": {
"version": "1.20.1",
"resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.1.tgz",
"integrity": "sha512-jWi7abTbYwajOytWCQc37VulmWiRae5RyTpaCyDcS5/lMdtwSz5lOpDE67srw/HYe35f1z3fDQw+3txg7gNtWw==",
"dependencies": {
"bytes": "3.1.2",
"content-type": "~1.0.4",
"debug": "2.6.9",
"depd": "2.0.0",
"destroy": "1.2.0",
"http-errors": "2.0.0",
"iconv-lite": "0.4.24",
"on-finished": "2.4.1",
"qs": "6.11.0",
"raw-body": "2.5.1",
"type-is": "~1.6.18",
"unpipe": "1.0.0"
},
"engines": {
"node": ">= 0.8",
"npm": "1.2.8000 || >= 1.4.16"
}
},
"node_modules/express/node_modules/raw-body": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.1.tgz",
"integrity": "sha512-qqJBtEyVgS0ZmPGdCFPWJ3FreoqvG4MVQln/kCgF7Olq95IbOp0/BWyMwbdtn4VTvkM8Y7khCQ2Xgk/tcrCXig==",
"dependencies": {
"bytes": "3.1.2",
"http-errors": "2.0.0",
"iconv-lite": "0.4.24",
"unpipe": "1.0.0"
},
"engines": {
"node": ">= 0.8"
}
},
"node_modules/express/node_modules/safe-buffer": {
"version": "5.2.1",
"resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz",

@ -54,7 +54,7 @@ class TestSagemakerAPILLM(unittest.TestCase):
def test_gen(self):
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
return_value=self.response) as mock_invoke_endpoint:
output = self.sagemaker.gen(None, None, self.messages)
output = self.sagemaker.gen(None, self.messages)
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
@ -66,7 +66,7 @@ class TestSagemakerAPILLM(unittest.TestCase):
def test_gen_stream(self):
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
return_value=self.response) as mock_invoke_endpoint:
output = list(self.sagemaker.gen_stream(None, None, self.messages))
output = list(self.sagemaker.gen_stream(None, self.messages))
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',

Loading…
Cancel
Save