diff --git a/application/api/user/routes.py b/application/api/user/routes.py index b159d6c..3222832 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -251,6 +251,20 @@ def combined_json(): "location": "custom", } ) + if 'brave_search' in settings.RETRIEVERS_ENABLED: + data.append( + { + "name": "Brave Search", + "language": "en", + "version": "", + "description": "brave_search", + "fullName": "Brave Search", + "date": "brave_search", + "docLink": "brave_search", + "model": settings.EMBEDDINGS_NAME, + "location": "custom", + } + ) return jsonify(data) diff --git a/application/core/settings.py b/application/core/settings.py index b069340..d8d0eb3 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -60,6 +60,8 @@ class Settings(BaseSettings): QDRANT_PATH: Optional[str] = None QDRANT_DISTANCE_FUNC: str = "Cosine" + BRAVE_SEARCH_API_KEY: Optional[str] = None + FLASK_DEBUG_MODE: bool = False diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py new file mode 100644 index 0000000..0cc7bd4 --- /dev/null +++ b/application/retriever/brave_search.py @@ -0,0 +1,75 @@ +import json +from application.retriever.base import BaseRetriever +from application.core.settings import settings +from application.llm.llm_creator import LLMCreator +from application.utils import count_tokens +from langchain_community.tools import BraveSearch + + + +class BraveRetSearch(BaseRetriever): + + def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + self.question = question + self.source = source + self.chat_history = chat_history + self.prompt = prompt + self.chunks = chunks + self.gpt_model = gpt_model + + def _get_data(self): + if self.chunks == 0: + docs = [] + else: + search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY, + search_kwargs={"count": int(self.chunks)}) + results = search.run(self.question) + results = json.loads(results) + + docs = [] + for i in results: + try: + title = i['title'] + link = i['link'] + snippet = i['snippet'] + docs.append({"text": snippet, "title": title, "link": link}) + except IndexError: + pass + if settings.LLM_NAME == "llama.cpp": + docs = [docs[0]] + + return docs + + def gen(self): + docs = self._get_data() + + # join all page_content together with a newline + docs_together = "\n".join([doc["text"] for doc in docs]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + for doc in docs: + yield {"source": doc} + + if len(self.chat_history) > 1: + tokens_current_history = 0 + # count tokens in history + self.chat_history.reverse() + for i in self.chat_history: + if "prompt" in i and "response" in i: + tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) + if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_current_history += tokens_batch + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append({"role": "user", "content": self.question}) + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + completion = llm.gen_stream(model=self.gpt_model, + messages=messages_combine) + for line in completion: + yield {"answer": str(line)} + + def search(self): + return self._get_data() + diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index 778313c..d662bb0 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -1,5 +1,3 @@ -import json -import ast from application.retriever.base import BaseRetriever from application.core.settings import settings from application.llm.llm_creator import LLMCreator @@ -33,7 +31,6 @@ class DuckDuckSearch(BaseRetriever): elif inside_brackets: current_item += char - # Check if there is an unmatched opening bracket at the end if inside_brackets: result.append(current_item) diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py index 5ec341a..ad07140 100644 --- a/application/retriever/retriever_creator.py +++ b/application/retriever/retriever_creator.py @@ -1,12 +1,14 @@ from application.retriever.classic_rag import ClassicRAG from application.retriever.duckduck_search import DuckDuckSearch +from application.retriever.brave_search import BraveRetSearch class RetrieverCreator: retievers = { 'classic': ClassicRAG, - 'duckduck_search': DuckDuckSearch + 'duckduck_search': DuckDuckSearch, + 'brave_search': BraveRetSearch } @classmethod