Add Brave Search retriever and update application files

pull/915/head
Alex 1 month ago
parent 19494685ba
commit e03e185d30

@ -251,6 +251,20 @@ def combined_json():
"location": "custom",
}
)
if 'brave_search' in settings.RETRIEVERS_ENABLED:
data.append(
{
"name": "Brave Search",
"language": "en",
"version": "",
"description": "brave_search",
"fullName": "Brave Search",
"date": "brave_search",
"docLink": "brave_search",
"model": settings.EMBEDDINGS_NAME,
"location": "custom",
}
)
return jsonify(data)

@ -60,6 +60,8 @@ class Settings(BaseSettings):
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
BRAVE_SEARCH_API_KEY: Optional[str] = None
FLASK_DEBUG_MODE: bool = False

@ -0,0 +1,75 @@
import json
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens
from langchain_community.tools import BraveSearch
class BraveRetSearch(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
self.question = question
self.source = source
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
def _get_data(self):
if self.chunks == 0:
docs = []
else:
search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY,
search_kwargs={"count": int(self.chunks)})
results = search.run(self.question)
results = json.loads(results)
docs = []
for i in results:
try:
title = i['title']
link = i['link']
snippet = i['snippet']
docs.append({"text": snippet, "title": title, "link": link})
except IndexError:
pass
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
completion = llm.gen_stream(model=self.gpt_model,
messages=messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()

@ -1,5 +1,3 @@
import json
import ast
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
@ -33,7 +31,6 @@ class DuckDuckSearch(BaseRetriever):
elif inside_brackets:
current_item += char
# Check if there is an unmatched opening bracket at the end
if inside_brackets:
result.append(current_item)

@ -1,12 +1,14 @@
from application.retriever.classic_rag import ClassicRAG
from application.retriever.duckduck_search import DuckDuckSearch
from application.retriever.brave_search import BraveRetSearch
class RetrieverCreator:
retievers = {
'classic': ClassicRAG,
'duckduck_search': DuckDuckSearch
'duckduck_search': DuckDuckSearch,
'brave_search': BraveRetSearch
}
@classmethod

Loading…
Cancel
Save