fix: api_key capturing + pytest errors

pull/926/head
Siddhant Rai 2 months ago
parent 60a670ce29
commit 77991896b4

@ -10,14 +10,12 @@ from pymongo import MongoClient
from bson.objectid import ObjectId from bson.objectid import ObjectId
from application.core.settings import settings from application.core.settings import settings
from application.llm.llm_creator import LLMCreator from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator from application.retriever.retriever_creator import RetrieverCreator
from application.error import bad_request from application.error import bad_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
mongo = MongoClient(settings.MONGO_URI) mongo = MongoClient(settings.MONGO_URI)
@ -26,20 +24,22 @@ conversations_collection = db["conversations"]
vectors_collection = db["vectors"] vectors_collection = db["vectors"]
prompts_collection = db["prompts"] prompts_collection = db["prompts"]
api_key_collection = db["api_keys"] api_key_collection = db["api_keys"]
answer = Blueprint('answer', __name__) answer = Blueprint("answer", __name__)
gpt_model = "" gpt_model = ""
# to have some kind of default behaviour # to have some kind of default behaviour
if settings.LLM_NAME == "openai": if settings.LLM_NAME == "openai":
gpt_model = 'gpt-3.5-turbo' gpt_model = "gpt-3.5-turbo"
elif settings.LLM_NAME == "anthropic": elif settings.LLM_NAME == "anthropic":
gpt_model = 'claude-2' gpt_model = "claude-2"
if settings.MODEL_NAME: # in case there is particular model name configured if settings.MODEL_NAME: # in case there is particular model name configured
gpt_model = settings.MODEL_NAME gpt_model = settings.MODEL_NAME
# load the prompts # load the prompts
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read() chat_combine_template = f.read()
@ -50,7 +50,7 @@ with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r"
chat_combine_creative = f.read() chat_combine_creative = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
chat_combine_strict = f.read() chat_combine_strict = f.read()
api_key_set = settings.API_KEY is not None api_key_set = settings.API_KEY is not None
embeddings_key_set = settings.EMBEDDINGS_KEY is not None embeddings_key_set = settings.EMBEDDINGS_KEY is not None
@ -61,8 +61,6 @@ async def async_generate(chain, question, chat_history):
return result return result
def run_async_chain(chain, question, chat_history): def run_async_chain(chain, question, chat_history):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
@ -74,17 +72,18 @@ def run_async_chain(chain, question, chat_history):
result["answer"] = answer result["answer"] = answer
return result return result
def get_data_from_api_key(api_key): def get_data_from_api_key(api_key):
data = api_key_collection.find_one({"key": api_key}) data = api_key_collection.find_one({"key": api_key})
if data is None: if data is None:
return bad_request(401, "Invalid API key") return bad_request(401, "Invalid API key")
return data return data
def get_vectorstore(data): def get_vectorstore(data):
if "active_docs" in data: if "active_docs" in data:
if data["active_docs"].split("/")[0] == "default": if data["active_docs"].split("/")[0] == "default":
vectorstore = "" vectorstore = ""
elif data["active_docs"].split("/")[0] == "local": elif data["active_docs"].split("/")[0] == "local":
vectorstore = "indexes/" + data["active_docs"] vectorstore = "indexes/" + data["active_docs"]
else: else:
@ -98,52 +97,82 @@ def get_vectorstore(data):
def is_azure_configured(): def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
def save_conversation(conversation_id, question, response, source_log_docs, llm): def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None and conversation_id != "None": if conversation_id is not None and conversation_id != "None":
conversations_collection.update_one( conversations_collection.update_one(
{"_id": ObjectId(conversation_id)}, {"_id": ObjectId(conversation_id)},
{"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}}, {
"$push": {
"queries": {
"prompt": question,
"response": response,
"sources": source_log_docs,
}
}
},
) )
else: else:
# create new conversation # create new conversation
# generate summary # generate summary
messages_summary = [{"role": "assistant", "content": "Summarise following conversation in no more than 3 " messages_summary = [
"words, respond ONLY with the summary, use the same " {
"language as the system \n\nUser: " + question + "\n\n" + "role": "assistant",
"AI: " + "content": "Summarise following conversation in no more than 3 "
response}, "words, respond ONLY with the summary, use the same "
{"role": "user", "content": "Summarise following conversation in no more than 3 words, " "language as the system \n\nUser: "
"respond ONLY with the summary, use the same language as the " + question
"system"}] + "\n\n"
+ "AI: "
completion = llm.gen(model=gpt_model, + response,
messages=messages_summary, max_tokens=30) },
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system",
},
]
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
conversation_id = conversations_collection.insert_one( conversation_id = conversations_collection.insert_one(
{"user": "local", {
"date": datetime.datetime.utcnow(), "user": "local",
"name": completion, "date": datetime.datetime.utcnow(),
"queries": [{"prompt": question, "response": response, "sources": source_log_docs}]} "name": completion,
"queries": [
{
"prompt": question,
"response": response,
"sources": source_log_docs,
}
],
}
).inserted_id ).inserted_id
return conversation_id return conversation_id
def get_prompt(prompt_id): def get_prompt(prompt_id):
if prompt_id == 'default': if prompt_id == "default":
prompt = chat_combine_template prompt = chat_combine_template
elif prompt_id == 'creative': elif prompt_id == "creative":
prompt = chat_combine_creative prompt = chat_combine_creative
elif prompt_id == 'strict': elif prompt_id == "strict":
prompt = chat_combine_strict prompt = chat_combine_strict
else: else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt return prompt
def complete_stream(question, retriever, conversation_id): def complete_stream(question, retriever, conversation_id, user_api_key):
response_full = "" response_full = ""
source_log_docs = [] source_log_docs = []
answer = retriever.gen() answer = retriever.gen()
@ -155,9 +184,10 @@ def complete_stream(question, retriever, conversation_id):
elif "source" in line: elif "source" in line:
source_log_docs.append(line["source"]) source_log_docs.append(line["source"])
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key)
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) conversation_id = save_conversation(
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm) conversation_id, question, response_full, source_log_docs, llm
)
# send data.type = "end" to indicate that the stream has ended as json # send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)}) data = json.dumps({"type": "id", "id": str(conversation_id)})
@ -180,17 +210,17 @@ def stream():
conversation_id = None conversation_id = None
else: else:
conversation_id = data["conversation_id"] conversation_id = data["conversation_id"]
if 'prompt_id' in data: if "prompt_id" in data:
prompt_id = data["prompt_id"] prompt_id = data["prompt_id"]
else: else:
prompt_id = 'default' prompt_id = "default"
if 'selectedDocs' in data and data['selectedDocs'] is None: if "selectedDocs" in data and data["selectedDocs"] is None:
chunks = 0 chunks = 0
elif 'chunks' in data: elif "chunks" in data:
chunks = int(data["chunks"]) chunks = int(data["chunks"])
else: else:
chunks = 2 chunks = 2
prompt = get_prompt(prompt_id) prompt = get_prompt(prompt_id)
# check if active_docs is set # check if active_docs is set
@ -198,23 +228,42 @@ def stream():
if "api_key" in data: if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"]) data_key = get_data_from_api_key(data["api_key"])
source = {"active_docs": data_key["source"]} source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
elif "active_docs" in data: elif "active_docs" in data:
source = {"active_docs": data["active_docs"]} source = {"active_docs": data["active_docs"]}
user_api_key = None
else: else:
source = {} source = {}
user_api_key = None
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
retriever_name = "classic" retriever_name = "classic"
else: else:
retriever_name = source['active_docs'] retriever_name = source["active_docs"]
retriever = RetrieverCreator.create_retriever(retriever_name, question=question, retriever = RetrieverCreator.create_retriever(
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model retriever_name,
) question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
gpt_model=gpt_model,
api_key=user_api_key,
)
return Response( return Response(
complete_stream(question=question, retriever=retriever, complete_stream(
conversation_id=conversation_id), mimetype="text/event-stream") question=question,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
),
mimetype="text/event-stream",
)
@answer.route("/api/answer", methods=["POST"]) @answer.route("/api/answer", methods=["POST"])
@ -230,15 +279,15 @@ def api_answer():
else: else:
conversation_id = data["conversation_id"] conversation_id = data["conversation_id"]
print("-" * 5) print("-" * 5)
if 'prompt_id' in data: if "prompt_id" in data:
prompt_id = data["prompt_id"] prompt_id = data["prompt_id"]
else: else:
prompt_id = 'default' prompt_id = "default"
if 'chunks' in data: if "chunks" in data:
chunks = int(data["chunks"]) chunks = int(data["chunks"])
else: else:
chunks = 2 chunks = 2
prompt = get_prompt(prompt_id) prompt = get_prompt(prompt_id)
# use try and except to check for exception # use try and except to check for exception
@ -247,17 +296,29 @@ def api_answer():
if "api_key" in data: if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"]) data_key = get_data_from_api_key(data["api_key"])
source = {"active_docs": data_key["source"]} source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
else: else:
source = {data} source = {data}
user_api_key = None
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
retriever_name = "classic" retriever_name = "classic"
else: else:
retriever_name = source['active_docs'] retriever_name = source["active_docs"]
retriever = RetrieverCreator.create_retriever(retriever_name, question=question, retriever = RetrieverCreator.create_retriever(
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model retriever_name,
) question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
gpt_model=gpt_model,
api_key=user_api_key,
)
source_log_docs = [] source_log_docs = []
response_full = "" response_full = ""
for line in retriever.gen(): for line in retriever.gen():
@ -265,12 +326,13 @@ def api_answer():
source_log_docs.append(line["source"]) source_log_docs.append(line["source"])
elif "answer" in line: elif "answer" in line:
response_full += line["answer"] response_full += line["answer"]
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key)
result = {"answer": response_full, "sources": source_log_docs} result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm) result["conversation_id"] = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
return result return result
except Exception as e: except Exception as e:
@ -289,23 +351,35 @@ def api_search():
if "api_key" in data: if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"]) data_key = get_data_from_api_key(data["api_key"])
source = {"active_docs": data_key["source"]} source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
elif "active_docs" in data: elif "active_docs" in data:
source = {"active_docs": data["active_docs"]} source = {"active_docs": data["active_docs"]}
user_api_key = None
else: else:
source = {} source = {}
if 'chunks' in data: user_api_key = None
if "chunks" in data:
chunks = int(data["chunks"]) chunks = int(data["chunks"])
else: else:
chunks = 2 chunks = 2
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
retriever_name = "classic" retriever_name = "classic"
else: else:
retriever_name = source['active_docs'] retriever_name = source["active_docs"]
retriever = RetrieverCreator.create_retriever(retriever_name, question=question, retriever = RetrieverCreator.create_retriever(
source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model retriever_name,
) question=question,
source=source,
chat_history=[],
prompt="default",
chunks=chunks,
gpt_model=gpt_model,
api_key=user_api_key,
)
docs = retriever.search() docs = retriever.search()
return docs return docs

