From 391f6861733d78e589db221f967408f9fa95893a Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 9 Apr 2024 14:02:33 +0100 Subject: [PATCH 1/7] Update application files and fix LLM models, create new retriever class --- application/api/answer/routes.py | 245 +++++++-------------- application/llm/anthropic.py | 4 +- application/llm/docsgpt_provider.py | 4 +- application/llm/huggingface.py | 4 +- application/llm/llama_cpp.py | 4 +- application/llm/openai.py | 4 +- application/llm/premai.py | 4 +- application/llm/sagemaker.py | 4 +- application/retriever/__init__.py | 0 application/retriever/base.py | 10 + application/retriever/classic_rag.py | 83 +++++++ application/retriever/retriever_creator.py | 15 ++ application/utils.py | 6 + 13 files changed, 202 insertions(+), 185 deletions(-) create mode 100644 application/retriever/__init__.py create mode 100644 application/retriever/base.py create mode 100644 application/retriever/classic_rag.py create mode 100644 application/retriever/retriever_creator.py create mode 100644 application/utils.py diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index b122eac..1b3c9b9 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -8,13 +8,14 @@ import traceback from pymongo import MongoClient from bson.objectid import ObjectId -from transformers import GPT2TokenizerFast +from application.utils import count_tokens from application.core.settings import settings from application.vectorstore.vector_creator import VectorCreator from application.llm.llm_creator import LLMCreator +from application.retriever.retriever_creator import RetrieverCreator from application.error import bad_request @@ -62,9 +63,6 @@ async def async_generate(chain, question, chat_history): return result -def count_tokens(string): - tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') - return len(tokenizer(string)['input_ids']) def run_async_chain(chain, question, chat_history): @@ -104,61 +102,11 @@ def get_vectorstore(data): def is_azure_configured(): return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME - -def complete_stream(question, docsearch, chat_history, prompt_id, conversation_id, chunks=2): - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) - if prompt_id == 'default': - prompt = chat_combine_template - elif prompt_id == 'creative': - prompt = chat_combine_creative - elif prompt_id == 'strict': - prompt = chat_combine_strict - else: - prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] - - if chunks == 0: - docs = [] - else: - docs = docsearch.search(question, k=chunks) - if settings.LLM_NAME == "llama.cpp": - docs = [docs[0]] - # join all page_content together with a newline - docs_together = "\n".join([doc.page_content for doc in docs]) - p_chat_combine = prompt.replace("{summaries}", docs_together) - messages_combine = [{"role": "system", "content": p_chat_combine}] - source_log_docs = [] - for doc in docs: - if doc.metadata: - source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) - else: - source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) - - if len(chat_history) > 1: - tokens_current_history = 0 - # count tokens in history - chat_history.reverse() - for i in chat_history: - if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: - tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) - messages_combine.append({"role": "user", "content": question}) - - response_full = "" - completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, - messages=messages_combine) - for line in completion: - data = json.dumps({"answer": str(line)}) - response_full += str(line) - yield f"data: {data}\n\n" - - # save conversation to database +def save_conversation(conversation_id, question, response, source_log_docs, llm): if conversation_id is not None: conversations_collection.update_one( {"_id": ObjectId(conversation_id)}, - {"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}}, + {"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}}, ) else: @@ -168,20 +116,50 @@ def complete_stream(question, docsearch, chat_history, prompt_id, conversation_i "words, respond ONLY with the summary, use the same " "language as the system \n\nUser: " + question + "\n\n" + "AI: " + - response_full}, + response}, {"role": "user", "content": "Summarise following conversation in no more than 3 words, " "respond ONLY with the summary, use the same language as the " "system"}] - completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, + completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30) conversation_id = conversations_collection.insert_one( {"user": "local", "date": datetime.datetime.utcnow(), "name": completion, - "queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]} + "queries": [{"prompt": question, "response": response, "sources": source_log_docs}]} ).inserted_id +def get_prompt(prompt_id): + if prompt_id == 'default': + prompt = chat_combine_template + elif prompt_id == 'creative': + prompt = chat_combine_creative + elif prompt_id == 'strict': + prompt = chat_combine_strict + else: + prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] + return prompt + + +def complete_stream(question, retriever, conversation_id): + + + response_full = "" + source_log_docs = [] + answer = retriever.gen() + for line in answer: + if "answer" in line: + response_full += str(line["answer"]) + data = json.dumps(line) + yield f"data: {data}\n\n" + elif "source" in line: + source_log_docs.append(line["source"]) + + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm) + # send data.type = "end" to indicate that the stream has ended as json data = json.dumps({"type": "id", "id": str(conversation_id)}) yield f"data: {data}\n\n" @@ -213,25 +191,26 @@ def stream(): chunks = int(data["chunks"]) else: chunks = 2 + + prompt = get_prompt(prompt_id) # check if active_docs is set if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) - vectorstore = get_vectorstore({"active_docs": data_key["source"]}) + source = {"active_docs": data_key["source"]} elif "active_docs" in data: - vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) + source = {"active_docs": data["active_docs"]} else: - vectorstore = "" - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) + source = {} + + retriever = RetrieverCreator.create_retriever("classic", question=question, + source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model + ) return Response( - complete_stream(question, docsearch, - chat_history=history, - prompt_id=prompt_id, - conversation_id=conversation_id, - chunks=chunks), mimetype="text/event-stream" - ) + complete_stream(question=question, retriever=retriever, + conversation_id=conversation_id), mimetype="text/event-stream") @answer.route("/api/answer", methods=["POST"]) @@ -255,110 +234,35 @@ def api_answer(): chunks = int(data["chunks"]) else: chunks = 2 - - if prompt_id == 'default': - prompt = chat_combine_template - elif prompt_id == 'creative': - prompt = chat_combine_creative - elif prompt_id == 'strict': - prompt = chat_combine_strict - else: - prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] + + prompt = get_prompt(prompt_id) # use try and except to check for exception try: # check if the vectorstore is set if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) - vectorstore = get_vectorstore({"active_docs": data_key["source"]}) + source = {"active_docs": data_key["source"]} else: - vectorstore = get_vectorstore(data) - # loading the index and the store and the prompt template - # Note if you have used other embeddings than OpenAI, you need to change the embeddings - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) - + source = {data} + retriever = RetrieverCreator.create_retriever("classic", question=question, + source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model + ) + source_log_docs = [] + response_full = "" + for line in retriever.gen(): + if "source" in line: + source_log_docs.append(line["source"]) + elif "answer" in line: + response_full += line["answer"] + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + result = {"answer": response_full, "sources": source_log_docs} + result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm) - - if chunks == 0: - docs = [] - else: - docs = docsearch.search(question, k=chunks) - # join all page_content together with a newline - docs_together = "\n".join([doc.page_content for doc in docs]) - p_chat_combine = prompt.replace("{summaries}", docs_together) - messages_combine = [{"role": "system", "content": p_chat_combine}] - source_log_docs = [] - for doc in docs: - if doc.metadata: - source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) - else: - source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) - # join all page_content together with a newline - - - if len(history) > 1: - tokens_current_history = 0 - # count tokens in history - history.reverse() - for i in history: - if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: - tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) - messages_combine.append({"role": "user", "content": question}) - - - completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME, - messages=messages_combine) - - - result = {"answer": completion, "sources": source_log_docs} - logger.debug(result) - - # generate conversationId - if conversation_id is not None: - conversations_collection.update_one( - {"_id": ObjectId(conversation_id)}, - {"$push": {"queries": {"prompt": question, - "response": result["answer"], "sources": result['sources']}}}, - ) - - else: - # create new conversation - # generate summary - messages_summary = [ - {"role": "assistant", "content": "Summarise following conversation in no more than 3 words, " - "respond ONLY with the summary, use the same language as the system \n\n" - "User: " + question + "\n\n" + "AI: " + result["answer"]}, - {"role": "user", "content": "Summarise following conversation in no more than 3 words, " - "respond ONLY with the summary, use the same language as the system"} - ] - - completion = llm.gen( - model=gpt_model, - engine=settings.AZURE_DEPLOYMENT_NAME, - messages=messages_summary, - max_tokens=30 - ) - conversation_id = conversations_collection.insert_one( - {"user": "local", - "date": datetime.datetime.utcnow(), - "name": completion, - "queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]} - ).inserted_id - - result["conversation_id"] = str(conversation_id) - - # mock result - # result = { - # "answer": "The answer is 42", - # "sources": ["https://en.wikipedia.org/wiki/42_(number)", "https://en.wikipedia.org/wiki/42_(number)"] - # } return result except Exception as e: # print whole traceback @@ -375,20 +279,20 @@ def api_search(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) - vectorstore = data_key["source"] + source = {"active_docs": data_key["source"]} elif "active_docs" in data: - vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) + source = {"active_docs": data["active_docs"]} else: - vectorstore = "" + source = {} if 'chunks' in data: chunks = int(data["chunks"]) else: chunks = 2 - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) - if chunks == 0: - docs = [] - else: - docs = docsearch.search(question, k=chunks) + + retriever = RetrieverCreator.create_retriever("classic", question=question, + source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model + ) + docs = retriever.search() source_log_docs = [] for doc in docs: @@ -396,6 +300,5 @@ def api_search(): source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) else: source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) - #yield f"data:{data}\n\n" return source_log_docs diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index a64d71e..6b0d646 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -10,7 +10,7 @@ class AnthropicLLM(BaseLLM): self.HUMAN_PROMPT = HUMAN_PROMPT self.AI_PROMPT = AI_PROMPT - def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs): + def gen(self, model, messages, max_tokens=300, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Context \n {context} \n ### Question \n {user_question}" @@ -25,7 +25,7 @@ class AnthropicLLM(BaseLLM): ) return completion.completion - def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs): + def gen_stream(self, model, messages, max_tokens=300, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Context \n {context} \n ### Question \n {user_question}" diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index e0c5dba..d540a91 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -8,7 +8,7 @@ class DocsGPTAPILLM(BaseLLM): self.endpoint = "https://llm.docsgpt.co.uk" - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -24,7 +24,7 @@ class DocsGPTAPILLM(BaseLLM): return response_clean - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index ef3b1fb..554bee2 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -29,7 +29,7 @@ class HuggingFaceLLM(BaseLLM): ) hf = HuggingFacePipeline(pipeline=pipe) - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -38,7 +38,7 @@ class HuggingFaceLLM(BaseLLM): return result.content - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.") diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index f18d437..be34d4f 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -12,7 +12,7 @@ class LlamaCpp(BaseLLM): llama = Llama(model_path=llm_name, n_ctx=2048) - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -24,7 +24,7 @@ class LlamaCpp(BaseLLM): return result['choices'][0]['text'].split('### Answer \n')[-1] - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/llm/openai.py b/application/llm/openai.py index a132399..4b0ed25 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -18,7 +18,7 @@ class OpenAILLM(BaseLLM): return openai - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): response = self.client.chat.completions.create(model=model, messages=messages, stream=stream, @@ -26,7 +26,7 @@ class OpenAILLM(BaseLLM): return response.choices[0].message.content - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): response = self.client.chat.completions.create(model=model, messages=messages, stream=stream, diff --git a/application/llm/premai.py b/application/llm/premai.py index 4bc8a89..5faa5fe 100644 --- a/application/llm/premai.py +++ b/application/llm/premai.py @@ -12,7 +12,7 @@ class PremAILLM(BaseLLM): self.api_key = api_key self.project_id = settings.PREMAI_PROJECT_ID - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): response = self.client.chat.completions.create(model=model, project_id=self.project_id, messages=messages, @@ -21,7 +21,7 @@ class PremAILLM(BaseLLM): return response.choices[0].message["content"] - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): response = self.client.chat.completions.create(model=model, project_id=self.project_id, messages=messages, diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 84ae09a..b81f638 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -74,7 +74,7 @@ class SagemakerAPILLM(BaseLLM): self.runtime = runtime - def gen(self, model, engine, messages, stream=False, **kwargs): + def gen(self, model, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -103,7 +103,7 @@ class SagemakerAPILLM(BaseLLM): print(result[0]['generated_text'], file=sys.stderr) return result[0]['generated_text'][len(prompt):] - def gen_stream(self, model, engine, messages, stream=True, **kwargs): + def gen_stream(self, model, messages, stream=True, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/retriever/__init__.py b/application/retriever/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/application/retriever/base.py b/application/retriever/base.py new file mode 100644 index 0000000..3bfaa5e --- /dev/null +++ b/application/retriever/base.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + + +class BaseRetriever(ABC): + def __init__(self): + pass + + @abstractmethod + def gen(self, *args, **kwargs): + pass diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py new file mode 100644 index 0000000..dc75794 --- /dev/null +++ b/application/retriever/classic_rag.py @@ -0,0 +1,83 @@ +import os +import json +from application.retriever.base import BaseRetriever +from application.core.settings import settings +from application.vectorstore.vector_creator import VectorCreator +from application.llm.llm_creator import LLMCreator + +from application.utils import count_tokens + + + +class ClassicRAG(BaseRetriever): + + def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + self.question = question + self.vectorstore = self._get_vectorstore(source=source) + self.chat_history = chat_history + self.prompt = prompt + self.chunks = chunks + self.gpt_model = gpt_model + + def _get_vectorstore(self, source): + if "active_docs" in source: + if source["active_docs"].split("/")[0] == "default": + vectorstore = "" + elif source["active_docs"].split("/")[0] == "local": + vectorstore = "indexes/" + source["active_docs"] + else: + vectorstore = "vectors/" + source["active_docs"] + if source["active_docs"] == "default": + vectorstore = "" + else: + vectorstore = "" + vectorstore = os.path.join("application", vectorstore) + return vectorstore + + + def _get_data(self): + if self.chunks == 0: + docs = [] + else: + docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY) + docs = docsearch.search(self.question, k=self.chunks) + if settings.LLM_NAME == "llama.cpp": + docs = [docs[0]] + return docs + + def gen(self): + docs = self._get_data() + + # join all page_content together with a newline + docs_together = "\n".join([doc.page_content for doc in docs]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + for doc in docs: + if doc.metadata: + yield {"source": {"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}} + else: + yield {"source": {"title": doc.page_content, "text": doc.page_content}} + + if len(self.chat_history) > 1: + tokens_current_history = 0 + # count tokens in history + self.chat_history.reverse() + for i in self.chat_history: + if "prompt" in i and "response" in i: + tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) + if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_current_history += tokens_batch + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append({"role": "user", "content": self.question}) + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + completion = llm.gen_stream(model=self.gpt_model, + messages=messages_combine) + for line in completion: + yield {"answer": str(line)} + + def search(self): + return self._get_data() + diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py new file mode 100644 index 0000000..9255f4e --- /dev/null +++ b/application/retriever/retriever_creator.py @@ -0,0 +1,15 @@ +from application.retriever.classic_rag import ClassicRAG + + + +class RetrieverCreator: + retievers = { + 'classic': ClassicRAG, + } + + @classmethod + def create_retriever(cls, type, *args, **kwargs): + retiever_class = cls.retievers.get(type.lower()) + if not retiever_class: + raise ValueError(f"No retievers class found for type {type}") + return retiever_class(*args, **kwargs) \ No newline at end of file diff --git a/application/utils.py b/application/utils.py new file mode 100644 index 0000000..ac98efc --- /dev/null +++ b/application/utils.py @@ -0,0 +1,6 @@ +from transformers import GPT2TokenizerFast + + +def count_tokens(string): + tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') + return len(tokenizer(string)['input_ids']) \ No newline at end of file From 1e26943c3e2cf1ccb4526a9686d79968e63ba6bc Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 9 Apr 2024 15:45:24 +0100 Subject: [PATCH 2/7] Update application files, fix LLM models, and create new retriever class --- application/api/answer/routes.py | 12 +-- application/requirements.txt | 1 + application/retriever/base.py | 4 + application/retriever/classic_rag.py | 11 ++- application/retriever/duckduck_search.py | 97 ++++++++++++++++++++++ application/retriever/retriever_creator.py | 2 + 6 files changed, 112 insertions(+), 15 deletions(-) create mode 100644 application/retriever/duckduck_search.py diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 1b3c9b9..11e6c4a 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -103,7 +103,7 @@ def is_azure_configured(): return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME def save_conversation(conversation_id, question, response, source_log_docs, llm): - if conversation_id is not None: + if conversation_id is not None and conversation_id != "None": conversations_collection.update_one( {"_id": ObjectId(conversation_id)}, {"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}}, @@ -129,6 +129,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm) "name": completion, "queries": [{"prompt": question, "response": response, "sources": source_log_docs}]} ).inserted_id + return conversation_id def get_prompt(prompt_id): if prompt_id == 'default': @@ -293,12 +294,5 @@ def api_search(): source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model ) docs = retriever.search() - - source_log_docs = [] - for doc in docs: - if doc.metadata: - source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) - else: - source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) - return source_log_docs + return docs diff --git a/application/requirements.txt b/application/requirements.txt index 0874a7c..c984534 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -3,6 +3,7 @@ boto3==1.34.6 celery==5.3.6 dataclasses_json==0.6.3 docx2txt==0.8 +duckduckgo-search=5.3.0 EbookLib==0.18 elasticsearch==8.12.0 escodegen==1.0.11 diff --git a/application/retriever/base.py b/application/retriever/base.py index 3bfaa5e..4a37e81 100644 --- a/application/retriever/base.py +++ b/application/retriever/base.py @@ -8,3 +8,7 @@ class BaseRetriever(ABC): @abstractmethod def gen(self, *args, **kwargs): pass + + @abstractmethod + def search(self, *args, **kwargs): + pass diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index dc75794..a5bf8e3 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -40,23 +40,22 @@ class ClassicRAG(BaseRetriever): docs = [] else: docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY) - docs = docsearch.search(self.question, k=self.chunks) + docs_temp = docsearch.search(self.question, k=self.chunks) + docs = [{"title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, "text": i.page_content} for i in docs_temp] if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] + return docs def gen(self): docs = self._get_data() # join all page_content together with a newline - docs_together = "\n".join([doc.page_content for doc in docs]) + docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) messages_combine = [{"role": "system", "content": p_chat_combine}] for doc in docs: - if doc.metadata: - yield {"source": {"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}} - else: - yield {"source": {"title": doc.page_content, "text": doc.page_content}} + yield {"source": doc} if len(self.chat_history) > 1: tokens_current_history = 0 diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py new file mode 100644 index 0000000..778313c --- /dev/null +++ b/application/retriever/duckduck_search.py @@ -0,0 +1,97 @@ +import json +import ast +from application.retriever.base import BaseRetriever +from application.core.settings import settings +from application.llm.llm_creator import LLMCreator +from application.utils import count_tokens +from langchain_community.tools import DuckDuckGoSearchResults +from langchain_community.utilities import DuckDuckGoSearchAPIWrapper + + + +class DuckDuckSearch(BaseRetriever): + + def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + self.question = question + self.source = source + self.chat_history = chat_history + self.prompt = prompt + self.chunks = chunks + self.gpt_model = gpt_model + + def _parse_lang_string(self, input_string): + result = [] + current_item = "" + inside_brackets = False + for char in input_string: + if char == "[": + inside_brackets = True + elif char == "]": + inside_brackets = False + result.append(current_item) + current_item = "" + elif inside_brackets: + current_item += char + + # Check if there is an unmatched opening bracket at the end + if inside_brackets: + result.append(current_item) + + return result + + def _get_data(self): + if self.chunks == 0: + docs = [] + else: + wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks) + search = DuckDuckGoSearchResults(api_wrapper=wrapper) + results = search.run(self.question) + results = self._parse_lang_string(results) + + docs = [] + for i in results: + try: + text = i.split("title:")[0] + title = i.split("title:")[1].split("link:")[0] + link = i.split("link:")[1] + docs.append({"text": text, "title": title, "link": link}) + except IndexError: + pass + if settings.LLM_NAME == "llama.cpp": + docs = [docs[0]] + + return docs + + def gen(self): + docs = self._get_data() + + # join all page_content together with a newline + docs_together = "\n".join([doc["text"] for doc in docs]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + for doc in docs: + yield {"source": doc} + + if len(self.chat_history) > 1: + tokens_current_history = 0 + # count tokens in history + self.chat_history.reverse() + for i in self.chat_history: + if "prompt" in i and "response" in i: + tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) + if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_current_history += tokens_batch + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append({"role": "user", "content": self.question}) + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + completion = llm.gen_stream(model=self.gpt_model, + messages=messages_combine) + for line in completion: + yield {"answer": str(line)} + + def search(self): + return self._get_data() + diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py index 9255f4e..892d63a 100644 --- a/application/retriever/retriever_creator.py +++ b/application/retriever/retriever_creator.py @@ -1,10 +1,12 @@ from application.retriever.classic_rag import ClassicRAG +from application.retriever.duckduck_search import DuckDuckSearch class RetrieverCreator: retievers = { 'classic': ClassicRAG, + 'duckduck': DuckDuckSearch } @classmethod From 19494685ba72a338ce71b3a4fa068f669bd12574 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 9 Apr 2024 16:38:42 +0100 Subject: [PATCH 3/7] Update application files, fix LLM models, and create new retriever class --- application/api/answer/routes.py | 25 ++++- application/api/user/routes.py | 14 +++ application/core/settings.py | 1 + application/requirements.txt | 2 +- application/retriever/retriever_creator.py | 2 +- frontend/src/conversation/conversationApi.ts | 97 ++++++-------------- 6 files changed, 67 insertions(+), 74 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 11e6c4a..97eb36c 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -205,7 +205,12 @@ def stream(): else: source = {} - retriever = RetrieverCreator.create_retriever("classic", question=question, + if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + retriever_name = "classic" + else: + retriever_name = source['active_docs'] + + retriever = RetrieverCreator.create_retriever(retriever_name, question=question, source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model ) @@ -247,7 +252,12 @@ def api_answer(): else: source = {data} - retriever = RetrieverCreator.create_retriever("classic", question=question, + if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + retriever_name = "classic" + else: + retriever_name = source['active_docs'] + + retriever = RetrieverCreator.create_retriever(retriever_name, question=question, source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model ) source_log_docs = [] @@ -290,9 +300,14 @@ def api_search(): else: chunks = 2 - retriever = RetrieverCreator.create_retriever("classic", question=question, - source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model - ) + if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + retriever_name = "classic" + else: + retriever_name = source['active_docs'] + + retriever = RetrieverCreator.create_retriever(retriever_name, question=question, + source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model + ) docs = retriever.search() return docs diff --git a/application/api/user/routes.py b/application/api/user/routes.py index e80ec52..b159d6c 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -237,6 +237,20 @@ def combined_json(): for index in data_remote: index["location"] = "remote" data.append(index) + if 'duckduck_search' in settings.RETRIEVERS_ENABLED: + data.append( + { + "name": "DuckDuckGo Search", + "language": "en", + "version": "", + "description": "duckduck_search", + "fullName": "DuckDuckGo Search", + "date": "duckduck_search", + "docLink": "duckduck_search", + "model": settings.EMBEDDINGS_NAME, + "location": "custom", + } + ) return jsonify(data) diff --git a/application/core/settings.py b/application/core/settings.py index 7eac3cb..b069340 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -18,6 +18,7 @@ class Settings(BaseSettings): TOKENS_MAX_HISTORY: int = 150 UPLOAD_FOLDER: str = "inputs" VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" + RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search API_URL: str = "http://localhost:7091" # backend url for celery worker diff --git a/application/requirements.txt b/application/requirements.txt index c984534..4642525 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -3,7 +3,7 @@ boto3==1.34.6 celery==5.3.6 dataclasses_json==0.6.3 docx2txt==0.8 -duckduckgo-search=5.3.0 +duckduckgo-search==5.3.0 EbookLib==0.18 elasticsearch==8.12.0 escodegen==1.0.11 diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py index 892d63a..5ec341a 100644 --- a/application/retriever/retriever_creator.py +++ b/application/retriever/retriever_creator.py @@ -6,7 +6,7 @@ from application.retriever.duckduck_search import DuckDuckSearch class RetrieverCreator: retievers = { 'classic': ClassicRAG, - 'duckduck': DuckDuckSearch + 'duckduck_search': DuckDuckSearch } @classmethod diff --git a/frontend/src/conversation/conversationApi.ts b/frontend/src/conversation/conversationApi.ts index e3a8219..0e49572 100644 --- a/frontend/src/conversation/conversationApi.ts +++ b/frontend/src/conversation/conversationApi.ts @@ -3,6 +3,33 @@ import { Doc } from '../preferences/preferenceApi'; const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com'; +function getDocPath(selectedDocs: Doc | null): string { + let docPath = 'default'; + + if (selectedDocs) { + let namePath = selectedDocs.name; + if (selectedDocs.language === namePath) { + namePath = '.project'; + } + if (selectedDocs.location === 'local') { + docPath = 'local' + '/' + selectedDocs.name + '/'; + } else if (selectedDocs.location === 'remote') { + docPath = + selectedDocs.language + + '/' + + namePath + + '/' + + selectedDocs.version + + '/' + + selectedDocs.model + + '/'; + } else if (selectedDocs.location === 'custom') { + docPath = selectedDocs.docLink; + } + } + + return docPath; +} export function fetchAnswerApi( question: string, signal: AbortSignal, @@ -28,27 +55,7 @@ export function fetchAnswerApi( title: any; } > { - let docPath = 'default'; - - if (selectedDocs) { - let namePath = selectedDocs.name; - if (selectedDocs.language === namePath) { - namePath = '.project'; - } - if (selectedDocs.location === 'local') { - docPath = 'local' + '/' + selectedDocs.name + '/'; - } else if (selectedDocs.location === 'remote') { - docPath = - selectedDocs.language + - '/' + - namePath + - '/' + - selectedDocs.version + - '/' + - selectedDocs.model + - '/'; - } - } + const docPath = getDocPath(selectedDocs); //in history array remove all keys except prompt and response history = history.map((item) => { return { prompt: item.prompt, response: item.response }; @@ -98,27 +105,7 @@ export function fetchAnswerSteaming( chunks: string, onEvent: (event: MessageEvent) => void, ): Promise { - let docPath = 'default'; - - if (selectedDocs) { - let namePath = selectedDocs.name; - if (selectedDocs.language === namePath) { - namePath = '.project'; - } - if (selectedDocs.location === 'local') { - docPath = 'local' + '/' + selectedDocs.name + '/'; - } else if (selectedDocs.location === 'remote') { - docPath = - selectedDocs.language + - '/' + - namePath + - '/' + - selectedDocs.version + - '/' + - selectedDocs.model + - '/'; - } - } + const docPath = getDocPath(selectedDocs); history = history.map((item) => { return { prompt: item.prompt, response: item.response }; @@ -195,31 +182,7 @@ export function searchEndpoint( history: Array = [], chunks: string, ) { - /* - "active_docs": "default", - "question": "Summarise", - "conversation_id": null, - "history": "[]" */ - let docPath = 'default'; - if (selectedDocs) { - let namePath = selectedDocs.name; - if (selectedDocs.language === namePath) { - namePath = '.project'; - } - if (selectedDocs.location === 'local') { - docPath = 'local' + '/' + selectedDocs.name + '/'; - } else if (selectedDocs.location === 'remote') { - docPath = - selectedDocs.language + - '/' + - namePath + - '/' + - selectedDocs.version + - '/' + - selectedDocs.model + - '/'; - } - } + const docPath = getDocPath(selectedDocs); const body = { question: question, From 7a02df558849c43f9e257a7d5bdca5c74ed5b0fc Mon Sep 17 00:00:00 2001 From: Pavel Date: Tue, 9 Apr 2024 19:56:07 +0400 Subject: [PATCH 4/7] Multiple uploads --- application/api/user/routes.py | 56 +++++++++++++++++++++------------- application/worker.py | 36 +++++++++++++++++++--- frontend/src/upload/Upload.tsx | 2 +- 3 files changed, 67 insertions(+), 27 deletions(-) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index e80ec52..7e5462b 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1,5 +1,6 @@ import os import uuid +import shutil from flask import Blueprint, request, jsonify from urllib.parse import urlparse import requests @@ -136,30 +137,43 @@ def upload_file(): return {"status": "no name"} job_name = secure_filename(request.form["name"]) # check if the post request has the file part - if "file" not in request.files: - print("No file part") - return {"status": "no file"} - file = request.files["file"] - if file.filename == "": + files = request.files.getlist("file") + + if not files or all(file.filename == '' for file in files): return {"status": "no file name"} - if file: - filename = secure_filename(file.filename) - # save dir - save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name) - # create dir if not exists - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - file.save(os.path.join(save_dir, filename)) - task = ingest.delay(settings.UPLOAD_FOLDER, [".rst", ".md", ".pdf", ".txt", ".docx", - ".csv", ".epub", ".html", ".mdx"], - job_name, filename, user) - # task id - task_id = task.id - return {"status": "ok", "task_id": task_id} + # Directory where files will be saved + save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name) + os.makedirs(save_dir, exist_ok=True) + + if len(files) > 1: + # Multiple files; prepare them for zip + temp_dir = os.path.join(save_dir, "temp") + os.makedirs(temp_dir, exist_ok=True) + + for file in files: + filename = secure_filename(file.filename) + file.save(os.path.join(temp_dir, filename)) + + # Use shutil.make_archive to zip the temp directory + zip_path = shutil.make_archive(base_name=os.path.join(save_dir, job_name), format='zip', root_dir=temp_dir) + final_filename = os.path.basename(zip_path) + + # Clean up the temporary directory after zipping + shutil.rmtree(temp_dir) else: - return {"status": "error"} + # Single file + file = files[0] + final_filename = secure_filename(file.filename) + file_path = os.path.join(save_dir, final_filename) + file.save(file_path) + + # Call ingest with the single file or zipped file + task = ingest.delay(settings.UPLOAD_FOLDER, [".rst", ".md", ".pdf", ".txt", ".docx", + ".csv", ".epub", ".html", ".mdx"], + job_name, final_filename, user) + + return {"status": "ok", "task_id": task.id} @user.route("/api/remote", methods=["POST"]) def upload_remote(): diff --git a/application/worker.py b/application/worker.py index 3891fde..eb28242 100644 --- a/application/worker.py +++ b/application/worker.py @@ -36,6 +36,32 @@ current_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) +def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5): + """ + Recursively extract zip files with a limit on recursion depth. + + Args: + zip_path (str): Path to the zip file to be extracted. + extract_to (str): Destination path for extracted files. + current_depth (int): Current depth of recursion. + max_depth (int): Maximum allowed depth of recursion to prevent infinite loops. + """ + if current_depth > max_depth: + print(f"Reached maximum recursion depth of {max_depth}") + return + + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(extract_to) + os.remove(zip_path) # Remove the zip file after extracting + + # Check for nested zip files and extract them + for root, dirs, files in os.walk(extract_to): + for file in files: + if file.endswith(".zip"): + # If a nested zip file is found, extract it recursively + file_path = os.path.join(root, file) + extract_zip_recursive(file_path, root, current_depth + 1, max_depth) + # Define the main function for ingesting and processing documents. def ingest_worker(self, directory, formats, name_job, filename, user): @@ -66,9 +92,11 @@ def ingest_worker(self, directory, formats, name_job, filename, user): token_check = True min_tokens = 150 max_tokens = 1250 - full_path = directory + "/" + user + "/" + name_job + recursion_depth = 2 + full_path = os.path.join(directory, user, name_job) import sys + print(full_path, file=sys.stderr) # check if API_URL env variable is set file_data = {"name": name_job, "file": filename, "user": user} @@ -81,14 +109,12 @@ def ingest_worker(self, directory, formats, name_job, filename, user): if not os.path.exists(full_path): os.makedirs(full_path) - with open(full_path + "/" + filename, "wb") as f: + with open(os.path.join(full_path, filename), "wb") as f: f.write(file) # check if file is .zip and extract it if filename.endswith(".zip"): - with zipfile.ZipFile(full_path + "/" + filename, "r") as zip_ref: - zip_ref.extractall(full_path) - os.remove(full_path + "/" + filename) + extract_zip_recursive(os.path.join(full_path, filename), full_path, 0, recursion_depth) self.update_state(state="PROGRESS", meta={"current": 1}) diff --git a/frontend/src/upload/Upload.tsx b/frontend/src/upload/Upload.tsx index 39c2a09..3ae2178 100644 --- a/frontend/src/upload/Upload.tsx +++ b/frontend/src/upload/Upload.tsx @@ -201,7 +201,7 @@ export default function Upload({ const { getRootProps, getInputProps, isDragActive } = useDropzone({ onDrop, - multiple: false, + multiple: true, onDragEnter: doNothing, onDragOver: doNothing, onDragLeave: doNothing, From e03e185d30ded4215ed3045994902ab8c1936fec Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 9 Apr 2024 17:11:09 +0100 Subject: [PATCH 5/7] Add Brave Search retriever and update application files --- application/api/user/routes.py | 14 ++++ application/core/settings.py | 2 + application/retriever/brave_search.py | 75 ++++++++++++++++++++++ application/retriever/duckduck_search.py | 3 - application/retriever/retriever_creator.py | 4 +- 5 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 application/retriever/brave_search.py diff --git a/application/api/user/routes.py b/application/api/user/routes.py index b159d6c..3222832 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -251,6 +251,20 @@ def combined_json(): "location": "custom", } ) + if 'brave_search' in settings.RETRIEVERS_ENABLED: + data.append( + { + "name": "Brave Search", + "language": "en", + "version": "", + "description": "brave_search", + "fullName": "Brave Search", + "date": "brave_search", + "docLink": "brave_search", + "model": settings.EMBEDDINGS_NAME, + "location": "custom", + } + ) return jsonify(data) diff --git a/application/core/settings.py b/application/core/settings.py index b069340..d8d0eb3 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -60,6 +60,8 @@ class Settings(BaseSettings): QDRANT_PATH: Optional[str] = None QDRANT_DISTANCE_FUNC: str = "Cosine" + BRAVE_SEARCH_API_KEY: Optional[str] = None + FLASK_DEBUG_MODE: bool = False diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py new file mode 100644 index 0000000..0cc7bd4 --- /dev/null +++ b/application/retriever/brave_search.py @@ -0,0 +1,75 @@ +import json +from application.retriever.base import BaseRetriever +from application.core.settings import settings +from application.llm.llm_creator import LLMCreator +from application.utils import count_tokens +from langchain_community.tools import BraveSearch + + + +class BraveRetSearch(BaseRetriever): + + def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + self.question = question + self.source = source + self.chat_history = chat_history + self.prompt = prompt + self.chunks = chunks + self.gpt_model = gpt_model + + def _get_data(self): + if self.chunks == 0: + docs = [] + else: + search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY, + search_kwargs={"count": int(self.chunks)}) + results = search.run(self.question) + results = json.loads(results) + + docs = [] + for i in results: + try: + title = i['title'] + link = i['link'] + snippet = i['snippet'] + docs.append({"text": snippet, "title": title, "link": link}) + except IndexError: + pass + if settings.LLM_NAME == "llama.cpp": + docs = [docs[0]] + + return docs + + def gen(self): + docs = self._get_data() + + # join all page_content together with a newline + docs_together = "\n".join([doc["text"] for doc in docs]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + for doc in docs: + yield {"source": doc} + + if len(self.chat_history) > 1: + tokens_current_history = 0 + # count tokens in history + self.chat_history.reverse() + for i in self.chat_history: + if "prompt" in i and "response" in i: + tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) + if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_current_history += tokens_batch + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append({"role": "user", "content": self.question}) + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + completion = llm.gen_stream(model=self.gpt_model, + messages=messages_combine) + for line in completion: + yield {"answer": str(line)} + + def search(self): + return self._get_data() + diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index 778313c..d662bb0 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -1,5 +1,3 @@ -import json -import ast from application.retriever.base import BaseRetriever from application.core.settings import settings from application.llm.llm_creator import LLMCreator @@ -33,7 +31,6 @@ class DuckDuckSearch(BaseRetriever): elif inside_brackets: current_item += char - # Check if there is an unmatched opening bracket at the end if inside_brackets: result.append(current_item) diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py index 5ec341a..ad07140 100644 --- a/application/retriever/retriever_creator.py +++ b/application/retriever/retriever_creator.py @@ -1,12 +1,14 @@ from application.retriever.classic_rag import ClassicRAG from application.retriever.duckduck_search import DuckDuckSearch +from application.retriever.brave_search import BraveRetSearch class RetrieverCreator: retievers = { 'classic': ClassicRAG, - 'duckduck_search': DuckDuckSearch + 'duckduck_search': DuckDuckSearch, + 'brave_search': BraveRetSearch } @classmethod From 4b849d720142e081fbace0868abe1ecef9187e68 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 9 Apr 2024 17:20:26 +0100 Subject: [PATCH 6/7] Fix SagemakerAPILLM test --- tests/llm/test_sagemaker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py index f8d02d8..0602f59 100644 --- a/tests/llm/test_sagemaker.py +++ b/tests/llm/test_sagemaker.py @@ -54,7 +54,7 @@ class TestSagemakerAPILLM(unittest.TestCase): def test_gen(self): with patch.object(self.sagemaker.runtime, 'invoke_endpoint', return_value=self.response) as mock_invoke_endpoint: - output = self.sagemaker.gen(None, None, self.messages) + output = self.sagemaker.gen(None, self.messages) mock_invoke_endpoint.assert_called_once_with( EndpointName=self.sagemaker.endpoint, ContentType='application/json', @@ -66,7 +66,7 @@ class TestSagemakerAPILLM(unittest.TestCase): def test_gen_stream(self): with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', return_value=self.response) as mock_invoke_endpoint: - output = list(self.sagemaker.gen_stream(None, None, self.messages)) + output = list(self.sagemaker.gen_stream(None, self.messages)) mock_invoke_endpoint.assert_called_once_with( EndpointName=self.sagemaker.endpoint, ContentType='application/json', From 8d7a134cb40502b0bd8474a1ed603da2ced8ac08 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 9 Apr 2024 17:25:08 +0100 Subject: [PATCH 7/7] lint: ruff --- application/api/answer/routes.py | 2 -- application/api/user/routes.py | 9 +++++---- application/core/settings.py | 2 +- application/parser/token_func.py | 5 ++++- application/retriever/classic_rag.py | 15 ++++++++++++--- 5 files changed, 22 insertions(+), 11 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 97eb36c..fa0ac4f 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -8,12 +8,10 @@ import traceback from pymongo import MongoClient from bson.objectid import ObjectId -from application.utils import count_tokens from application.core.settings import settings -from application.vectorstore.vector_creator import VectorCreator from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator from application.error import bad_request diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 3222832..cacfbd7 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -283,10 +283,12 @@ def check_docs(): else: file_url = urlparse(base_path + vectorstore + "index.faiss") - if file_url.scheme in ['https'] and file_url.netloc == 'raw.githubusercontent.com' and file_url.path.startswith('/arc53/DocsHUB/main/'): - + if ( + file_url.scheme in ['https'] and + file_url.netloc == 'raw.githubusercontent.com' and + file_url.path.startswith('/arc53/DocsHUB/main/') + ): r = requests.get(file_url.geturl()) - if r.status_code != 200: return {"status": "null"} else: @@ -295,7 +297,6 @@ def check_docs(): with open(vectorstore + "index.faiss", "wb") as f: f.write(r.content) - # download the store r = requests.get(base_path + vectorstore + "index.pkl") with open(vectorstore + "index.pkl", "wb") as f: f.write(r.content) diff --git a/application/core/settings.py b/application/core/settings.py index d8d0eb3..26c27ed 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -9,7 +9,7 @@ current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__ class Settings(BaseSettings): LLM_NAME: str = "docsgpt" - MODEL_NAME: Optional[str] = None # when LLM_NAME is openai, MODEL_NAME can be e.g. gpt-4-turbo-preview or gpt-3.5-turbo + MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2" CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" diff --git a/application/parser/token_func.py b/application/parser/token_func.py index 36ae7e5..7511cde 100644 --- a/application/parser/token_func.py +++ b/application/parser/token_func.py @@ -22,7 +22,10 @@ def group_documents(documents: List[Document], min_tokens: int, max_tokens: int) doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text)) # Check if current group is empty or if the document can be added based on token count and matching metadata - if current_group is None or (len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and doc_len < min_tokens and current_group.extra_info == doc.extra_info): + if (current_group is None or + (len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and + doc_len < min_tokens and + current_group.extra_info == doc.extra_info)): if current_group is None: current_group = doc # Use the document directly to retain its metadata else: diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index a5bf8e3..b5f1eb9 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,5 +1,4 @@ import os -import json from application.retriever.base import BaseRetriever from application.core.settings import settings from application.vectorstore.vector_creator import VectorCreator @@ -39,9 +38,19 @@ class ClassicRAG(BaseRetriever): if self.chunks == 0: docs = [] else: - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY) + docsearch = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, + self.vectorstore, + settings.EMBEDDINGS_KEY + ) docs_temp = docsearch.search(self.question, k=self.chunks) - docs = [{"title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, "text": i.page_content} for i in docs_temp] + docs = [ + { + "title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, + "text": i.page_content + } + for i in docs_temp + ] if settings.LLM_NAME == "llama.cpp": docs = [docs[0]]