From 391f6861733d78e589db221f967408f9fa95893a Mon Sep 17 00:00:00 2001
From: Alex
Date: Tue, 9 Apr 2024 14:02:33 +0100
Subject: [PATCH 1/7] Update application files and fix LLM models, create new
retriever class
---
application/api/answer/routes.py | 245 +++++++--------------
application/llm/anthropic.py | 4 +-
application/llm/docsgpt_provider.py | 4 +-
application/llm/huggingface.py | 4 +-
application/llm/llama_cpp.py | 4 +-
application/llm/openai.py | 4 +-
application/llm/premai.py | 4 +-
application/llm/sagemaker.py | 4 +-
application/retriever/__init__.py | 0
application/retriever/base.py | 10 +
application/retriever/classic_rag.py | 83 +++++++
application/retriever/retriever_creator.py | 15 ++
application/utils.py | 6 +
13 files changed, 202 insertions(+), 185 deletions(-)
create mode 100644 application/retriever/__init__.py
create mode 100644 application/retriever/base.py
create mode 100644 application/retriever/classic_rag.py
create mode 100644 application/retriever/retriever_creator.py
create mode 100644 application/utils.py
diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py
index b122eac..1b3c9b9 100644
--- a/application/api/answer/routes.py
+++ b/application/api/answer/routes.py
@@ -8,13 +8,14 @@ import traceback
from pymongo import MongoClient
from bson.objectid import ObjectId
-from transformers import GPT2TokenizerFast
+from application.utils import count_tokens
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
+from application.retriever.retriever_creator import RetrieverCreator
from application.error import bad_request
@@ -62,9 +63,6 @@ async def async_generate(chain, question, chat_history):
return result
-def count_tokens(string):
- tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
- return len(tokenizer(string)['input_ids'])
def run_async_chain(chain, question, chat_history):
@@ -104,61 +102,11 @@ def get_vectorstore(data):
def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
-
-def complete_stream(question, docsearch, chat_history, prompt_id, conversation_id, chunks=2):
- llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
- if prompt_id == 'default':
- prompt = chat_combine_template
- elif prompt_id == 'creative':
- prompt = chat_combine_creative
- elif prompt_id == 'strict':
- prompt = chat_combine_strict
- else:
- prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
-
- if chunks == 0:
- docs = []
- else:
- docs = docsearch.search(question, k=chunks)
- if settings.LLM_NAME == "llama.cpp":
- docs = [docs[0]]
- # join all page_content together with a newline
- docs_together = "\n".join([doc.page_content for doc in docs])
- p_chat_combine = prompt.replace("{summaries}", docs_together)
- messages_combine = [{"role": "system", "content": p_chat_combine}]
- source_log_docs = []
- for doc in docs:
- if doc.metadata:
- source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
- else:
- source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
-
- if len(chat_history) > 1:
- tokens_current_history = 0
- # count tokens in history
- chat_history.reverse()
- for i in chat_history:
- if "prompt" in i and "response" in i:
- tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
- if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
- tokens_current_history += tokens_batch
- messages_combine.append({"role": "user", "content": i["prompt"]})
- messages_combine.append({"role": "system", "content": i["response"]})
- messages_combine.append({"role": "user", "content": question})
-
- response_full = ""
- completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
- messages=messages_combine)
- for line in completion:
- data = json.dumps({"answer": str(line)})
- response_full += str(line)
- yield f"data: {data}\n\n"
-
- # save conversation to database
+def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
- {"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}},
+ {"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}},
)
else:
@@ -168,20 +116,50 @@ def complete_stream(question, docsearch, chat_history, prompt_id, conversation_i
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: " + question + "\n\n" +
"AI: " +
- response_full},
+ response},
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system"}]
- completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
+ completion = llm.gen(model=gpt_model,
messages=messages_summary, max_tokens=30)
conversation_id = conversations_collection.insert_one(
{"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
- "queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]}
+ "queries": [{"prompt": question, "response": response, "sources": source_log_docs}]}
).inserted_id
+def get_prompt(prompt_id):
+ if prompt_id == 'default':
+ prompt = chat_combine_template
+ elif prompt_id == 'creative':
+ prompt = chat_combine_creative
+ elif prompt_id == 'strict':
+ prompt = chat_combine_strict
+ else:
+ prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
+ return prompt
+
+
+def complete_stream(question, retriever, conversation_id):
+
+
+ response_full = ""
+ source_log_docs = []
+ answer = retriever.gen()
+ for line in answer:
+ if "answer" in line:
+ response_full += str(line["answer"])
+ data = json.dumps(line)
+ yield f"data: {data}\n\n"
+ elif "source" in line:
+ source_log_docs.append(line["source"])
+
+
+ llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
+ conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
+
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
@@ -213,25 +191,26 @@ def stream():
chunks = int(data["chunks"])
else:
chunks = 2
+
+ prompt = get_prompt(prompt_id)
# check if active_docs is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
- vectorstore = get_vectorstore({"active_docs": data_key["source"]})
+ source = {"active_docs": data_key["source"]}
elif "active_docs" in data:
- vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
+ source = {"active_docs": data["active_docs"]}
else:
- vectorstore = ""
- docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
+ source = {}
+
+ retriever = RetrieverCreator.create_retriever("classic", question=question,
+ source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
+ )
return Response(
- complete_stream(question, docsearch,
- chat_history=history,
- prompt_id=prompt_id,
- conversation_id=conversation_id,
- chunks=chunks), mimetype="text/event-stream"
- )
+ complete_stream(question=question, retriever=retriever,
+ conversation_id=conversation_id), mimetype="text/event-stream")
@answer.route("/api/answer", methods=["POST"])
@@ -255,110 +234,35 @@ def api_answer():
chunks = int(data["chunks"])
else:
chunks = 2
-
- if prompt_id == 'default':
- prompt = chat_combine_template
- elif prompt_id == 'creative':
- prompt = chat_combine_creative
- elif prompt_id == 'strict':
- prompt = chat_combine_strict
- else:
- prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
+
+ prompt = get_prompt(prompt_id)
# use try and except to check for exception
try:
# check if the vectorstore is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
- vectorstore = get_vectorstore({"active_docs": data_key["source"]})
+ source = {"active_docs": data_key["source"]}
else:
- vectorstore = get_vectorstore(data)
- # loading the index and the store and the prompt template
- # Note if you have used other embeddings than OpenAI, you need to change the embeddings
- docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
-
+ source = {data}
+ retriever = RetrieverCreator.create_retriever("classic", question=question,
+ source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
+ )
+ source_log_docs = []
+ response_full = ""
+ for line in retriever.gen():
+ if "source" in line:
+ source_log_docs.append(line["source"])
+ elif "answer" in line:
+ response_full += line["answer"]
+
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
+
+ result = {"answer": response_full, "sources": source_log_docs}
+ result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
-
- if chunks == 0:
- docs = []
- else:
- docs = docsearch.search(question, k=chunks)
- # join all page_content together with a newline
- docs_together = "\n".join([doc.page_content for doc in docs])
- p_chat_combine = prompt.replace("{summaries}", docs_together)
- messages_combine = [{"role": "system", "content": p_chat_combine}]
- source_log_docs = []
- for doc in docs:
- if doc.metadata:
- source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
- else:
- source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
- # join all page_content together with a newline
-
-
- if len(history) > 1:
- tokens_current_history = 0
- # count tokens in history
- history.reverse()
- for i in history:
- if "prompt" in i and "response" in i:
- tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
- if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
- tokens_current_history += tokens_batch
- messages_combine.append({"role": "user", "content": i["prompt"]})
- messages_combine.append({"role": "system", "content": i["response"]})
- messages_combine.append({"role": "user", "content": question})
-
-
- completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
- messages=messages_combine)
-
-
- result = {"answer": completion, "sources": source_log_docs}
- logger.debug(result)
-
- # generate conversationId
- if conversation_id is not None:
- conversations_collection.update_one(
- {"_id": ObjectId(conversation_id)},
- {"$push": {"queries": {"prompt": question,
- "response": result["answer"], "sources": result['sources']}}},
- )
-
- else:
- # create new conversation
- # generate summary
- messages_summary = [
- {"role": "assistant", "content": "Summarise following conversation in no more than 3 words, "
- "respond ONLY with the summary, use the same language as the system \n\n"
- "User: " + question + "\n\n" + "AI: " + result["answer"]},
- {"role": "user", "content": "Summarise following conversation in no more than 3 words, "
- "respond ONLY with the summary, use the same language as the system"}
- ]
-
- completion = llm.gen(
- model=gpt_model,
- engine=settings.AZURE_DEPLOYMENT_NAME,
- messages=messages_summary,
- max_tokens=30
- )
- conversation_id = conversations_collection.insert_one(
- {"user": "local",
- "date": datetime.datetime.utcnow(),
- "name": completion,
- "queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]}
- ).inserted_id
-
- result["conversation_id"] = str(conversation_id)
-
- # mock result
- # result = {
- # "answer": "The answer is 42",
- # "sources": ["https://en.wikipedia.org/wiki/42_(number)", "https://en.wikipedia.org/wiki/42_(number)"]
- # }
return result
except Exception as e:
# print whole traceback
@@ -375,20 +279,20 @@ def api_search():
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
- vectorstore = data_key["source"]
+ source = {"active_docs": data_key["source"]}
elif "active_docs" in data:
- vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
+ source = {"active_docs": data["active_docs"]}
else:
- vectorstore = ""
+ source = {}
if 'chunks' in data:
chunks = int(data["chunks"])
else:
chunks = 2
- docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
- if chunks == 0:
- docs = []
- else:
- docs = docsearch.search(question, k=chunks)
+
+ retriever = RetrieverCreator.create_retriever("classic", question=question,
+ source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model
+ )
+ docs = retriever.search()
source_log_docs = []
for doc in docs:
@@ -396,6 +300,5 @@ def api_search():
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
else:
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
- #yield f"data:{data}\n\n"
return source_log_docs
diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py
index a64d71e..6b0d646 100644
--- a/application/llm/anthropic.py
+++ b/application/llm/anthropic.py
@@ -10,7 +10,7 @@ class AnthropicLLM(BaseLLM):
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT
- def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs):
+ def gen(self, model, messages, max_tokens=300, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
@@ -25,7 +25,7 @@ class AnthropicLLM(BaseLLM):
)
return completion.completion
- def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs):
+ def gen_stream(self, model, messages, max_tokens=300, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py
index e0c5dba..d540a91 100644
--- a/application/llm/docsgpt_provider.py
+++ b/application/llm/docsgpt_provider.py
@@ -8,7 +8,7 @@ class DocsGPTAPILLM(BaseLLM):
self.endpoint = "https://llm.docsgpt.co.uk"
- def gen(self, model, engine, messages, stream=False, **kwargs):
+ def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@@ -24,7 +24,7 @@ class DocsGPTAPILLM(BaseLLM):
return response_clean
- def gen_stream(self, model, engine, messages, stream=True, **kwargs):
+ def gen_stream(self, model, messages, stream=True, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py
index ef3b1fb..554bee2 100644
--- a/application/llm/huggingface.py
+++ b/application/llm/huggingface.py
@@ -29,7 +29,7 @@ class HuggingFaceLLM(BaseLLM):
)
hf = HuggingFacePipeline(pipeline=pipe)
- def gen(self, model, engine, messages, stream=False, **kwargs):
+ def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@@ -38,7 +38,7 @@ class HuggingFaceLLM(BaseLLM):
return result.content
- def gen_stream(self, model, engine, messages, stream=True, **kwargs):
+ def gen_stream(self, model, messages, stream=True, **kwargs):
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")
diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py
index f18d437..be34d4f 100644
--- a/application/llm/llama_cpp.py
+++ b/application/llm/llama_cpp.py
@@ -12,7 +12,7 @@ class LlamaCpp(BaseLLM):
llama = Llama(model_path=llm_name, n_ctx=2048)
- def gen(self, model, engine, messages, stream=False, **kwargs):
+ def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@@ -24,7 +24,7 @@ class LlamaCpp(BaseLLM):
return result['choices'][0]['text'].split('### Answer \n')[-1]
- def gen_stream(self, model, engine, messages, stream=True, **kwargs):
+ def gen_stream(self, model, messages, stream=True, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
diff --git a/application/llm/openai.py b/application/llm/openai.py
index a132399..4b0ed25 100644
--- a/application/llm/openai.py
+++ b/application/llm/openai.py
@@ -18,7 +18,7 @@ class OpenAILLM(BaseLLM):
return openai
- def gen(self, model, engine, messages, stream=False, **kwargs):
+ def gen(self, model, messages, stream=False, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs):
response = self.client.chat.completions.create(model=model,
messages=messages,
stream=stream,
@@ -26,7 +26,7 @@ class OpenAILLM(BaseLLM):
return response.choices[0].message.content
- def gen_stream(self, model, engine, messages, stream=True, **kwargs):
+ def gen_stream(self, model, messages, stream=True, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs):
response = self.client.chat.completions.create(model=model,
messages=messages,
stream=stream,
diff --git a/application/llm/premai.py b/application/llm/premai.py
index 4bc8a89..5faa5fe 100644
--- a/application/llm/premai.py
+++ b/application/llm/premai.py
@@ -12,7 +12,7 @@ class PremAILLM(BaseLLM):
self.api_key = api_key
self.project_id = settings.PREMAI_PROJECT_ID
- def gen(self, model, engine, messages, stream=False, **kwargs):
+ def gen(self, model, messages, stream=False, **kwargs):
response = self.client.chat.completions.create(model=model,
project_id=self.project_id,
messages=messages,
@@ -21,7 +21,7 @@ class PremAILLM(BaseLLM):
return response.choices[0].message["content"]
- def gen_stream(self, model, engine, messages, stream=True, **kwargs):
+ def gen_stream(self, model, messages, stream=True, **kwargs):
response = self.client.chat.completions.create(model=model,
project_id=self.project_id,
messages=messages,
diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py
index 84ae09a..b81f638 100644
--- a/application/llm/sagemaker.py
+++ b/application/llm/sagemaker.py
@@ -74,7 +74,7 @@ class SagemakerAPILLM(BaseLLM):
self.runtime = runtime
- def gen(self, model, engine, messages, stream=False, **kwargs):
+ def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@@ -103,7 +103,7 @@ class SagemakerAPILLM(BaseLLM):
print(result[0]['generated_text'], file=sys.stderr)
return result[0]['generated_text'][len(prompt):]
- def gen_stream(self, model, engine, messages, stream=True, **kwargs):
+ def gen_stream(self, model, messages, stream=True, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
diff --git a/application/retriever/__init__.py b/application/retriever/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/application/retriever/base.py b/application/retriever/base.py
new file mode 100644
index 0000000..3bfaa5e
--- /dev/null
+++ b/application/retriever/base.py
@@ -0,0 +1,10 @@
+from abc import ABC, abstractmethod
+
+
+class BaseRetriever(ABC):
+ def __init__(self):
+ pass
+
+ @abstractmethod
+ def gen(self, *args, **kwargs):
+ pass
diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py
new file mode 100644
index 0000000..dc75794
--- /dev/null
+++ b/application/retriever/classic_rag.py
@@ -0,0 +1,83 @@
+import os
+import json
+from application.retriever.base import BaseRetriever
+from application.core.settings import settings
+from application.vectorstore.vector_creator import VectorCreator
+from application.llm.llm_creator import LLMCreator
+
+from application.utils import count_tokens
+
+
+
+class ClassicRAG(BaseRetriever):
+
+ def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
+ self.question = question
+ self.vectorstore = self._get_vectorstore(source=source)
+ self.chat_history = chat_history
+ self.prompt = prompt
+ self.chunks = chunks
+ self.gpt_model = gpt_model
+
+ def _get_vectorstore(self, source):
+ if "active_docs" in source:
+ if source["active_docs"].split("/")[0] == "default":
+ vectorstore = ""
+ elif source["active_docs"].split("/")[0] == "local":
+ vectorstore = "indexes/" + source["active_docs"]
+ else:
+ vectorstore = "vectors/" + source["active_docs"]
+ if source["active_docs"] == "default":
+ vectorstore = ""
+ else:
+ vectorstore = ""
+ vectorstore = os.path.join("application", vectorstore)
+ return vectorstore
+
+
+ def _get_data(self):
+ if self.chunks == 0:
+ docs = []
+ else:
+ docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY)
+ docs = docsearch.search(self.question, k=self.chunks)
+ if settings.LLM_NAME == "llama.cpp":
+ docs = [docs[0]]
+ return docs
+
+ def gen(self):
+ docs = self._get_data()
+
+ # join all page_content together with a newline
+ docs_together = "\n".join([doc.page_content for doc in docs])
+ p_chat_combine = self.prompt.replace("{summaries}", docs_together)
+ messages_combine = [{"role": "system", "content": p_chat_combine}]
+ for doc in docs:
+ if doc.metadata:
+ yield {"source": {"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}}
+ else:
+ yield {"source": {"title": doc.page_content, "text": doc.page_content}}
+
+ if len(self.chat_history) > 1:
+ tokens_current_history = 0
+ # count tokens in history
+ self.chat_history.reverse()
+ for i in self.chat_history:
+ if "prompt" in i and "response" in i:
+ tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
+ if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
+ tokens_current_history += tokens_batch
+ messages_combine.append({"role": "user", "content": i["prompt"]})
+ messages_combine.append({"role": "system", "content": i["response"]})
+ messages_combine.append({"role": "user", "content": self.question})
+
+ llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
+
+ completion = llm.gen_stream(model=self.gpt_model,
+ messages=messages_combine)
+ for line in completion:
+ yield {"answer": str(line)}
+
+ def search(self):
+ return self._get_data()
+
diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py
new file mode 100644
index 0000000..9255f4e
--- /dev/null
+++ b/application/retriever/retriever_creator.py
@@ -0,0 +1,15 @@
+from application.retriever.classic_rag import ClassicRAG
+
+
+
+class RetrieverCreator:
+ retievers = {
+ 'classic': ClassicRAG,
+ }
+
+ @classmethod
+ def create_retriever(cls, type, *args, **kwargs):
+ retiever_class = cls.retievers.get(type.lower())
+ if not retiever_class:
+ raise ValueError(f"No retievers class found for type {type}")
+ return retiever_class(*args, **kwargs)
\ No newline at end of file
diff --git a/application/utils.py b/application/utils.py
new file mode 100644
index 0000000..ac98efc
--- /dev/null
+++ b/application/utils.py
@@ -0,0 +1,6 @@
+from transformers import GPT2TokenizerFast
+
+
+def count_tokens(string):
+ tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
+ return len(tokenizer(string)['input_ids'])
\ No newline at end of file
From 1e26943c3e2cf1ccb4526a9686d79968e63ba6bc Mon Sep 17 00:00:00 2001
From: Alex
Date: Tue, 9 Apr 2024 15:45:24 +0100
Subject: [PATCH 2/7] Update application files, fix LLM models, and create new
retriever class
---
application/api/answer/routes.py | 12 +--
application/requirements.txt | 1 +
application/retriever/base.py | 4 +
application/retriever/classic_rag.py | 11 ++-
application/retriever/duckduck_search.py | 97 ++++++++++++++++++++++
application/retriever/retriever_creator.py | 2 +
6 files changed, 112 insertions(+), 15 deletions(-)
create mode 100644 application/retriever/duckduck_search.py
diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py
index 1b3c9b9..11e6c4a 100644
--- a/application/api/answer/routes.py
+++ b/application/api/answer/routes.py
@@ -103,7 +103,7 @@ def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
def save_conversation(conversation_id, question, response, source_log_docs, llm):
- if conversation_id is not None:
+ if conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}},
@@ -129,6 +129,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
"name": completion,
"queries": [{"prompt": question, "response": response, "sources": source_log_docs}]}
).inserted_id
+ return conversation_id
def get_prompt(prompt_id):
if prompt_id == 'default':
@@ -293,12 +294,5 @@ def api_search():
source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model
)
docs = retriever.search()
-
- source_log_docs = []
- for doc in docs:
- if doc.metadata:
- source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
- else:
- source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
- return source_log_docs
+ return docs
diff --git a/application/requirements.txt b/application/requirements.txt
index 0874a7c..c984534 100644
--- a/application/requirements.txt
+++ b/application/requirements.txt
@@ -3,6 +3,7 @@ boto3==1.34.6
celery==5.3.6
dataclasses_json==0.6.3
docx2txt==0.8
+duckduckgo-search=5.3.0
EbookLib==0.18
elasticsearch==8.12.0
escodegen==1.0.11
diff --git a/application/retriever/base.py b/application/retriever/base.py
index 3bfaa5e..4a37e81 100644
--- a/application/retriever/base.py
+++ b/application/retriever/base.py
@@ -8,3 +8,7 @@ class BaseRetriever(ABC):
@abstractmethod
def gen(self, *args, **kwargs):
pass
+
+ @abstractmethod
+ def search(self, *args, **kwargs):
+ pass
diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py
index dc75794..a5bf8e3 100644
--- a/application/retriever/classic_rag.py
+++ b/application/retriever/classic_rag.py
@@ -40,23 +40,22 @@ class ClassicRAG(BaseRetriever):
docs = []
else:
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY)
- docs = docsearch.search(self.question, k=self.chunks)
+ docs_temp = docsearch.search(self.question, k=self.chunks)
+ docs = [{"title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, "text": i.page_content} for i in docs_temp]
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
+
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
- docs_together = "\n".join([doc.page_content for doc in docs])
+ docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
- if doc.metadata:
- yield {"source": {"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}}
- else:
- yield {"source": {"title": doc.page_content, "text": doc.page_content}}
+ yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py
new file mode 100644
index 0000000..778313c
--- /dev/null
+++ b/application/retriever/duckduck_search.py
@@ -0,0 +1,97 @@
+import json
+import ast
+from application.retriever.base import BaseRetriever
+from application.core.settings import settings
+from application.llm.llm_creator import LLMCreator
+from application.utils import count_tokens
+from langchain_community.tools import DuckDuckGoSearchResults
+from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
+
+
+
+class DuckDuckSearch(BaseRetriever):
+
+ def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
+ self.question = question
+ self.source = source
+ self.chat_history = chat_history
+ self.prompt = prompt
+ self.chunks = chunks
+ self.gpt_model = gpt_model
+
+ def _parse_lang_string(self, input_string):
+ result = []
+ current_item = ""
+ inside_brackets = False
+ for char in input_string:
+ if char == "[":
+ inside_brackets = True
+ elif char == "]":
+ inside_brackets = False
+ result.append(current_item)
+ current_item = ""
+ elif inside_brackets:
+ current_item += char
+
+ # Check if there is an unmatched opening bracket at the end
+ if inside_brackets:
+ result.append(current_item)
+
+ return result
+
+ def _get_data(self):
+ if self.chunks == 0:
+ docs = []
+ else:
+ wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks)
+ search = DuckDuckGoSearchResults(api_wrapper=wrapper)
+ results = search.run(self.question)
+ results = self._parse_lang_string(results)
+
+ docs = []
+ for i in results:
+ try:
+ text = i.split("title:")[0]
+ title = i.split("title:")[1].split("link:")[0]
+ link = i.split("link:")[1]
+ docs.append({"text": text, "title": title, "link": link})
+ except IndexError:
+ pass
+ if settings.LLM_NAME == "llama.cpp":
+ docs = [docs[0]]
+
+ return docs
+
+ def gen(self):
+ docs = self._get_data()
+
+ # join all page_content together with a newline
+ docs_together = "\n".join([doc["text"] for doc in docs])
+ p_chat_combine = self.prompt.replace("{summaries}", docs_together)
+ messages_combine = [{"role": "system", "content": p_chat_combine}]
+ for doc in docs:
+ yield {"source": doc}
+
+ if len(self.chat_history) > 1:
+ tokens_current_history = 0
+ # count tokens in history
+ self.chat_history.reverse()
+ for i in self.chat_history:
+ if "prompt" in i and "response" in i:
+ tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
+ if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
+ tokens_current_history += tokens_batch
+ messages_combine.append({"role": "user", "content": i["prompt"]})
+ messages_combine.append({"role": "system", "content": i["response"]})
+ messages_combine.append({"role": "user", "content": self.question})
+
+ llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
+
+ completion = llm.gen_stream(model=self.gpt_model,
+ messages=messages_combine)
+ for line in completion:
+ yield {"answer": str(line)}
+
+ def search(self):
+ return self._get_data()
+
diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py
index 9255f4e..892d63a 100644
--- a/application/retriever/retriever_creator.py
+++ b/application/retriever/retriever_creator.py
@@ -1,10 +1,12 @@
from application.retriever.classic_rag import ClassicRAG
+from application.retriever.duckduck_search import DuckDuckSearch
class RetrieverCreator:
retievers = {
'classic': ClassicRAG,
+ 'duckduck': DuckDuckSearch
}
@classmethod
From 19494685ba72a338ce71b3a4fa068f669bd12574 Mon Sep 17 00:00:00 2001
From: Alex
Date: Tue, 9 Apr 2024 16:38:42 +0100
Subject: [PATCH 3/7] Update application files, fix LLM models, and create new
retriever class
---
application/api/answer/routes.py | 25 ++++-
application/api/user/routes.py | 14 +++
application/core/settings.py | 1 +
application/requirements.txt | 2 +-
application/retriever/retriever_creator.py | 2 +-
frontend/src/conversation/conversationApi.ts | 97 ++++++--------------
6 files changed, 67 insertions(+), 74 deletions(-)
diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py
index 11e6c4a..97eb36c 100644
--- a/application/api/answer/routes.py
+++ b/application/api/answer/routes.py
@@ -205,7 +205,12 @@ def stream():
else:
source = {}
- retriever = RetrieverCreator.create_retriever("classic", question=question,
+ if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
+ retriever_name = "classic"
+ else:
+ retriever_name = source['active_docs']
+
+ retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
)
@@ -247,7 +252,12 @@ def api_answer():
else:
source = {data}
- retriever = RetrieverCreator.create_retriever("classic", question=question,
+ if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
+ retriever_name = "classic"
+ else:
+ retriever_name = source['active_docs']
+
+ retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
)
source_log_docs = []
@@ -290,9 +300,14 @@ def api_search():
else:
chunks = 2
- retriever = RetrieverCreator.create_retriever("classic", question=question,
- source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model
- )
+ if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
+ retriever_name = "classic"
+ else:
+ retriever_name = source['active_docs']
+
+ retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
+ source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model
+ )
docs = retriever.search()
return docs
diff --git a/application/api/user/routes.py b/application/api/user/routes.py
index e80ec52..b159d6c 100644
--- a/application/api/user/routes.py
+++ b/application/api/user/routes.py
@@ -237,6 +237,20 @@ def combined_json():
for index in data_remote:
index["location"] = "remote"
data.append(index)
+ if 'duckduck_search' in settings.RETRIEVERS_ENABLED:
+ data.append(
+ {
+ "name": "DuckDuckGo Search",
+ "language": "en",
+ "version": "",
+ "description": "duckduck_search",
+ "fullName": "DuckDuckGo Search",
+ "date": "duckduck_search",
+ "docLink": "duckduck_search",
+ "model": settings.EMBEDDINGS_NAME,
+ "location": "custom",
+ }
+ )
return jsonify(data)
diff --git a/application/core/settings.py b/application/core/settings.py
index 7eac3cb..b069340 100644
--- a/application/core/settings.py
+++ b/application/core/settings.py
@@ -18,6 +18,7 @@ class Settings(BaseSettings):
TOKENS_MAX_HISTORY: int = 150
UPLOAD_FOLDER: str = "inputs"
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant"
+ RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
API_URL: str = "http://localhost:7091" # backend url for celery worker
diff --git a/application/requirements.txt b/application/requirements.txt
index c984534..4642525 100644
--- a/application/requirements.txt
+++ b/application/requirements.txt
@@ -3,7 +3,7 @@ boto3==1.34.6
celery==5.3.6
dataclasses_json==0.6.3
docx2txt==0.8
-duckduckgo-search=5.3.0
+duckduckgo-search==5.3.0
EbookLib==0.18
elasticsearch==8.12.0
escodegen==1.0.11
diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py
index 892d63a..5ec341a 100644
--- a/application/retriever/retriever_creator.py
+++ b/application/retriever/retriever_creator.py
@@ -6,7 +6,7 @@ from application.retriever.duckduck_search import DuckDuckSearch
class RetrieverCreator:
retievers = {
'classic': ClassicRAG,
- 'duckduck': DuckDuckSearch
+ 'duckduck_search': DuckDuckSearch
}
@classmethod
diff --git a/frontend/src/conversation/conversationApi.ts b/frontend/src/conversation/conversationApi.ts
index e3a8219..0e49572 100644
--- a/frontend/src/conversation/conversationApi.ts
+++ b/frontend/src/conversation/conversationApi.ts
@@ -3,6 +3,33 @@ import { Doc } from '../preferences/preferenceApi';
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
+function getDocPath(selectedDocs: Doc | null): string {
+ let docPath = 'default';
+
+ if (selectedDocs) {
+ let namePath = selectedDocs.name;
+ if (selectedDocs.language === namePath) {
+ namePath = '.project';
+ }
+ if (selectedDocs.location === 'local') {
+ docPath = 'local' + '/' + selectedDocs.name + '/';
+ } else if (selectedDocs.location === 'remote') {
+ docPath =
+ selectedDocs.language +
+ '/' +
+ namePath +
+ '/' +
+ selectedDocs.version +
+ '/' +
+ selectedDocs.model +
+ '/';
+ } else if (selectedDocs.location === 'custom') {
+ docPath = selectedDocs.docLink;
+ }
+ }
+
+ return docPath;
+}
export function fetchAnswerApi(
question: string,
signal: AbortSignal,
@@ -28,27 +55,7 @@ export function fetchAnswerApi(
title: any;
}
> {
- let docPath = 'default';
-
- if (selectedDocs) {
- let namePath = selectedDocs.name;
- if (selectedDocs.language === namePath) {
- namePath = '.project';
- }
- if (selectedDocs.location === 'local') {
- docPath = 'local' + '/' + selectedDocs.name + '/';
- } else if (selectedDocs.location === 'remote') {
- docPath =
- selectedDocs.language +
- '/' +
- namePath +
- '/' +
- selectedDocs.version +
- '/' +
- selectedDocs.model +
- '/';
- }
- }
+ const docPath = getDocPath(selectedDocs);
//in history array remove all keys except prompt and response
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
@@ -98,27 +105,7 @@ export function fetchAnswerSteaming(
chunks: string,
onEvent: (event: MessageEvent) => void,
): Promise {
- let docPath = 'default';
-
- if (selectedDocs) {
- let namePath = selectedDocs.name;
- if (selectedDocs.language === namePath) {
- namePath = '.project';
- }
- if (selectedDocs.location === 'local') {
- docPath = 'local' + '/' + selectedDocs.name + '/';
- } else if (selectedDocs.location === 'remote') {
- docPath =
- selectedDocs.language +
- '/' +
- namePath +
- '/' +
- selectedDocs.version +
- '/' +
- selectedDocs.model +
- '/';
- }
- }
+ const docPath = getDocPath(selectedDocs);
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
@@ -195,31 +182,7 @@ export function searchEndpoint(
history: Array = [],
chunks: string,
) {
- /*
- "active_docs": "default",
- "question": "Summarise",
- "conversation_id": null,
- "history": "[]" */
- let docPath = 'default';
- if (selectedDocs) {
- let namePath = selectedDocs.name;
- if (selectedDocs.language === namePath) {
- namePath = '.project';
- }
- if (selectedDocs.location === 'local') {
- docPath = 'local' + '/' + selectedDocs.name + '/';
- } else if (selectedDocs.location === 'remote') {
- docPath =
- selectedDocs.language +
- '/' +
- namePath +
- '/' +
- selectedDocs.version +
- '/' +
- selectedDocs.model +
- '/';
- }
- }
+ const docPath = getDocPath(selectedDocs);
const body = {
question: question,
From 7a02df558849c43f9e257a7d5bdca5c74ed5b0fc Mon Sep 17 00:00:00 2001
From: Pavel
Date: Tue, 9 Apr 2024 19:56:07 +0400
Subject: [PATCH 4/7] Multiple uploads
---
application/api/user/routes.py | 56 +++++++++++++++++++++-------------
application/worker.py | 36 +++++++++++++++++++---
frontend/src/upload/Upload.tsx | 2 +-
3 files changed, 67 insertions(+), 27 deletions(-)
diff --git a/application/api/user/routes.py b/application/api/user/routes.py
index e80ec52..7e5462b 100644
--- a/application/api/user/routes.py
+++ b/application/api/user/routes.py
@@ -1,5 +1,6 @@
import os
import uuid
+import shutil
from flask import Blueprint, request, jsonify
from urllib.parse import urlparse
import requests
@@ -136,30 +137,43 @@ def upload_file():
return {"status": "no name"}
job_name = secure_filename(request.form["name"])
# check if the post request has the file part
- if "file" not in request.files:
- print("No file part")
- return {"status": "no file"}
- file = request.files["file"]
- if file.filename == "":
+ files = request.files.getlist("file")
+
+ if not files or all(file.filename == '' for file in files):
return {"status": "no file name"}
- if file:
- filename = secure_filename(file.filename)
- # save dir
- save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name)
- # create dir if not exists
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
-
- file.save(os.path.join(save_dir, filename))
- task = ingest.delay(settings.UPLOAD_FOLDER, [".rst", ".md", ".pdf", ".txt", ".docx",
- ".csv", ".epub", ".html", ".mdx"],
- job_name, filename, user)
- # task id
- task_id = task.id
- return {"status": "ok", "task_id": task_id}
+ # Directory where files will be saved
+ save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name)
+ os.makedirs(save_dir, exist_ok=True)
+
+ if len(files) > 1:
+ # Multiple files; prepare them for zip
+ temp_dir = os.path.join(save_dir, "temp")
+ os.makedirs(temp_dir, exist_ok=True)
+
+ for file in files:
+ filename = secure_filename(file.filename)
+ file.save(os.path.join(temp_dir, filename))
+
+ # Use shutil.make_archive to zip the temp directory
+ zip_path = shutil.make_archive(base_name=os.path.join(save_dir, job_name), format='zip', root_dir=temp_dir)
+ final_filename = os.path.basename(zip_path)
+
+ # Clean up the temporary directory after zipping
+ shutil.rmtree(temp_dir)
else:
- return {"status": "error"}
+ # Single file
+ file = files[0]
+ final_filename = secure_filename(file.filename)
+ file_path = os.path.join(save_dir, final_filename)
+ file.save(file_path)
+
+ # Call ingest with the single file or zipped file
+ task = ingest.delay(settings.UPLOAD_FOLDER, [".rst", ".md", ".pdf", ".txt", ".docx",
+ ".csv", ".epub", ".html", ".mdx"],
+ job_name, final_filename, user)
+
+ return {"status": "ok", "task_id": task.id}
@user.route("/api/remote", methods=["POST"])
def upload_remote():
diff --git a/application/worker.py b/application/worker.py
index 3891fde..eb28242 100644
--- a/application/worker.py
+++ b/application/worker.py
@@ -36,6 +36,32 @@ current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
+def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
+ """
+ Recursively extract zip files with a limit on recursion depth.
+
+ Args:
+ zip_path (str): Path to the zip file to be extracted.
+ extract_to (str): Destination path for extracted files.
+ current_depth (int): Current depth of recursion.
+ max_depth (int): Maximum allowed depth of recursion to prevent infinite loops.
+ """
+ if current_depth > max_depth:
+ print(f"Reached maximum recursion depth of {max_depth}")
+ return
+
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
+ zip_ref.extractall(extract_to)
+ os.remove(zip_path) # Remove the zip file after extracting
+
+ # Check for nested zip files and extract them
+ for root, dirs, files in os.walk(extract_to):
+ for file in files:
+ if file.endswith(".zip"):
+ # If a nested zip file is found, extract it recursively
+ file_path = os.path.join(root, file)
+ extract_zip_recursive(file_path, root, current_depth + 1, max_depth)
+
# Define the main function for ingesting and processing documents.
def ingest_worker(self, directory, formats, name_job, filename, user):
@@ -66,9 +92,11 @@ def ingest_worker(self, directory, formats, name_job, filename, user):
token_check = True
min_tokens = 150
max_tokens = 1250
- full_path = directory + "/" + user + "/" + name_job
+ recursion_depth = 2
+ full_path = os.path.join(directory, user, name_job)
import sys
+
print(full_path, file=sys.stderr)
# check if API_URL env variable is set
file_data = {"name": name_job, "file": filename, "user": user}
@@ -81,14 +109,12 @@ def ingest_worker(self, directory, formats, name_job, filename, user):
if not os.path.exists(full_path):
os.makedirs(full_path)
- with open(full_path + "/" + filename, "wb") as f:
+ with open(os.path.join(full_path, filename), "wb") as f:
f.write(file)
# check if file is .zip and extract it
if filename.endswith(".zip"):
- with zipfile.ZipFile(full_path + "/" + filename, "r") as zip_ref:
- zip_ref.extractall(full_path)
- os.remove(full_path + "/" + filename)
+ extract_zip_recursive(os.path.join(full_path, filename), full_path, 0, recursion_depth)
self.update_state(state="PROGRESS", meta={"current": 1})
diff --git a/frontend/src/upload/Upload.tsx b/frontend/src/upload/Upload.tsx
index 39c2a09..3ae2178 100644
--- a/frontend/src/upload/Upload.tsx
+++ b/frontend/src/upload/Upload.tsx
@@ -201,7 +201,7 @@ export default function Upload({
const { getRootProps, getInputProps, isDragActive } = useDropzone({
onDrop,
- multiple: false,
+ multiple: true,
onDragEnter: doNothing,
onDragOver: doNothing,
onDragLeave: doNothing,
From e03e185d30ded4215ed3045994902ab8c1936fec Mon Sep 17 00:00:00 2001
From: Alex
Date: Tue, 9 Apr 2024 17:11:09 +0100
Subject: [PATCH 5/7] Add Brave Search retriever and update application files
---
application/api/user/routes.py | 14 ++++
application/core/settings.py | 2 +
application/retriever/brave_search.py | 75 ++++++++++++++++++++++
application/retriever/duckduck_search.py | 3 -
application/retriever/retriever_creator.py | 4 +-
5 files changed, 94 insertions(+), 4 deletions(-)
create mode 100644 application/retriever/brave_search.py
diff --git a/application/api/user/routes.py b/application/api/user/routes.py
index b159d6c..3222832 100644
--- a/application/api/user/routes.py
+++ b/application/api/user/routes.py
@@ -251,6 +251,20 @@ def combined_json():
"location": "custom",
}
)
+ if 'brave_search' in settings.RETRIEVERS_ENABLED:
+ data.append(
+ {
+ "name": "Brave Search",
+ "language": "en",
+ "version": "",
+ "description": "brave_search",
+ "fullName": "Brave Search",
+ "date": "brave_search",
+ "docLink": "brave_search",
+ "model": settings.EMBEDDINGS_NAME,
+ "location": "custom",
+ }
+ )
return jsonify(data)
diff --git a/application/core/settings.py b/application/core/settings.py
index b069340..d8d0eb3 100644
--- a/application/core/settings.py
+++ b/application/core/settings.py
@@ -60,6 +60,8 @@ class Settings(BaseSettings):
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
+ BRAVE_SEARCH_API_KEY: Optional[str] = None
+
FLASK_DEBUG_MODE: bool = False
diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py
new file mode 100644
index 0000000..0cc7bd4
--- /dev/null
+++ b/application/retriever/brave_search.py
@@ -0,0 +1,75 @@
+import json
+from application.retriever.base import BaseRetriever
+from application.core.settings import settings
+from application.llm.llm_creator import LLMCreator
+from application.utils import count_tokens
+from langchain_community.tools import BraveSearch
+
+
+
+class BraveRetSearch(BaseRetriever):
+
+ def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
+ self.question = question
+ self.source = source
+ self.chat_history = chat_history
+ self.prompt = prompt
+ self.chunks = chunks
+ self.gpt_model = gpt_model
+
+ def _get_data(self):
+ if self.chunks == 0:
+ docs = []
+ else:
+ search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY,
+ search_kwargs={"count": int(self.chunks)})
+ results = search.run(self.question)
+ results = json.loads(results)
+
+ docs = []
+ for i in results:
+ try:
+ title = i['title']
+ link = i['link']
+ snippet = i['snippet']
+ docs.append({"text": snippet, "title": title, "link": link})
+ except IndexError:
+ pass
+ if settings.LLM_NAME == "llama.cpp":
+ docs = [docs[0]]
+
+ return docs
+
+ def gen(self):
+ docs = self._get_data()
+
+ # join all page_content together with a newline
+ docs_together = "\n".join([doc["text"] for doc in docs])
+ p_chat_combine = self.prompt.replace("{summaries}", docs_together)
+ messages_combine = [{"role": "system", "content": p_chat_combine}]
+ for doc in docs:
+ yield {"source": doc}
+
+ if len(self.chat_history) > 1:
+ tokens_current_history = 0
+ # count tokens in history
+ self.chat_history.reverse()
+ for i in self.chat_history:
+ if "prompt" in i and "response" in i:
+ tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
+ if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
+ tokens_current_history += tokens_batch
+ messages_combine.append({"role": "user", "content": i["prompt"]})
+ messages_combine.append({"role": "system", "content": i["response"]})
+ messages_combine.append({"role": "user", "content": self.question})
+
+ llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
+
+ completion = llm.gen_stream(model=self.gpt_model,
+ messages=messages_combine)
+ for line in completion:
+ yield {"answer": str(line)}
+
+ def search(self):
+ return self._get_data()
+
diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py
index 778313c..d662bb0 100644
--- a/application/retriever/duckduck_search.py
+++ b/application/retriever/duckduck_search.py
@@ -1,5 +1,3 @@
-import json
-import ast
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
@@ -33,7 +31,6 @@ class DuckDuckSearch(BaseRetriever):
elif inside_brackets:
current_item += char
- # Check if there is an unmatched opening bracket at the end
if inside_brackets:
result.append(current_item)
diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py
index 5ec341a..ad07140 100644
--- a/application/retriever/retriever_creator.py
+++ b/application/retriever/retriever_creator.py
@@ -1,12 +1,14 @@
from application.retriever.classic_rag import ClassicRAG
from application.retriever.duckduck_search import DuckDuckSearch
+from application.retriever.brave_search import BraveRetSearch
class RetrieverCreator:
retievers = {
'classic': ClassicRAG,
- 'duckduck_search': DuckDuckSearch
+ 'duckduck_search': DuckDuckSearch,
+ 'brave_search': BraveRetSearch
}
@classmethod
From 4b849d720142e081fbace0868abe1ecef9187e68 Mon Sep 17 00:00:00 2001
From: Alex
Date: Tue, 9 Apr 2024 17:20:26 +0100
Subject: [PATCH 6/7] Fix SagemakerAPILLM test
---
tests/llm/test_sagemaker.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py
index f8d02d8..0602f59 100644
--- a/tests/llm/test_sagemaker.py
+++ b/tests/llm/test_sagemaker.py
@@ -54,7 +54,7 @@ class TestSagemakerAPILLM(unittest.TestCase):
def test_gen(self):
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
return_value=self.response) as mock_invoke_endpoint:
- output = self.sagemaker.gen(None, None, self.messages)
+ output = self.sagemaker.gen(None, self.messages)
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
@@ -66,7 +66,7 @@ class TestSagemakerAPILLM(unittest.TestCase):
def test_gen_stream(self):
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
return_value=self.response) as mock_invoke_endpoint:
- output = list(self.sagemaker.gen_stream(None, None, self.messages))
+ output = list(self.sagemaker.gen_stream(None, self.messages))
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
From 8d7a134cb40502b0bd8474a1ed603da2ced8ac08 Mon Sep 17 00:00:00 2001
From: Alex
Date: Tue, 9 Apr 2024 17:25:08 +0100
Subject: [PATCH 7/7] lint: ruff
---
application/api/answer/routes.py | 2 --
application/api/user/routes.py | 9 +++++----
application/core/settings.py | 2 +-
application/parser/token_func.py | 5 ++++-
application/retriever/classic_rag.py | 15 ++++++++++++---
5 files changed, 22 insertions(+), 11 deletions(-)
diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py
index 97eb36c..fa0ac4f 100644
--- a/application/api/answer/routes.py
+++ b/application/api/answer/routes.py
@@ -8,12 +8,10 @@ import traceback
from pymongo import MongoClient
from bson.objectid import ObjectId
-from application.utils import count_tokens
from application.core.settings import settings
-from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.error import bad_request
diff --git a/application/api/user/routes.py b/application/api/user/routes.py
index 3222832..cacfbd7 100644
--- a/application/api/user/routes.py
+++ b/application/api/user/routes.py
@@ -283,10 +283,12 @@ def check_docs():
else:
file_url = urlparse(base_path + vectorstore + "index.faiss")
- if file_url.scheme in ['https'] and file_url.netloc == 'raw.githubusercontent.com' and file_url.path.startswith('/arc53/DocsHUB/main/'):
-
+ if (
+ file_url.scheme in ['https'] and
+ file_url.netloc == 'raw.githubusercontent.com' and
+ file_url.path.startswith('/arc53/DocsHUB/main/')
+ ):
r = requests.get(file_url.geturl())
-
if r.status_code != 200:
return {"status": "null"}
else:
@@ -295,7 +297,6 @@ def check_docs():
with open(vectorstore + "index.faiss", "wb") as f:
f.write(r.content)
- # download the store
r = requests.get(base_path + vectorstore + "index.pkl")
with open(vectorstore + "index.pkl", "wb") as f:
f.write(r.content)
diff --git a/application/core/settings.py b/application/core/settings.py
index d8d0eb3..26c27ed 100644
--- a/application/core/settings.py
+++ b/application/core/settings.py
@@ -9,7 +9,7 @@ current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__
class Settings(BaseSettings):
LLM_NAME: str = "docsgpt"
- MODEL_NAME: Optional[str] = None # when LLM_NAME is openai, MODEL_NAME can be e.g. gpt-4-turbo-preview or gpt-3.5-turbo
+ MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
diff --git a/application/parser/token_func.py b/application/parser/token_func.py
index 36ae7e5..7511cde 100644
--- a/application/parser/token_func.py
+++ b/application/parser/token_func.py
@@ -22,7 +22,10 @@ def group_documents(documents: List[Document], min_tokens: int, max_tokens: int)
doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
# Check if current group is empty or if the document can be added based on token count and matching metadata
- if current_group is None or (len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and doc_len < min_tokens and current_group.extra_info == doc.extra_info):
+ if (current_group is None or
+ (len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and
+ doc_len < min_tokens and
+ current_group.extra_info == doc.extra_info)):
if current_group is None:
current_group = doc # Use the document directly to retain its metadata
else:
diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py
index a5bf8e3..b5f1eb9 100644
--- a/application/retriever/classic_rag.py
+++ b/application/retriever/classic_rag.py
@@ -1,5 +1,4 @@
import os
-import json
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
@@ -39,9 +38,19 @@ class ClassicRAG(BaseRetriever):
if self.chunks == 0:
docs = []
else:
- docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY)
+ docsearch = VectorCreator.create_vectorstore(
+ settings.VECTOR_STORE,
+ self.vectorstore,
+ settings.EMBEDDINGS_KEY
+ )
docs_temp = docsearch.search(self.question, k=self.chunks)
- docs = [{"title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, "text": i.page_content} for i in docs_temp]
+ docs = [
+ {
+ "title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content,
+ "text": i.page_content
+ }
+ for i in docs_temp
+ ]
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]