@ -15,7 +15,9 @@ class AnthropicLLM(BaseLLM):
self.HUMAN_PROMPT = HUMAN_PROMPT self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT self.AI_PROMPT = AI_PROMPT
def _raw_gen(self, model, messages, max_tokens=300, stream=False, **kwargs): def _raw_gen(
self, baseself, model, messages, max_tokens=300, stream=False, **kwargs
):
context = messages[0]["content"] context = messages[0]["content"]
user_question = messages[-1]["content"] user_question = messages[-1]["content"]
prompt = f"### Context \n {context} \n ### Question \n {user_question}" prompt = f"### Context \n {context} \n ### Question \n {user_question}"
@ -30,7 +32,7 @@ class AnthropicLLM(BaseLLM):
) )
return completion.completion return completion.completion
def _raw_gen_stream(self, model, messages, max_tokens=300, **kwargs): def _raw_gen_stream(self, baseself, model, messages, max_tokens=300, **kwargs):
context = messages[0]["content"] context = messages[0]["content"]
user_question = messages[-1]["content"] user_question = messages[-1]["content"]
prompt = f"### Context \n {context} \n ### Question \n {user_question}" prompt = f"### Context \n {context} \n ### Question \n {user_question}"

@ -5,7 +5,7 @@ import requests
class DocsGPTAPILLM(BaseLLM): class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key, *args, **kwargs): def __init__(self, api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.api_key = api_key self.api_key = api_key
self.endpoint = "https://llm.docsgpt.co.uk" self.endpoint = "https://llm.docsgpt.co.uk"

@ -3,7 +3,9 @@ from application.llm.base import BaseLLM
class HuggingFaceLLM(BaseLLM): class HuggingFaceLLM(BaseLLM):
def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs): def __init__(
self, api_key=None, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs
):
global hf global hf
from langchain.llms import HuggingFacePipeline from langchain.llms import HuggingFacePipeline
@ -45,7 +47,7 @@ class HuggingFaceLLM(BaseLLM):
) )
hf = HuggingFacePipeline(pipeline=pipe) hf = HuggingFacePipeline(pipeline=pipe)
def _raw_gen(self, model, messages, stream=False, **kwargs): def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"] context = messages[0]["content"]
user_question = messages[-1]["content"] user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -54,6 +56,6 @@ class HuggingFaceLLM(BaseLLM):
return result.content return result.content
def _raw_gen_stream(self, model, messages, stream=True, **kwargs): def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.") raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")

