import os from application.retriever.base import BaseRetriever from application.core.settings import settings from application.vectorstore.vector_creator import VectorCreator from application.llm.llm_creator import LLMCreator from application.utils import count_tokens class ClassicRAG(BaseRetriever): def __init__( self, question, source, chat_history, prompt, chunks=2, gpt_model="docsgpt", user_api_key=None, ): self.question = question self.vectorstore = self._get_vectorstore(source=source) self.chat_history = chat_history self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model self.user_api_key = user_api_key def _get_vectorstore(self, source): if "active_docs" in source: if source["active_docs"].split("/")[0] == "default": vectorstore = "" elif source["active_docs"].split("/")[0] == "local": vectorstore = "indexes/" + source["active_docs"] else: vectorstore = "vectors/" + source["active_docs"] if source["active_docs"] == "default": vectorstore = "" else: vectorstore = "" vectorstore = os.path.join("application", vectorstore) return vectorstore def _get_data(self): if self.chunks == 0: docs = [] else: docsearch = VectorCreator.create_vectorstore( settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY ) docs_temp = docsearch.search(self.question, k=self.chunks) docs = [ { "title": ( i.metadata["title"].split("/")[-1] if i.metadata else i.page_content ), "text": i.page_content, } for i in docs_temp ] if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] return docs def gen(self): docs = self._get_data() # join all page_content together with a newline docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) messages_combine = [{"role": "system", "content": p_chat_combine}] for doc in docs: yield {"source": doc} if len(self.chat_history) > 1: tokens_current_history = 0 # count tokens in history self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: tokens_batch = count_tokens(i["prompt"]) + count_tokens( i["response"] ) if ( tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY ): tokens_current_history += tokens_batch messages_combine.append( {"role": "user", "content": i["prompt"]} ) messages_combine.append( {"role": "system", "content": i["response"]} ) messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key ) completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: yield {"answer": str(line)} def search(self): return self._get_data()