Update application files and fix LLM models, create new retriever class

This commit is contained in:
Alex 2024-04-09 14:02:33 +01:00
parent e07df29ab9
commit 391f686173
13 changed files with 202 additions and 185 deletions

View File

@ -8,13 +8,14 @@ import traceback
from pymongo import MongoClient
from bson.objectid import ObjectId
from transformers import GPT2TokenizerFast
from application.utils import count_tokens
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.error import bad_request
@ -62,9 +63,6 @@ async def async_generate(chain, question, chat_history):
return result
def count_tokens(string):
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
return len(tokenizer(string)['input_ids'])
def run_async_chain(chain, question, chat_history):
@ -104,61 +102,11 @@ def get_vectorstore(data):
def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
def complete_stream(question, docsearch, chat_history, prompt_id, conversation_id, chunks=2):
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
if prompt_id == 'default':
prompt = chat_combine_template
elif prompt_id == 'creative':
prompt = chat_combine_creative
elif prompt_id == 'strict':
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
if chunks == 0:
docs = []
else:
docs = docsearch.search(question, k=chunks)
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
# join all page_content together with a newline
docs_together = "\n".join([doc.page_content for doc in docs])
p_chat_combine = prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
source_log_docs = []
for doc in docs:
if doc.metadata:
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
else:
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
if len(chat_history) > 1:
tokens_current_history = 0
# count tokens in history
chat_history.reverse()
for i in chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append({"role": "user", "content": question})
response_full = ""
completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
messages=messages_combine)
for line in completion:
data = json.dumps({"answer": str(line)})
response_full += str(line)
yield f"data: {data}\n\n"
# save conversation to database
def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}},
{"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}},
)
else:
@ -168,20 +116,50 @@ def complete_stream(question, docsearch, chat_history, prompt_id, conversation_i
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: " + question + "\n\n" +
"AI: " +
response_full},
response},
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system"}]
completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
completion = llm.gen(model=gpt_model,
messages=messages_summary, max_tokens=30)
conversation_id = conversations_collection.insert_one(
{"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]}
"queries": [{"prompt": question, "response": response, "sources": source_log_docs}]}
).inserted_id
def get_prompt(prompt_id):
if prompt_id == 'default':
prompt = chat_combine_template
elif prompt_id == 'creative':
prompt = chat_combine_creative
elif prompt_id == 'strict':
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt
def complete_stream(question, retriever, conversation_id):
response_full = ""
source_log_docs = []
answer = retriever.gen()
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
data = json.dumps(line)
yield f"data: {data}\n\n"
elif "source" in line:
source_log_docs.append(line["source"])
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
@ -213,25 +191,26 @@ def stream():
chunks = int(data["chunks"])
else:
chunks = 2
prompt = get_prompt(prompt_id)
# check if active_docs is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
source = {"active_docs": data_key["source"]}
elif "active_docs" in data:
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
source = {"active_docs": data["active_docs"]}
else:
vectorstore = ""
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
source = {}
retriever = RetrieverCreator.create_retriever("classic", question=question,
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
)
return Response(
complete_stream(question, docsearch,
chat_history=history,
prompt_id=prompt_id,
conversation_id=conversation_id,
chunks=chunks), mimetype="text/event-stream"
)
complete_stream(question=question, retriever=retriever,
conversation_id=conversation_id), mimetype="text/event-stream")
@answer.route("/api/answer", methods=["POST"])
@ -255,110 +234,35 @@ def api_answer():
chunks = int(data["chunks"])
else:
chunks = 2
if prompt_id == 'default':
prompt = chat_combine_template
elif prompt_id == 'creative':
prompt = chat_combine_creative
elif prompt_id == 'strict':
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
prompt = get_prompt(prompt_id)
# use try and except to check for exception
try:
# check if the vectorstore is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
source = {"active_docs": data_key["source"]}
else:
vectorstore = get_vectorstore(data)
# loading the index and the store and the prompt template
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
source = {data}
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
if chunks == 0:
docs = []
else:
docs = docsearch.search(question, k=chunks)
# join all page_content together with a newline
docs_together = "\n".join([doc.page_content for doc in docs])
p_chat_combine = prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
retriever = RetrieverCreator.create_retriever("classic", question=question,
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
)
source_log_docs = []
for doc in docs:
if doc.metadata:
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
else:
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
# join all page_content together with a newline
response_full = ""
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
if len(history) > 1:
tokens_current_history = 0
# count tokens in history
history.reverse()
for i in history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append({"role": "user", "content": question})
completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
messages=messages_combine)
result = {"answer": completion, "sources": source_log_docs}
logger.debug(result)
# generate conversationId
if conversation_id is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{"$push": {"queries": {"prompt": question,
"response": result["answer"], "sources": result['sources']}}},
)
else:
# create new conversation
# generate summary
messages_summary = [
{"role": "assistant", "content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the system \n\n"
"User: " + question + "\n\n" + "AI: " + result["answer"]},
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the system"}
]
completion = llm.gen(
model=gpt_model,
engine=settings.AZURE_DEPLOYMENT_NAME,
messages=messages_summary,
max_tokens=30
)
conversation_id = conversations_collection.insert_one(
{"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]}
).inserted_id
result["conversation_id"] = str(conversation_id)
# mock result
# result = {
# "answer": "The answer is 42",
# "sources": ["https://en.wikipedia.org/wiki/42_(number)", "https://en.wikipedia.org/wiki/42_(number)"]
# }
return result
except Exception as e:
# print whole traceback
@ -375,20 +279,20 @@ def api_search():
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = data_key["source"]
source = {"active_docs": data_key["source"]}
elif "active_docs" in data:
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
source = {"active_docs": data["active_docs"]}
else:
vectorstore = ""
source = {}
if 'chunks' in data:
chunks = int(data["chunks"])
else:
chunks = 2
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
if chunks == 0:
docs = []
else:
docs = docsearch.search(question, k=chunks)
retriever = RetrieverCreator.create_retriever("classic", question=question,
source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model
)
docs = retriever.search()
source_log_docs = []
for doc in docs:
@ -396,6 +300,5 @@ def api_search():
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
else:
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
#yield f"data:{data}\n\n"
return source_log_docs

View File

@ -10,7 +10,7 @@ class AnthropicLLM(BaseLLM):
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT
def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs):
def gen(self, model, messages, max_tokens=300, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
@ -25,7 +25,7 @@ class AnthropicLLM(BaseLLM):
)
return completion.completion
def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs):
def gen_stream(self, model, messages, max_tokens=300, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Context \n {context} \n ### Question \n {user_question}"

View File

@ -8,7 +8,7 @@ class DocsGPTAPILLM(BaseLLM):
self.endpoint = "https://llm.docsgpt.co.uk"
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -24,7 +24,7 @@ class DocsGPTAPILLM(BaseLLM):
return response_clean
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

View File

@ -29,7 +29,7 @@ class HuggingFaceLLM(BaseLLM):
)
hf = HuggingFacePipeline(pipeline=pipe)
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -38,7 +38,7 @@ class HuggingFaceLLM(BaseLLM):
return result.content
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, **kwargs):
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")

