mirror of
https://github.com/arc53/DocsGPT
synced 2024-11-17 21:26:26 +00:00
Update application files and fix LLM models, create new retriever class
This commit is contained in:
parent
e07df29ab9
commit
391f686173
@ -8,13 +8,14 @@ import traceback
|
||||
|
||||
from pymongo import MongoClient
|
||||
from bson.objectid import ObjectId
|
||||
from transformers import GPT2TokenizerFast
|
||||
from application.utils import count_tokens
|
||||
|
||||
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.error import bad_request
|
||||
|
||||
|
||||
@ -62,9 +63,6 @@ async def async_generate(chain, question, chat_history):
|
||||
return result
|
||||
|
||||
|
||||
def count_tokens(string):
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
|
||||
return len(tokenizer(string)['input_ids'])
|
||||
|
||||
|
||||
def run_async_chain(chain, question, chat_history):
|
||||
@ -104,61 +102,11 @@ def get_vectorstore(data):
|
||||
def is_azure_configured():
|
||||
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||
|
||||
|
||||
def complete_stream(question, docsearch, chat_history, prompt_id, conversation_id, chunks=2):
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||
if prompt_id == 'default':
|
||||
prompt = chat_combine_template
|
||||
elif prompt_id == 'creative':
|
||||
prompt = chat_combine_creative
|
||||
elif prompt_id == 'strict':
|
||||
prompt = chat_combine_strict
|
||||
else:
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
|
||||
if chunks == 0:
|
||||
docs = []
|
||||
else:
|
||||
docs = docsearch.search(question, k=chunks)
|
||||
if settings.LLM_NAME == "llama.cpp":
|
||||
docs = [docs[0]]
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc.page_content for doc in docs])
|
||||
p_chat_combine = prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
source_log_docs = []
|
||||
for doc in docs:
|
||||
if doc.metadata:
|
||||
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
||||
else:
|
||||
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
||||
|
||||
if len(chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
# count tokens in history
|
||||
chat_history.reverse()
|
||||
for i in chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
||||
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
||||
tokens_current_history += tokens_batch
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append({"role": "system", "content": i["response"]})
|
||||
messages_combine.append({"role": "user", "content": question})
|
||||
|
||||
response_full = ""
|
||||
completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
messages=messages_combine)
|
||||
for line in completion:
|
||||
data = json.dumps({"answer": str(line)})
|
||||
response_full += str(line)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# save conversation to database
|
||||
def save_conversation(conversation_id, question, response, source_log_docs, llm):
|
||||
if conversation_id is not None:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}},
|
||||
{"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}},
|
||||
)
|
||||
|
||||
else:
|
||||
@ -168,20 +116,50 @@ def complete_stream(question, docsearch, chat_history, prompt_id, conversation_i
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the system \n\nUser: " + question + "\n\n" +
|
||||
"AI: " +
|
||||
response_full},
|
||||
response},
|
||||
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"system"}]
|
||||
|
||||
completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
completion = llm.gen(model=gpt_model,
|
||||
messages=messages_summary, max_tokens=30)
|
||||
conversation_id = conversations_collection.insert_one(
|
||||
{"user": "local",
|
||||
"date": datetime.datetime.utcnow(),
|
||||
"name": completion,
|
||||
"queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]}
|
||||
"queries": [{"prompt": question, "response": response, "sources": source_log_docs}]}
|
||||
).inserted_id
|
||||
|
||||
def get_prompt(prompt_id):
|
||||
if prompt_id == 'default':
|
||||
prompt = chat_combine_template
|
||||
elif prompt_id == 'creative':
|
||||
prompt = chat_combine_creative
|
||||
elif prompt_id == 'strict':
|
||||
prompt = chat_combine_strict
|
||||
else:
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
return prompt
|
||||
|
||||
|
||||
def complete_stream(question, retriever, conversation_id):
|
||||
|
||||
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
answer = retriever.gen()
|
||||
for line in answer:
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
elif "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
|
||||
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
yield f"data: {data}\n\n"
|
||||
@ -213,25 +191,26 @@ def stream():
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
# check if active_docs is set
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
|
||||
source = {"active_docs": data_key["source"]}
|
||||
elif "active_docs" in data:
|
||||
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
else:
|
||||
vectorstore = ""
|
||||
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
|
||||
source = {}
|
||||
|
||||
retriever = RetrieverCreator.create_retriever("classic", question=question,
|
||||
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
|
||||
)
|
||||
|
||||
return Response(
|
||||
complete_stream(question, docsearch,
|
||||
chat_history=history,
|
||||
prompt_id=prompt_id,
|
||||
conversation_id=conversation_id,
|
||||
chunks=chunks), mimetype="text/event-stream"
|
||||
)
|
||||
complete_stream(question=question, retriever=retriever,
|
||||
conversation_id=conversation_id), mimetype="text/event-stream")
|
||||
|
||||
|
||||
@answer.route("/api/answer", methods=["POST"])
|
||||
@ -255,110 +234,35 @@ def api_answer():
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
|
||||
if prompt_id == 'default':
|
||||
prompt = chat_combine_template
|
||||
elif prompt_id == 'creative':
|
||||
prompt = chat_combine_creative
|
||||
elif prompt_id == 'strict':
|
||||
prompt = chat_combine_strict
|
||||
else:
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
# use try and except to check for exception
|
||||
try:
|
||||
# check if the vectorstore is set
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
|
||||
source = {"active_docs": data_key["source"]}
|
||||
else:
|
||||
vectorstore = get_vectorstore(data)
|
||||
# loading the index and the store and the prompt template
|
||||
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
|
||||
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
|
||||
source = {data}
|
||||
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||
|
||||
|
||||
|
||||
if chunks == 0:
|
||||
docs = []
|
||||
else:
|
||||
docs = docsearch.search(question, k=chunks)
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc.page_content for doc in docs])
|
||||
p_chat_combine = prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
retriever = RetrieverCreator.create_retriever("classic", question=question,
|
||||
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
|
||||
)
|
||||
source_log_docs = []
|
||||
for doc in docs:
|
||||
if doc.metadata:
|
||||
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
||||
else:
|
||||
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
||||
# join all page_content together with a newline
|
||||
response_full = ""
|
||||
for line in retriever.gen():
|
||||
if "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
|
||||
|
||||
if len(history) > 1:
|
||||
tokens_current_history = 0
|
||||
# count tokens in history
|
||||
history.reverse()
|
||||
for i in history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
||||
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
||||
tokens_current_history += tokens_batch
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append({"role": "system", "content": i["response"]})
|
||||
messages_combine.append({"role": "user", "content": question})
|
||||
|
||||
|
||||
completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
messages=messages_combine)
|
||||
|
||||
|
||||
result = {"answer": completion, "sources": source_log_docs}
|
||||
logger.debug(result)
|
||||
|
||||
# generate conversationId
|
||||
if conversation_id is not None:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{"$push": {"queries": {"prompt": question,
|
||||
"response": result["answer"], "sources": result['sources']}}},
|
||||
)
|
||||
|
||||
else:
|
||||
# create new conversation
|
||||
# generate summary
|
||||
messages_summary = [
|
||||
{"role": "assistant", "content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the system \n\n"
|
||||
"User: " + question + "\n\n" + "AI: " + result["answer"]},
|
||||
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the system"}
|
||||
]
|
||||
|
||||
completion = llm.gen(
|
||||
model=gpt_model,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
messages=messages_summary,
|
||||
max_tokens=30
|
||||
)
|
||||
conversation_id = conversations_collection.insert_one(
|
||||
{"user": "local",
|
||||
"date": datetime.datetime.utcnow(),
|
||||
"name": completion,
|
||||
"queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]}
|
||||
).inserted_id
|
||||
|
||||
result["conversation_id"] = str(conversation_id)
|
||||
|
||||
# mock result
|
||||
# result = {
|
||||
# "answer": "The answer is 42",
|
||||
# "sources": ["https://en.wikipedia.org/wiki/42_(number)", "https://en.wikipedia.org/wiki/42_(number)"]
|
||||
# }
|
||||
return result
|
||||
except Exception as e:
|
||||
# print whole traceback
|
||||
@ -375,20 +279,20 @@ def api_search():
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
vectorstore = data_key["source"]
|
||||
source = {"active_docs": data_key["source"]}
|
||||
elif "active_docs" in data:
|
||||
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
else:
|
||||
vectorstore = ""
|
||||
source = {}
|
||||
if 'chunks' in data:
|
||||
chunks = int(data["chunks"])
|
||||
else:
|
||||
chunks = 2
|
||||
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
|
||||
if chunks == 0:
|
||||
docs = []
|
||||
else:
|
||||
docs = docsearch.search(question, k=chunks)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever("classic", question=question,
|
||||
source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model
|
||||
)
|
||||
docs = retriever.search()
|
||||
|
||||
source_log_docs = []
|
||||
for doc in docs:
|
||||
@ -396,6 +300,5 @@ def api_search():
|
||||
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
||||
else:
|
||||
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
||||
#yield f"data:{data}\n\n"
|
||||
return source_log_docs
|
||||
|
||||
|
@ -10,7 +10,7 @@ class AnthropicLLM(BaseLLM):
|
||||
self.HUMAN_PROMPT = HUMAN_PROMPT
|
||||
self.AI_PROMPT = AI_PROMPT
|
||||
|
||||
def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs):
|
||||
def gen(self, model, messages, max_tokens=300, stream=False, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||
@ -25,7 +25,7 @@ class AnthropicLLM(BaseLLM):
|
||||
)
|
||||
return completion.completion
|
||||
|
||||
def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs):
|
||||
def gen_stream(self, model, messages, max_tokens=300, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||
|
@ -8,7 +8,7 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
self.endpoint = "https://llm.docsgpt.co.uk"
|
||||
|
||||
|
||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
||||
def gen(self, model, messages, stream=False, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
@ -24,7 +24,7 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
return response_clean
|
||||
|
||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
||||
def gen_stream(self, model, messages, stream=True, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
@ -29,7 +29,7 @@ class HuggingFaceLLM(BaseLLM):
|
||||
)
|
||||
hf = HuggingFacePipeline(pipeline=pipe)
|
||||
|
||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
||||
def gen(self, model, messages, stream=False, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
@ -38,7 +38,7 @@ class HuggingFaceLLM(BaseLLM):
|
||||
|
||||
return result.content
|
||||
|
||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
||||
def gen_stream(self, model, messages, stream=True, **kwargs):
|
||||
|
||||
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")
|
||||
|
||||
|
@ -12,7 +12,7 @@ class LlamaCpp(BaseLLM):
|
||||
|
||||
llama = Llama(model_path=llm_name, n_ctx=2048)
|
||||
|
||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
||||
def gen(self, model, messages, stream=False, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
@ -24,7 +24,7 @@ class LlamaCpp(BaseLLM):
|
||||
|
||||
return result['choices'][0]['text'].split('### Answer \n')[-1]
|
||||
|
||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
||||
def gen_stream(self, model, messages, stream=True, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
@ -18,7 +18,7 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
return openai
|
||||
|
||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
||||
def gen(self, model, messages, stream=False, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs):
|
||||
response = self.client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
@ -26,7 +26,7 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
||||
def gen_stream(self, model, messages, stream=True, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs):
|
||||
response = self.client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
|
@ -12,7 +12,7 @@ class PremAILLM(BaseLLM):
|
||||
self.api_key = api_key
|
||||
self.project_id = settings.PREMAI_PROJECT_ID
|
||||
|
||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
||||
def gen(self, model, messages, stream=False, **kwargs):
|
||||
response = self.client.chat.completions.create(model=model,
|
||||
project_id=self.project_id,
|
||||
messages=messages,
|
||||
@ -21,7 +21,7 @@ class PremAILLM(BaseLLM):
|
||||
|
||||
return response.choices[0].message["content"]
|
||||
|
||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
||||
def gen_stream(self, model, messages, stream=True, **kwargs):
|
||||
response = self.client.chat.completions.create(model=model,
|
||||
project_id=self.project_id,
|
||||
messages=messages,
|
||||
|
@ -74,7 +74,7 @@ class SagemakerAPILLM(BaseLLM):
|
||||
self.runtime = runtime
|
||||
|
||||
|
||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
||||
def gen(self, model, messages, stream=False, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
@ -103,7 +103,7 @@ class SagemakerAPILLM(BaseLLM):
|
||||
print(result[0]['generated_text'], file=sys.stderr)
|
||||
return result[0]['generated_text'][len(prompt):]
|
||||
|
||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
||||
def gen_stream(self, model, messages, stream=True, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
0
application/retriever/__init__.py
Normal file
0
application/retriever/__init__.py
Normal file
10
application/retriever/base.py
Normal file
10
application/retriever/base.py
Normal file
@ -0,0 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gen(self, *args, **kwargs):
|
||||
pass
|
83
application/retriever/classic_rag.py
Normal file
83
application/retriever/classic_rag.py
Normal file
@ -0,0 +1,83 @@
|
||||
import os
|
||||
import json
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
from application.utils import count_tokens
|
||||
|
||||
|
||||
|
||||
class ClassicRAG(BaseRetriever):
|
||||
|
||||
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
|
||||
self.question = question
|
||||
self.vectorstore = self._get_vectorstore(source=source)
|
||||
self.chat_history = chat_history
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
|
||||
def _get_vectorstore(self, source):
|
||||
if "active_docs" in source:
|
||||
if source["active_docs"].split("/")[0] == "default":
|
||||
vectorstore = ""
|
||||
elif source["active_docs"].split("/")[0] == "local":
|
||||
vectorstore = "indexes/" + source["active_docs"]
|
||||
else:
|
||||
vectorstore = "vectors/" + source["active_docs"]
|
||||
if source["active_docs"] == "default":
|
||||
vectorstore = ""
|
||||
else:
|
||||
vectorstore = ""
|
||||
vectorstore = os.path.join("application", vectorstore)
|
||||
return vectorstore
|
||||
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
docs = []
|
||||
else:
|
||||
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY)
|
||||
docs = docsearch.search(self.question, k=self.chunks)
|
||||
if settings.LLM_NAME == "llama.cpp":
|
||||
docs = [docs[0]]
|
||||
return docs
|
||||
|
||||
def gen(self):
|
||||
docs = self._get_data()
|
||||
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc.page_content for doc in docs])
|
||||
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
for doc in docs:
|
||||
if doc.metadata:
|
||||
yield {"source": {"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}}
|
||||
else:
|
||||
yield {"source": {"title": doc.page_content, "text": doc.page_content}}
|
||||
|
||||
if len(self.chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
# count tokens in history
|
||||
self.chat_history.reverse()
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
||||
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
||||
tokens_current_history += tokens_batch
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append({"role": "system", "content": i["response"]})
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model,
|
||||
messages=messages_combine)
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
def search(self):
|
||||
return self._get_data()
|
||||
|
15
application/retriever/retriever_creator.py
Normal file
15
application/retriever/retriever_creator.py
Normal file
@ -0,0 +1,15 @@
|
||||
from application.retriever.classic_rag import ClassicRAG
|
||||
|
||||
|
||||
|
||||
class RetrieverCreator:
|
||||
retievers = {
|
||||
'classic': ClassicRAG,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_retriever(cls, type, *args, **kwargs):
|
||||
retiever_class = cls.retievers.get(type.lower())
|
||||
if not retiever_class:
|
||||
raise ValueError(f"No retievers class found for type {type}")
|
||||
return retiever_class(*args, **kwargs)
|
6
application/utils.py
Normal file
6
application/utils.py
Normal file
@ -0,0 +1,6 @@
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
|
||||
def count_tokens(string):
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
|
||||
return len(tokenizer(string)['input_ids'])
|
Loading…
Reference in New Issue
Block a user