@ -4,7 +4,7 @@ from application.core.settings import settings
class LlamaCpp(BaseLLM): class LlamaCpp(BaseLLM):
def __init__(self, api_key, llm_name=settings.MODEL_PATH, *args, **kwargs): def __init__(self, api_key=None, llm_name=settings.MODEL_PATH, *args, **kwargs):
global llama global llama
try: try:
from llama_cpp import Llama from llama_cpp import Llama
@ -17,7 +17,7 @@ class LlamaCpp(BaseLLM):
self.api_key = api_key self.api_key = api_key
llama = Llama(model_path=llm_name, n_ctx=2048) llama = Llama(model_path=llm_name, n_ctx=2048)
def _raw_gen(self, model, messages, stream=False, **kwargs): def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"] context = messages[0]["content"]
user_question = messages[-1]["content"] user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -29,7 +29,7 @@ class LlamaCpp(BaseLLM):
return result["choices"][0]["text"].split("### Answer \n")[-1] return result["choices"][0]["text"].split("### Answer \n")[-1]
def _raw_gen_stream(self, model, messages, stream=True, **kwargs): def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"] context = messages[0]["content"]
user_question = messages[-1]["content"] user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

@ -4,7 +4,7 @@ from application.core.settings import settings
class OpenAILLM(BaseLLM): class OpenAILLM(BaseLLM):
def __init__(self, api_key, *args, **kwargs): def __init__(self, api_key=None, *args, **kwargs):
global openai global openai
from openai import OpenAI from openai import OpenAI
@ -22,6 +22,7 @@ class OpenAILLM(BaseLLM):
def _raw_gen( def _raw_gen(
self, self,
baseself,
model, model,
messages, messages,
stream=False, stream=False,
@ -36,6 +37,7 @@ class OpenAILLM(BaseLLM):
def _raw_gen_stream( def _raw_gen_stream(
self, self,
baseself,
model, model,
messages, messages,
stream=True, stream=True,

@ -4,7 +4,7 @@ from application.core.settings import settings
class PremAILLM(BaseLLM): class PremAILLM(BaseLLM):
def __init__(self, api_key, *args, **kwargs): def __init__(self, api_key=None, *args, **kwargs):
from premai import Prem from premai import Prem
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -12,7 +12,7 @@ class PremAILLM(BaseLLM):
self.api_key = api_key self.api_key = api_key
self.project_id = settings.PREMAI_PROJECT_ID self.project_id = settings.PREMAI_PROJECT_ID
def _raw_gen(self, model, messages, stream=False, **kwargs): def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=model, model=model,
project_id=self.project_id, project_id=self.project_id,
@ -23,7 +23,7 @@ class PremAILLM(BaseLLM):
return response.choices[0].message["content"] return response.choices[0].message["content"]
def _raw_gen_stream(self, model, messages, stream=True, **kwargs): def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=model, model=model,
project_id=self.project_id, project_id=self.project_id,

@ -60,7 +60,7 @@ class LineIterator:
class SagemakerAPILLM(BaseLLM): class SagemakerAPILLM(BaseLLM):
def __init__(self, api_key, *args, **kwargs): def __init__(self, api_key=None, *args, **kwargs):
import boto3 import boto3
runtime = boto3.client( runtime = boto3.client(
@ -75,7 +75,7 @@ class SagemakerAPILLM(BaseLLM):
self.endpoint = settings.SAGEMAKER_ENDPOINT self.endpoint = settings.SAGEMAKER_ENDPOINT
self.runtime = runtime self.runtime = runtime
def _raw_gen(self, model, messages, stream=False, **kwargs): def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"] context = messages[0]["content"]
user_question = messages[-1]["content"] user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -104,7 +104,7 @@ class SagemakerAPILLM(BaseLLM):
print(result[0]["generated_text"], file=sys.stderr) print(result[0]["generated_text"], file=sys.stderr)
return result[0]["generated_text"][len(prompt) :] return result[0]["generated_text"][len(prompt) :]
def _raw_gen_stream(self, model, messages, stream=True, **kwargs): def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"] context = messages[0]["content"]
user_question = messages[-1]["content"] user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

@ -6,43 +6,54 @@ from application.utils import count_tokens
from langchain_community.tools import BraveSearch from langchain_community.tools import BraveSearch
class BraveRetSearch(BaseRetriever): class BraveRetSearch(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): def __init__(
self,
question,
source,
chat_history,
prompt,
chunks=2,
gpt_model="docsgpt",
api_key=None,
):
self.question = question self.question = question
self.source = source self.source = source
self.chat_history = chat_history self.chat_history = chat_history
self.prompt = prompt self.prompt = prompt
self.chunks = chunks self.chunks = chunks
self.gpt_model = gpt_model self.gpt_model = gpt_model
self.api_key = api_key
def _get_data(self): def _get_data(self):
if self.chunks == 0: if self.chunks == 0:
docs = [] docs = []
else: else:
search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY, search = BraveSearch.from_api_key(
search_kwargs={"count": int(self.chunks)}) api_key=settings.BRAVE_SEARCH_API_KEY,
search_kwargs={"count": int(self.chunks)},
)
results = search.run(self.question) results = search.run(self.question)
results = json.loads(results) results = json.loads(results)
docs = [] docs = []
for i in results: for i in results:
try: try:
title = i['title'] title = i["title"]
link = i['link'] link = i["link"]
snippet = i['snippet'] snippet = i["snippet"]
docs.append({"text": snippet, "title": title, "link": link}) docs.append({"text": snippet, "title": title, "link": link})
except IndexError: except IndexError:
pass pass
if settings.LLM_NAME == "llama.cpp": if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]] docs = [docs[0]]
return docs return docs
def gen(self): def gen(self):
docs = self._get_data() docs = self._get_data()
# join all page_content together with a newline # join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs]) docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together) p_chat_combine = self.prompt.replace("{summaries}", docs_together)
@ -56,20 +67,27 @@ class BraveRetSearch(BaseRetriever):
self.chat_history.reverse() self.chat_history.reverse()
for i in self.chat_history: for i in self.chat_history:
if "prompt" in i and "response" in i: if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) tokens_batch = count_tokens(i["prompt"]) + count_tokens(
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: i["response"]
)
if (
tokens_current_history + tokens_batch
< settings.TOKENS_MAX_HISTORY
):
tokens_current_history += tokens_batch tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]}) messages_combine.append(
messages_combine.append({"role": "system", "content": i["response"]}) {"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question}) messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key)
completion = llm.gen_stream(model=self.gpt_model, completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
messages=messages_combine)
for line in completion: for line in completion:
yield {"answer": str(line)} yield {"answer": str(line)}
def search(self): def search(self):
return self._get_data() return self._get_data()

@ -7,21 +7,30 @@ from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens from application.utils import count_tokens
class ClassicRAG(BaseRetriever): class ClassicRAG(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): def __init__(
self,
question,
source,
chat_history,
prompt,
chunks=2,
gpt_model="docsgpt",
api_key=None,
):
self.question = question self.question = question
self.vectorstore = self._get_vectorstore(source=source) self.vectorstore = self._get_vectorstore(source=source)
self.chat_history = chat_history self.chat_history = chat_history
self.prompt = prompt self.prompt = prompt
self.chunks = chunks self.chunks = chunks
self.gpt_model = gpt_model self.gpt_model = gpt_model
self.api_key = api_key
def _get_vectorstore(self, source): def _get_vectorstore(self, source):
if "active_docs" in source: if "active_docs" in source:
if source["active_docs"].split("/")[0] == "default": if source["active_docs"].split("/")[0] == "default":
vectorstore = "" vectorstore = ""
elif source["active_docs"].split("/")[0] == "local": elif source["active_docs"].split("/")[0] == "local":
vectorstore = "indexes/" + source["active_docs"] vectorstore = "indexes/" + source["active_docs"]
else: else:
@ -33,32 +42,33 @@ class ClassicRAG(BaseRetriever):
vectorstore = os.path.join("application", vectorstore) vectorstore = os.path.join("application", vectorstore)
return vectorstore return vectorstore
def _get_data(self): def _get_data(self):
if self.chunks == 0: if self.chunks == 0:
docs = [] docs = []
else: else:
docsearch = VectorCreator.create_vectorstore( docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
self.vectorstore,
settings.EMBEDDINGS_KEY
) )
docs_temp = docsearch.search(self.question, k=self.chunks) docs_temp = docsearch.search(self.question, k=self.chunks)
docs = [ docs = [
{ {
"title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, "title": (
"text": i.page_content i.metadata["title"].split("/")[-1]
} if i.metadata
else i.page_content
),
"text": i.page_content,
}
for i in docs_temp for i in docs_temp
] ]
if settings.LLM_NAME == "llama.cpp": if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]] docs = [docs[0]]
return docs return docs
def gen(self): def gen(self):
docs = self._get_data() docs = self._get_data()
# join all page_content together with a newline # join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs]) docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together) p_chat_combine = self.prompt.replace("{summaries}", docs_together)
@ -72,20 +82,27 @@ class ClassicRAG(BaseRetriever):
self.chat_history.reverse() self.chat_history.reverse()
for i in self.chat_history: for i in self.chat_history:
if "prompt" in i and "response" in i: if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) tokens_batch = count_tokens(i["prompt"]) + count_tokens(
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: i["response"]
)
if (
tokens_current_history + tokens_batch
< settings.TOKENS_MAX_HISTORY
):
tokens_current_history += tokens_batch tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]}) messages_combine.append(
messages_combine.append({"role": "system", "content": i["response"]}) {"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question}) messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key)
completion = llm.gen_stream(model=self.gpt_model, completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
messages=messages_combine)
for line in completion: for line in completion:
yield {"answer": str(line)} yield {"answer": str(line)}
def search(self): def search(self):
return self._get_data() return self._get_data()

@ -6,16 +6,25 @@ from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
class DuckDuckSearch(BaseRetriever): class DuckDuckSearch(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): def __init__(
self,
question,
source,
chat_history,
prompt,
chunks=2,
gpt_model="docsgpt",
api_key=None,
):
self.question = question self.question = question
self.source = source self.source = source
self.chat_history = chat_history self.chat_history = chat_history
self.prompt = prompt self.prompt = prompt
self.chunks = chunks self.chunks = chunks
self.gpt_model = gpt_model self.gpt_model = gpt_model
self.api_key = api_key
def _parse_lang_string(self, input_string): def _parse_lang_string(self, input_string):
result = [] result = []
@ -30,12 +39,12 @@ class DuckDuckSearch(BaseRetriever):
current_item = "" current_item = ""
elif inside_brackets: elif inside_brackets:
current_item += char current_item += char
if inside_brackets: if inside_brackets:
result.append(current_item) result.append(current_item)
return result return result
def _get_data(self): def _get_data(self):
if self.chunks == 0: if self.chunks == 0:
docs = [] docs = []
@ -44,7 +53,7 @@ class DuckDuckSearch(BaseRetriever):
search = DuckDuckGoSearchResults(api_wrapper=wrapper) search = DuckDuckGoSearchResults(api_wrapper=wrapper)
results = search.run(self.question) results = search.run(self.question)
results = self._parse_lang_string(results) results = self._parse_lang_string(results)
docs = [] docs = []
for i in results: for i in results:
try: try:
@ -56,12 +65,12 @@ class DuckDuckSearch(BaseRetriever):
pass pass
if settings.LLM_NAME == "llama.cpp": if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]] docs = [docs[0]]
return docs return docs
def gen(self): def gen(self):
docs = self._get_data() docs = self._get_data()
# join all page_content together with a newline # join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs]) docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together) p_chat_combine = self.prompt.replace("{summaries}", docs_together)
@ -75,20 +84,27 @@ class DuckDuckSearch(BaseRetriever):
self.chat_history.reverse() self.chat_history.reverse()
for i in self.chat_history: for i in self.chat_history:
if "prompt" in i and "response" in i: if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) tokens_batch = count_tokens(i["prompt"]) + count_tokens(
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: i["response"]
)
if (
tokens_current_history + tokens_batch
< settings.TOKENS_MAX_HISTORY
):
tokens_current_history += tokens_batch tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]}) messages_combine.append(
messages_combine.append({"role": "system", "content": i["response"]}) {"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question}) messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key)
completion = llm.gen_stream(model=self.gpt_model, completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
messages=messages_combine)
for line in completion: for line in completion:
yield {"answer": str(line)} yield {"answer": str(line)}
def search(self): def search(self):
return self._get_data() return self._get_data()

Loading…
Cancel
Save