View File

@ -12,7 +12,7 @@ class LlamaCpp(BaseLLM):
llama = Llama(model_path=llm_name, n_ctx=2048)
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -24,7 +24,7 @@ class LlamaCpp(BaseLLM):
return result['choices'][0]['text'].split('### Answer \n')[-1]
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

View File

@ -18,7 +18,7 @@ class OpenAILLM(BaseLLM):
return openai
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs):
response = self.client.chat.completions.create(model=model,
messages=messages,
stream=stream,
@ -26,7 +26,7 @@ class OpenAILLM(BaseLLM):
return response.choices[0].message.content
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs):
response = self.client.chat.completions.create(model=model,
messages=messages,
stream=stream,

View File

@ -12,7 +12,7 @@ class PremAILLM(BaseLLM):
self.api_key = api_key
self.project_id = settings.PREMAI_PROJECT_ID
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, **kwargs):
response = self.client.chat.completions.create(model=model,
project_id=self.project_id,
messages=messages,
@ -21,7 +21,7 @@ class PremAILLM(BaseLLM):
return response.choices[0].message["content"]
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, **kwargs):
response = self.client.chat.completions.create(model=model,
project_id=self.project_id,
messages=messages,

View File

@ -74,7 +74,7 @@ class SagemakerAPILLM(BaseLLM):
self.runtime = runtime
def gen(self, model, engine, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -103,7 +103,7 @@ class SagemakerAPILLM(BaseLLM):
print(result[0]['generated_text'], file=sys.stderr)
return result[0]['generated_text'][len(prompt):]
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
def gen_stream(self, model, messages, stream=True, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

View File

View File

@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
class BaseRetriever(ABC):
def __init__(self):
pass
@abstractmethod
def gen(self, *args, **kwargs):
pass

View File

@ -0,0 +1,83 @@
import os
import json
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens
class ClassicRAG(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
self.question = question
self.vectorstore = self._get_vectorstore(source=source)
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
def _get_vectorstore(self, source):
if "active_docs" in source:
if source["active_docs"].split("/")[0] == "default":
vectorstore = ""
elif source["active_docs"].split("/")[0] == "local":
vectorstore = "indexes/" + source["active_docs"]
else:
vectorstore = "vectors/" + source["active_docs"]
if source["active_docs"] == "default":
vectorstore = ""
else:
vectorstore = ""
vectorstore = os.path.join("application", vectorstore)
return vectorstore
def _get_data(self):
if self.chunks == 0:
docs = []
else:
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY)
docs = docsearch.search(self.question, k=self.chunks)
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
docs_together = "\n".join([doc.page_content for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
if doc.metadata:
yield {"source": {"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}}
else:
yield {"source": {"title": doc.page_content, "text": doc.page_content}}
if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
completion = llm.gen_stream(model=self.gpt_model,
messages=messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()

View File

@ -0,0 +1,15 @@
from application.retriever.classic_rag import ClassicRAG
class RetrieverCreator:
retievers = {
'classic': ClassicRAG,
}
@classmethod
def create_retriever(cls, type, *args, **kwargs):
retiever_class = cls.retievers.get(type.lower())
if not retiever_class:
raise ValueError(f"No retievers class found for type {type}")
return retiever_class(*args, **kwargs)

6
application/utils.py Normal file
View File

@ -0,0 +1,6 @@
from transformers import GPT2TokenizerFast
def count_tokens(string):
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
return len(tokenizer(string)['input_ids'])