diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index fa0ac4f..b95dd14 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -10,14 +10,12 @@ from pymongo import MongoClient from bson.objectid import ObjectId - from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator from application.error import bad_request - logger = logging.getLogger(__name__) mongo = MongoClient(settings.MONGO_URI) @@ -26,20 +24,22 @@ conversations_collection = db["conversations"] vectors_collection = db["vectors"] prompts_collection = db["prompts"] api_key_collection = db["api_keys"] -answer = Blueprint('answer', __name__) +answer = Blueprint("answer", __name__) gpt_model = "" # to have some kind of default behaviour if settings.LLM_NAME == "openai": - gpt_model = 'gpt-3.5-turbo' + gpt_model = "gpt-3.5-turbo" elif settings.LLM_NAME == "anthropic": - gpt_model = 'claude-2' + gpt_model = "claude-2" if settings.MODEL_NAME: # in case there is particular model name configured gpt_model = settings.MODEL_NAME # load the prompts -current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +current_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: chat_combine_template = f.read() @@ -50,7 +50,7 @@ with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r" chat_combine_creative = f.read() with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: - chat_combine_strict = f.read() + chat_combine_strict = f.read() api_key_set = settings.API_KEY is not None embeddings_key_set = settings.EMBEDDINGS_KEY is not None @@ -61,8 +61,6 @@ async def async_generate(chain, question, chat_history): return result - - def run_async_chain(chain, question, chat_history): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -74,17 +72,18 @@ def run_async_chain(chain, question, chat_history): result["answer"] = answer return result + def get_data_from_api_key(api_key): data = api_key_collection.find_one({"key": api_key}) if data is None: return bad_request(401, "Invalid API key") return data - + def get_vectorstore(data): if "active_docs" in data: if data["active_docs"].split("/")[0] == "default": - vectorstore = "" + vectorstore = "" elif data["active_docs"].split("/")[0] == "local": vectorstore = "indexes/" + data["active_docs"] else: @@ -98,52 +97,82 @@ def get_vectorstore(data): def is_azure_configured(): - return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME + return ( + settings.OPENAI_API_BASE + and settings.OPENAI_API_VERSION + and settings.AZURE_DEPLOYMENT_NAME + ) + def save_conversation(conversation_id, question, response, source_log_docs, llm): if conversation_id is not None and conversation_id != "None": conversations_collection.update_one( {"_id": ObjectId(conversation_id)}, - {"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}}, + { + "$push": { + "queries": { + "prompt": question, + "response": response, + "sources": source_log_docs, + } + } + }, ) else: # create new conversation # generate summary - messages_summary = [{"role": "assistant", "content": "Summarise following conversation in no more than 3 " - "words, respond ONLY with the summary, use the same " - "language as the system \n\nUser: " + question + "\n\n" + - "AI: " + - response}, - {"role": "user", "content": "Summarise following conversation in no more than 3 words, " - "respond ONLY with the summary, use the same language as the " - "system"}] - - completion = llm.gen(model=gpt_model, - messages=messages_summary, max_tokens=30) + messages_summary = [ + { + "role": "assistant", + "content": "Summarise following conversation in no more than 3 " + "words, respond ONLY with the summary, use the same " + "language as the system \n\nUser: " + + question + + "\n\n" + + "AI: " + + response, + }, + { + "role": "user", + "content": "Summarise following conversation in no more than 3 words, " + "respond ONLY with the summary, use the same language as the " + "system", + }, + ] + + completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30) conversation_id = conversations_collection.insert_one( - {"user": "local", - "date": datetime.datetime.utcnow(), - "name": completion, - "queries": [{"prompt": question, "response": response, "sources": source_log_docs}]} + { + "user": "local", + "date": datetime.datetime.utcnow(), + "name": completion, + "queries": [ + { + "prompt": question, + "response": response, + "sources": source_log_docs, + } + ], + } ).inserted_id return conversation_id + def get_prompt(prompt_id): - if prompt_id == 'default': + if prompt_id == "default": prompt = chat_combine_template - elif prompt_id == 'creative': + elif prompt_id == "creative": prompt = chat_combine_creative - elif prompt_id == 'strict': + elif prompt_id == "strict": prompt = chat_combine_strict else: prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] return prompt -def complete_stream(question, retriever, conversation_id): - - +def complete_stream(question, retriever, conversation_id, user_api_key): + response_full = "" source_log_docs = [] answer = retriever.gen() @@ -155,9 +184,10 @@ def complete_stream(question, retriever, conversation_id): elif "source" in line: source_log_docs.append(line["source"]) - - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) - conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key) + conversation_id = save_conversation( + conversation_id, question, response_full, source_log_docs, llm + ) # send data.type = "end" to indicate that the stream has ended as json data = json.dumps({"type": "id", "id": str(conversation_id)}) @@ -180,17 +210,17 @@ def stream(): conversation_id = None else: conversation_id = data["conversation_id"] - if 'prompt_id' in data: + if "prompt_id" in data: prompt_id = data["prompt_id"] else: - prompt_id = 'default' - if 'selectedDocs' in data and data['selectedDocs'] is None: + prompt_id = "default" + if "selectedDocs" in data and data["selectedDocs"] is None: chunks = 0 - elif 'chunks' in data: + elif "chunks" in data: chunks = int(data["chunks"]) else: chunks = 2 - + prompt = get_prompt(prompt_id) # check if active_docs is set @@ -198,23 +228,42 @@ def stream(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) source = {"active_docs": data_key["source"]} + user_api_key = data["api_key"] elif "active_docs" in data: source = {"active_docs": data["active_docs"]} + user_api_key = None else: source = {} + user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + if ( + source["active_docs"].split("/")[0] == "default" + or source["active_docs"].split("/")[0] == "local" + ): retriever_name = "classic" else: - retriever_name = source['active_docs'] - - retriever = RetrieverCreator.create_retriever(retriever_name, question=question, - source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model - ) + retriever_name = source["active_docs"] + + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + gpt_model=gpt_model, + api_key=user_api_key, + ) return Response( - complete_stream(question=question, retriever=retriever, - conversation_id=conversation_id), mimetype="text/event-stream") + complete_stream( + question=question, + retriever=retriever, + conversation_id=conversation_id, + user_api_key=user_api_key, + ), + mimetype="text/event-stream", + ) @answer.route("/api/answer", methods=["POST"]) @@ -230,15 +279,15 @@ def api_answer(): else: conversation_id = data["conversation_id"] print("-" * 5) - if 'prompt_id' in data: + if "prompt_id" in data: prompt_id = data["prompt_id"] else: - prompt_id = 'default' - if 'chunks' in data: + prompt_id = "default" + if "chunks" in data: chunks = int(data["chunks"]) else: chunks = 2 - + prompt = get_prompt(prompt_id) # use try and except to check for exception @@ -247,17 +296,29 @@ def api_answer(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) source = {"active_docs": data_key["source"]} + user_api_key = data["api_key"] else: source = {data} + user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + if ( + source["active_docs"].split("/")[0] == "default" + or source["active_docs"].split("/")[0] == "local" + ): retriever_name = "classic" else: - retriever_name = source['active_docs'] - - retriever = RetrieverCreator.create_retriever(retriever_name, question=question, - source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model - ) + retriever_name = source["active_docs"] + + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + gpt_model=gpt_model, + api_key=user_api_key, + ) source_log_docs = [] response_full = "" for line in retriever.gen(): @@ -265,12 +326,13 @@ def api_answer(): source_log_docs.append(line["source"]) elif "answer" in line: response_full += line["answer"] - - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) - + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key) result = {"answer": response_full, "sources": source_log_docs} - result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm) + result["conversation_id"] = save_conversation( + conversation_id, question, response_full, source_log_docs, llm + ) return result except Exception as e: @@ -289,23 +351,35 @@ def api_search(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) source = {"active_docs": data_key["source"]} + user_api_key = data["api_key"] elif "active_docs" in data: source = {"active_docs": data["active_docs"]} + user_api_key = None else: source = {} - if 'chunks' in data: + user_api_key = None + if "chunks" in data: chunks = int(data["chunks"]) else: chunks = 2 - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + if ( + source["active_docs"].split("/")[0] == "default" + or source["active_docs"].split("/")[0] == "local" + ): retriever_name = "classic" else: - retriever_name = source['active_docs'] - - retriever = RetrieverCreator.create_retriever(retriever_name, question=question, - source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model - ) + retriever_name = source["active_docs"] + + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=[], + prompt="default", + chunks=chunks, + gpt_model=gpt_model, + api_key=user_api_key, + ) docs = retriever.search() return docs - diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index b3fde3d..31e4563 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -15,7 +15,9 @@ class AnthropicLLM(BaseLLM): self.HUMAN_PROMPT = HUMAN_PROMPT self.AI_PROMPT = AI_PROMPT - def _raw_gen(self, model, messages, max_tokens=300, stream=False, **kwargs): + def _raw_gen( + self, baseself, model, messages, max_tokens=300, stream=False, **kwargs + ): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" @@ -30,7 +32,7 @@ class AnthropicLLM(BaseLLM): ) return completion.completion - def _raw_gen_stream(self, model, messages, max_tokens=300, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, max_tokens=300, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index ffe1e31..6dc1d6c 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -5,7 +5,7 @@ import requests class DocsGPTAPILLM(BaseLLM): - def __init__(self, api_key, *args, **kwargs): + def __init__(self, api_key=None, *args, **kwargs): super().__init__(*args, **kwargs) self.api_key = api_key self.endpoint = "https://llm.docsgpt.co.uk" diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index b1118ed..0baaceb 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -3,7 +3,9 @@ from application.llm.base import BaseLLM class HuggingFaceLLM(BaseLLM): - def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs): + def __init__( + self, api_key=None, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs + ): global hf from langchain.llms import HuggingFacePipeline @@ -45,7 +47,7 @@ class HuggingFaceLLM(BaseLLM): ) hf = HuggingFacePipeline(pipeline=pipe) - def _raw_gen(self, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -54,6 +56,6 @@ class HuggingFaceLLM(BaseLLM): return result.content - def _raw_gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.") diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index 896e66f..ebaefca 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -4,7 +4,7 @@ from application.core.settings import settings class LlamaCpp(BaseLLM): - def __init__(self, api_key, llm_name=settings.MODEL_PATH, *args, **kwargs): + def __init__(self, api_key=None, llm_name=settings.MODEL_PATH, *args, **kwargs): global llama try: from llama_cpp import Llama @@ -17,7 +17,7 @@ class LlamaCpp(BaseLLM): self.api_key = api_key llama = Llama(model_path=llm_name, n_ctx=2048) - def _raw_gen(self, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -29,7 +29,7 @@ class LlamaCpp(BaseLLM): return result["choices"][0]["text"].split("### Answer \n")[-1] - def _raw_gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/llm/openai.py b/application/llm/openai.py index c741404..b0fd4c5 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -4,7 +4,7 @@ from application.core.settings import settings class OpenAILLM(BaseLLM): - def __init__(self, api_key, *args, **kwargs): + def __init__(self, api_key=None, *args, **kwargs): global openai from openai import OpenAI @@ -22,6 +22,7 @@ class OpenAILLM(BaseLLM): def _raw_gen( self, + baseself, model, messages, stream=False, @@ -36,6 +37,7 @@ class OpenAILLM(BaseLLM): def _raw_gen_stream( self, + baseself, model, messages, stream=True, diff --git a/application/llm/premai.py b/application/llm/premai.py index 203ff4d..cdb8063 100644 --- a/application/llm/premai.py +++ b/application/llm/premai.py @@ -4,7 +4,7 @@ from application.core.settings import settings class PremAILLM(BaseLLM): - def __init__(self, api_key, *args, **kwargs): + def __init__(self, api_key=None, *args, **kwargs): from premai import Prem super().__init__(*args, **kwargs) @@ -12,7 +12,7 @@ class PremAILLM(BaseLLM): self.api_key = api_key self.project_id = settings.PREMAI_PROJECT_ID - def _raw_gen(self, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): response = self.client.chat.completions.create( model=model, project_id=self.project_id, @@ -23,7 +23,7 @@ class PremAILLM(BaseLLM): return response.choices[0].message["content"] - def _raw_gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): response = self.client.chat.completions.create( model=model, project_id=self.project_id, diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 807bfa2..579eec6 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -60,7 +60,7 @@ class LineIterator: class SagemakerAPILLM(BaseLLM): - def __init__(self, api_key, *args, **kwargs): + def __init__(self, api_key=None, *args, **kwargs): import boto3 runtime = boto3.client( @@ -75,7 +75,7 @@ class SagemakerAPILLM(BaseLLM): self.endpoint = settings.SAGEMAKER_ENDPOINT self.runtime = runtime - def _raw_gen(self, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -104,7 +104,7 @@ class SagemakerAPILLM(BaseLLM): print(result[0]["generated_text"], file=sys.stderr) return result[0]["generated_text"][len(prompt) :] - def _raw_gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 0cc7bd4..d0c81b6 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -6,43 +6,54 @@ from application.utils import count_tokens from langchain_community.tools import BraveSearch - class BraveRetSearch(BaseRetriever): - def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + def __init__( + self, + question, + source, + chat_history, + prompt, + chunks=2, + gpt_model="docsgpt", + api_key=None, + ): self.question = question self.source = source self.chat_history = chat_history self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model - + self.api_key = api_key + def _get_data(self): if self.chunks == 0: docs = [] else: - search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY, - search_kwargs={"count": int(self.chunks)}) + search = BraveSearch.from_api_key( + api_key=settings.BRAVE_SEARCH_API_KEY, + search_kwargs={"count": int(self.chunks)}, + ) results = search.run(self.question) results = json.loads(results) - + docs = [] for i in results: try: - title = i['title'] - link = i['link'] - snippet = i['snippet'] + title = i["title"] + link = i["link"] + snippet = i["snippet"] docs.append({"text": snippet, "title": title, "link": link}) except IndexError: pass if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] - + return docs - + def gen(self): docs = self._get_data() - + # join all page_content together with a newline docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) @@ -56,20 +67,27 @@ class BraveRetSearch(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_batch = count_tokens(i["prompt"]) + count_tokens( + i["response"] + ) + if ( + tokens_current_history + tokens_batch + < settings.TOKENS_MAX_HISTORY + ): tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append( + {"role": "user", "content": i["prompt"]} + ) + messages_combine.append( + {"role": "system", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) - completion = llm.gen_stream(model=self.gpt_model, - messages=messages_combine) + completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: yield {"answer": str(line)} - + def search(self): return self._get_data() - diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index b5f1eb9..138971f 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -7,21 +7,30 @@ from application.llm.llm_creator import LLMCreator from application.utils import count_tokens - class ClassicRAG(BaseRetriever): - def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + def __init__( + self, + question, + source, + chat_history, + prompt, + chunks=2, + gpt_model="docsgpt", + api_key=None, + ): self.question = question self.vectorstore = self._get_vectorstore(source=source) self.chat_history = chat_history self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model + self.api_key = api_key def _get_vectorstore(self, source): if "active_docs" in source: if source["active_docs"].split("/")[0] == "default": - vectorstore = "" + vectorstore = "" elif source["active_docs"].split("/")[0] == "local": vectorstore = "indexes/" + source["active_docs"] else: @@ -33,32 +42,33 @@ class ClassicRAG(BaseRetriever): vectorstore = os.path.join("application", vectorstore) return vectorstore - def _get_data(self): if self.chunks == 0: docs = [] else: docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, - self.vectorstore, - settings.EMBEDDINGS_KEY + settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY ) docs_temp = docsearch.search(self.question, k=self.chunks) docs = [ { - "title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, - "text": i.page_content - } + "title": ( + i.metadata["title"].split("/")[-1] + if i.metadata + else i.page_content + ), + "text": i.page_content, + } for i in docs_temp ] if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] - + return docs - + def gen(self): docs = self._get_data() - + # join all page_content together with a newline docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) @@ -72,20 +82,27 @@ class ClassicRAG(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_batch = count_tokens(i["prompt"]) + count_tokens( + i["response"] + ) + if ( + tokens_current_history + tokens_batch + < settings.TOKENS_MAX_HISTORY + ): tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append( + {"role": "user", "content": i["prompt"]} + ) + messages_combine.append( + {"role": "system", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) - completion = llm.gen_stream(model=self.gpt_model, - messages=messages_combine) + completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: yield {"answer": str(line)} - + def search(self): return self._get_data() - diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index d662bb0..f61c082 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -6,16 +6,25 @@ from langchain_community.tools import DuckDuckGoSearchResults from langchain_community.utilities import DuckDuckGoSearchAPIWrapper - class DuckDuckSearch(BaseRetriever): - def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + def __init__( + self, + question, + source, + chat_history, + prompt, + chunks=2, + gpt_model="docsgpt", + api_key=None, + ): self.question = question self.source = source self.chat_history = chat_history self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model + self.api_key = api_key def _parse_lang_string(self, input_string): result = [] @@ -30,12 +39,12 @@ class DuckDuckSearch(BaseRetriever): current_item = "" elif inside_brackets: current_item += char - + if inside_brackets: result.append(current_item) - + return result - + def _get_data(self): if self.chunks == 0: docs = [] @@ -44,7 +53,7 @@ class DuckDuckSearch(BaseRetriever): search = DuckDuckGoSearchResults(api_wrapper=wrapper) results = search.run(self.question) results = self._parse_lang_string(results) - + docs = [] for i in results: try: @@ -56,12 +65,12 @@ class DuckDuckSearch(BaseRetriever): pass if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] - + return docs - + def gen(self): docs = self._get_data() - + # join all page_content together with a newline docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) @@ -75,20 +84,27 @@ class DuckDuckSearch(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_batch = count_tokens(i["prompt"]) + count_tokens( + i["response"] + ) + if ( + tokens_current_history + tokens_batch + < settings.TOKENS_MAX_HISTORY + ): tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append( + {"role": "user", "content": i["prompt"]} + ) + messages_combine.append( + {"role": "system", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) - completion = llm.gen_stream(model=self.gpt_model, - messages=messages_combine) + completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: yield {"answer": str(line)} - + def search(self): return self._get_data() -