Merge pull request #926 from siiddhantt/feature

Feature: Logging token usage info to MongoDB
pull/933/head 0.9.0
Alex 4 weeks ago committed by GitHub
commit 8873428b4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -10,14 +10,12 @@ from pymongo import MongoClient
from bson.objectid import ObjectId from bson.objectid import ObjectId
from application.core.settings import settings from application.core.settings import settings
from application.llm.llm_creator import LLMCreator from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator from application.retriever.retriever_creator import RetrieverCreator
from application.error import bad_request from application.error import bad_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
mongo = MongoClient(settings.MONGO_URI) mongo = MongoClient(settings.MONGO_URI)
@ -26,20 +24,22 @@ conversations_collection = db["conversations"]
vectors_collection = db["vectors"] vectors_collection = db["vectors"]
prompts_collection = db["prompts"] prompts_collection = db["prompts"]
api_key_collection = db["api_keys"] api_key_collection = db["api_keys"]
answer = Blueprint('answer', __name__) answer = Blueprint("answer", __name__)
gpt_model = "" gpt_model = ""
# to have some kind of default behaviour # to have some kind of default behaviour
if settings.LLM_NAME == "openai": if settings.LLM_NAME == "openai":
gpt_model = 'gpt-3.5-turbo' gpt_model = "gpt-3.5-turbo"
elif settings.LLM_NAME == "anthropic": elif settings.LLM_NAME == "anthropic":
gpt_model = 'claude-2' gpt_model = "claude-2"
if settings.MODEL_NAME: # in case there is particular model name configured if settings.MODEL_NAME: # in case there is particular model name configured
gpt_model = settings.MODEL_NAME gpt_model = settings.MODEL_NAME
# load the prompts # load the prompts
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read() chat_combine_template = f.read()
@ -50,7 +50,7 @@ with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r"
chat_combine_creative = f.read() chat_combine_creative = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
chat_combine_strict = f.read() chat_combine_strict = f.read()
api_key_set = settings.API_KEY is not None api_key_set = settings.API_KEY is not None
embeddings_key_set = settings.EMBEDDINGS_KEY is not None embeddings_key_set = settings.EMBEDDINGS_KEY is not None
@ -61,8 +61,6 @@ async def async_generate(chain, question, chat_history):
return result return result
def run_async_chain(chain, question, chat_history): def run_async_chain(chain, question, chat_history):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
@ -74,17 +72,18 @@ def run_async_chain(chain, question, chat_history):
result["answer"] = answer result["answer"] = answer
return result return result
def get_data_from_api_key(api_key): def get_data_from_api_key(api_key):
data = api_key_collection.find_one({"key": api_key}) data = api_key_collection.find_one({"key": api_key})
if data is None: if data is None:
return bad_request(401, "Invalid API key") return bad_request(401, "Invalid API key")
return data return data
def get_vectorstore(data): def get_vectorstore(data):
if "active_docs" in data: if "active_docs" in data:
if data["active_docs"].split("/")[0] == "default": if data["active_docs"].split("/")[0] == "default":
vectorstore = "" vectorstore = ""
elif data["active_docs"].split("/")[0] == "local": elif data["active_docs"].split("/")[0] == "local":
vectorstore = "indexes/" + data["active_docs"] vectorstore = "indexes/" + data["active_docs"]
else: else:
@ -98,52 +97,82 @@ def get_vectorstore(data):
def is_azure_configured(): def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
def save_conversation(conversation_id, question, response, source_log_docs, llm): def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None and conversation_id != "None": if conversation_id is not None and conversation_id != "None":
conversations_collection.update_one( conversations_collection.update_one(
{"_id": ObjectId(conversation_id)}, {"_id": ObjectId(conversation_id)},
{"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}}, {
"$push": {
"queries": {
"prompt": question,
"response": response,
"sources": source_log_docs,
}
}
},
) )
else: else:
# create new conversation # create new conversation
# generate summary # generate summary
messages_summary = [{"role": "assistant", "content": "Summarise following conversation in no more than 3 " messages_summary = [
"words, respond ONLY with the summary, use the same " {
"language as the system \n\nUser: " + question + "\n\n" + "role": "assistant",
"AI: " + "content": "Summarise following conversation in no more than 3 "
response}, "words, respond ONLY with the summary, use the same "
{"role": "user", "content": "Summarise following conversation in no more than 3 words, " "language as the system \n\nUser: "
"respond ONLY with the summary, use the same language as the " + question
"system"}] + "\n\n"
+ "AI: "
completion = llm.gen(model=gpt_model, + response,
messages=messages_summary, max_tokens=30) },
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system",
},
]
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
conversation_id = conversations_collection.insert_one( conversation_id = conversations_collection.insert_one(
{"user": "local", {
"date": datetime.datetime.utcnow(), "user": "local",
"name": completion, "date": datetime.datetime.utcnow(),
"queries": [{"prompt": question, "response": response, "sources": source_log_docs}]} "name": completion,
"queries": [
{
"prompt": question,
"response": response,
"sources": source_log_docs,
}
],
}
).inserted_id ).inserted_id
return conversation_id return conversation_id
def get_prompt(prompt_id): def get_prompt(prompt_id):
if prompt_id == 'default': if prompt_id == "default":
prompt = chat_combine_template prompt = chat_combine_template
elif prompt_id == 'creative': elif prompt_id == "creative":
prompt = chat_combine_creative prompt = chat_combine_creative
elif prompt_id == 'strict': elif prompt_id == "strict":
prompt = chat_combine_strict prompt = chat_combine_strict
else: else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt return prompt
def complete_stream(question, retriever, conversation_id): def complete_stream(question, retriever, conversation_id, user_api_key):
response_full = "" response_full = ""
source_log_docs = [] source_log_docs = []
answer = retriever.gen() answer = retriever.gen()
@ -155,9 +184,12 @@ def complete_stream(question, retriever, conversation_id):
elif "source" in line: elif "source" in line:
source_log_docs.append(line["source"]) source_log_docs.append(line["source"])
llm = LLMCreator.create_llm(
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm) )
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
# send data.type = "end" to indicate that the stream has ended as json # send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)}) data = json.dumps({"type": "id", "id": str(conversation_id)})
@ -180,17 +212,17 @@ def stream():
conversation_id = None conversation_id = None
else: else:
conversation_id = data["conversation_id"] conversation_id = data["conversation_id"]
if 'prompt_id' in data: if "prompt_id" in data:
prompt_id = data["prompt_id"] prompt_id = data["prompt_id"]
else: else:
prompt_id = 'default' prompt_id = "default"
if 'selectedDocs' in data and data['selectedDocs'] is None: if "selectedDocs" in data and data["selectedDocs"] is None:
chunks = 0 chunks = 0
elif 'chunks' in data: elif "chunks" in data:
chunks = int(data["chunks"]) chunks = int(data["chunks"])
else: else:
chunks = 2 chunks = 2
prompt = get_prompt(prompt_id) prompt = get_prompt(prompt_id)
# check if active_docs is set # check if active_docs is set
@ -198,23 +230,42 @@ def stream():
if "api_key" in data: if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"]) data_key = get_data_from_api_key(data["api_key"])
source = {"active_docs": data_key["source"]} source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
elif "active_docs" in data: elif "active_docs" in data:
source = {"active_docs": data["active_docs"]} source = {"active_docs": data["active_docs"]}
user_api_key = None
else: else:
source = {} source = {}
user_api_key = None
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
retriever_name = "classic" retriever_name = "classic"
else: else:
retriever_name = source['active_docs'] retriever_name = source["active_docs"]
retriever = RetrieverCreator.create_retriever(retriever_name, question=question, retriever = RetrieverCreator.create_retriever(
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model retriever_name,
) question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response( return Response(
complete_stream(question=question, retriever=retriever, complete_stream(
conversation_id=conversation_id), mimetype="text/event-stream") question=question,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
),
mimetype="text/event-stream",
)
@answer.route("/api/answer", methods=["POST"]) @answer.route("/api/answer", methods=["POST"])
@ -230,15 +281,15 @@ def api_answer():
else: else:
conversation_id = data["conversation_id"] conversation_id = data["conversation_id"]
print("-" * 5) print("-" * 5)
if 'prompt_id' in data: if "prompt_id" in data:
prompt_id = data["prompt_id"] prompt_id = data["prompt_id"]
else: else:
prompt_id = 'default' prompt_id = "default"
if 'chunks' in data: if "chunks" in data:
chunks = int(data["chunks"]) chunks = int(data["chunks"])
else: else:
chunks = 2 chunks = 2
prompt = get_prompt(prompt_id) prompt = get_prompt(prompt_id)
# use try and except to check for exception # use try and except to check for exception
@ -247,17 +298,29 @@ def api_answer():
if "api_key" in data: if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"]) data_key = get_data_from_api_key(data["api_key"])
source = {"active_docs": data_key["source"]} source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
else: else:
source = {data} source = {data}
user_api_key = None
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
retriever_name = "classic" retriever_name = "classic"
else: else:
retriever_name = source['active_docs'] retriever_name = source["active_docs"]
retriever = RetrieverCreator.create_retriever(retriever_name, question=question, retriever = RetrieverCreator.create_retriever(
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model retriever_name,
) question=question,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
source_log_docs = [] source_log_docs = []
response_full = "" response_full = ""
for line in retriever.gen(): for line in retriever.gen():
@ -265,12 +328,15 @@ def api_answer():
source_log_docs.append(line["source"]) source_log_docs.append(line["source"])
elif "answer" in line: elif "answer" in line:
response_full += line["answer"] response_full += line["answer"]
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
result = {"answer": response_full, "sources": source_log_docs} result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm) result["conversation_id"] = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
return result return result
except Exception as e: except Exception as e:
@ -289,23 +355,35 @@ def api_search():
if "api_key" in data: if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"]) data_key = get_data_from_api_key(data["api_key"])
source = {"active_docs": data_key["source"]} source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
elif "active_docs" in data: elif "active_docs" in data:
source = {"active_docs": data["active_docs"]} source = {"active_docs": data["active_docs"]}
user_api_key = None
else: else:
source = {} source = {}
if 'chunks' in data: user_api_key = None
if "chunks" in data:
chunks = int(data["chunks"]) chunks = int(data["chunks"])
else: else:
chunks = 2 chunks = 2
if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": if (
source["active_docs"].split("/")[0] == "default"
or source["active_docs"].split("/")[0] == "local"
):
retriever_name = "classic" retriever_name = "classic"
else: else:
retriever_name = source['active_docs'] retriever_name = source["active_docs"]
retriever = RetrieverCreator.create_retriever(retriever_name, question=question, retriever = RetrieverCreator.create_retriever(
source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model retriever_name,
) question=question,
source=source,
chat_history=[],
prompt="default",
chunks=chunks,
gpt_model=gpt_model,
user_api_key=user_api_key,
)
docs = retriever.search() docs = retriever.search()
return docs return docs

@ -1,21 +1,29 @@
from application.llm.base import BaseLLM from application.llm.base import BaseLLM
from application.core.settings import settings from application.core.settings import settings
class AnthropicLLM(BaseLLM): class AnthropicLLM(BaseLLM):
def __init__(self, api_key=None): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
self.api_key = api_key or settings.ANTHROPIC_API_KEY # If not provided, use a default from settings
super().__init__(*args, **kwargs)
self.api_key = (
api_key or settings.ANTHROPIC_API_KEY
) # If not provided, use a default from settings
self.user_api_key = user_api_key
self.anthropic = Anthropic(api_key=self.api_key) self.anthropic = Anthropic(api_key=self.api_key)
self.HUMAN_PROMPT = HUMAN_PROMPT self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT self.AI_PROMPT = AI_PROMPT
def gen(self, model, messages, max_tokens=300, stream=False, **kwargs): def _raw_gen(
context = messages[0]['content'] self, baseself, model, messages, stream=False, max_tokens=300, **kwargs
user_question = messages[-1]['content'] ):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Context \n {context} \n ### Question \n {user_question}" prompt = f"### Context \n {context} \n ### Question \n {user_question}"
if stream: if stream:
return self.gen_stream(model, prompt, max_tokens, **kwargs) return self.gen_stream(model, prompt, stream, max_tokens, **kwargs)
completion = self.anthropic.completions.create( completion = self.anthropic.completions.create(
model=model, model=model,
@ -25,9 +33,11 @@ class AnthropicLLM(BaseLLM):
) )
return completion.completion return completion.completion
def gen_stream(self, model, messages, max_tokens=300, **kwargs): def _raw_gen_stream(
context = messages[0]['content'] self, baseself, model, messages, stream=True, max_tokens=300, **kwargs
user_question = messages[-1]['content'] ):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Context \n {context} \n ### Question \n {user_question}" prompt = f"### Context \n {context} \n ### Question \n {user_question}"
stream_response = self.anthropic.completions.create( stream_response = self.anthropic.completions.create(
model=model, model=model,
@ -37,4 +47,4 @@ class AnthropicLLM(BaseLLM):
) )
for completion in stream_response: for completion in stream_response:
yield completion.completion yield completion.completion

@ -1,14 +1,28 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from application.usage import gen_token_usage, stream_token_usage
class BaseLLM(ABC): class BaseLLM(ABC):
def __init__(self): def __init__(self):
pass self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
def _apply_decorator(self, method, decorator, *args, **kwargs):
return decorator(method, *args, **kwargs)
@abstractmethod @abstractmethod
def gen(self, *args, **kwargs): def _raw_gen(self, model, messages, stream, *args, **kwargs):
pass pass
def gen(self, model, messages, stream=False, *args, **kwargs):
return self._apply_decorator(self._raw_gen, gen_token_usage)(
self, model=model, messages=messages, stream=stream, *args, **kwargs
)
@abstractmethod @abstractmethod
def gen_stream(self, *args, **kwargs): def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
pass pass
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
return self._apply_decorator(self._raw_gen_stream, stream_token_usage)(
self, model=model, messages=messages, stream=stream, *args, **kwargs
)

@ -2,48 +2,43 @@ from application.llm.base import BaseLLM
import json import json
import requests import requests
class DocsGPTAPILLM(BaseLLM):
def __init__(self, *args, **kwargs): class DocsGPTAPILLM(BaseLLM):
self.endpoint = "https://llm.docsgpt.co.uk"
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
self.endpoint = "https://llm.docsgpt.co.uk"
def gen(self, model, messages, stream=False, **kwargs): def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs):
context = messages[0]['content'] context = messages[0]["content"]
user_question = messages[-1]['content'] user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
response = requests.post( response = requests.post(
f"{self.endpoint}/answer", f"{self.endpoint}/answer", json={"prompt": prompt, "max_new_tokens": 30}
json={
"prompt": prompt,
"max_new_tokens": 30
}
) )
response_clean = response.json()['a'].replace("###", "") response_clean = response.json()["a"].replace("###", "")
return response_clean return response_clean
def gen_stream(self, model, messages, stream=True, **kwargs): def _raw_gen_stream(self, baseself, model, messages, stream=True, *args, **kwargs):
context = messages[0]['content'] context = messages[0]["content"]
user_question = messages[-1]['content'] user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
# send prompt to endpoint /stream # send prompt to endpoint /stream
response = requests.post( response = requests.post(
f"{self.endpoint}/stream", f"{self.endpoint}/stream",
json={ json={"prompt": prompt, "max_new_tokens": 256},
"prompt": prompt, stream=True,
"max_new_tokens": 256
},
stream=True
) )
for line in response.iter_lines(): for line in response.iter_lines():
if line: if line:
#data = json.loads(line) # data = json.loads(line)
data_str = line.decode('utf-8') data_str = line.decode("utf-8")
if data_str.startswith("data: "): if data_str.startswith("data: "):
data = json.loads(data_str[6:]) data = json.loads(data_str[6:])
yield data['a'] yield data["a"]

@ -1,44 +1,68 @@
from application.llm.base import BaseLLM from application.llm.base import BaseLLM
class HuggingFaceLLM(BaseLLM): class HuggingFaceLLM(BaseLLM):
def __init__(self, api_key, llm_name='Arc53/DocsGPT-7B',q=False): def __init__(
self,
api_key=None,
user_api_key=None,
llm_name="Arc53/DocsGPT-7B",
q=False,
*args,
**kwargs,
):
global hf global hf
from langchain.llms import HuggingFacePipeline from langchain.llms import HuggingFacePipeline
if q: if q:
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
BitsAndBytesConfig,
)
tokenizer = AutoTokenizer.from_pretrained(llm_name) tokenizer = AutoTokenizer.from_pretrained(llm_name)
bnb_config = BitsAndBytesConfig( bnb_config = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
bnb_4bit_use_double_quant=True, bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16 bnb_4bit_compute_dtype=torch.bfloat16,
) )
model = AutoModelForCausalLM.from_pretrained(llm_name,quantization_config=bnb_config) model = AutoModelForCausalLM.from_pretrained(
llm_name, quantization_config=bnb_config
)
else: else:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
tokenizer = AutoTokenizer.from_pretrained(llm_name) tokenizer = AutoTokenizer.from_pretrained(llm_name)
model = AutoModelForCausalLM.from_pretrained(llm_name) model = AutoModelForCausalLM.from_pretrained(llm_name)
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
pipe = pipeline( pipe = pipeline(
"text-generation", model=model, "text-generation",
tokenizer=tokenizer, max_new_tokens=2000, model=model,
device_map="auto", eos_token_id=tokenizer.eos_token_id tokenizer=tokenizer,
max_new_tokens=2000,
device_map="auto",
eos_token_id=tokenizer.eos_token_id,
) )
hf = HuggingFacePipeline(pipeline=pipe) hf = HuggingFacePipeline(pipeline=pipe)
def gen(self, model, messages, stream=False, **kwargs): def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]['content'] context = messages[0]["content"]
user_question = messages[-1]['content'] user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = hf(prompt) result = hf(prompt)
return result.content return result.content
def gen_stream(self, model, messages, stream=True, **kwargs): def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.") raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")

@ -1,32 +1,45 @@
from application.llm.base import BaseLLM from application.llm.base import BaseLLM
from application.core.settings import settings from application.core.settings import settings
class LlamaCpp(BaseLLM): class LlamaCpp(BaseLLM):
def __init__(self, api_key, llm_name=settings.MODEL_PATH, **kwargs): def __init__(
self,
api_key=None,
user_api_key=None,
llm_name=settings.MODEL_PATH,
*args,
**kwargs,
):
global llama global llama
try: try:
from llama_cpp import Llama from llama_cpp import Llama
except ImportError: except ImportError:
raise ImportError("Please install llama_cpp using pip install llama-cpp-python") raise ImportError(
"Please install llama_cpp using pip install llama-cpp-python"
)
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
llama = Llama(model_path=llm_name, n_ctx=2048) llama = Llama(model_path=llm_name, n_ctx=2048)
def gen(self, model, messages, stream=False, **kwargs): def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]['content'] context = messages[0]["content"]
user_question = messages[-1]['content'] user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = llama(prompt, max_tokens=150, echo=False) result = llama(prompt, max_tokens=150, echo=False)
# import sys # import sys
# print(result['choices'][0]['text'].split('### Answer \n')[-1], file=sys.stderr) # print(result['choices'][0]['text'].split('### Answer \n')[-1], file=sys.stderr)
return result['choices'][0]['text'].split('### Answer \n')[-1]
def gen_stream(self, model, messages, stream=True, **kwargs): return result["choices"][0]["text"].split("### Answer \n")[-1]
context = messages[0]['content']
user_question = messages[-1]['content'] def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = llama(prompt, max_tokens=150, echo=False, stream=stream) result = llama(prompt, max_tokens=150, echo=False, stream=stream)
@ -35,5 +48,5 @@ class LlamaCpp(BaseLLM):
# print(list(result), file=sys.stderr) # print(list(result), file=sys.stderr)
for item in result: for item in result:
for choice in item['choices']: for choice in item["choices"]:
yield choice['text'] yield choice["text"]

@ -7,22 +7,21 @@ from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.premai import PremAILLM from application.llm.premai import PremAILLM
class LLMCreator: class LLMCreator:
llms = { llms = {
'openai': OpenAILLM, "openai": OpenAILLM,
'azure_openai': AzureOpenAILLM, "azure_openai": AzureOpenAILLM,
'sagemaker': SagemakerAPILLM, "sagemaker": SagemakerAPILLM,
'huggingface': HuggingFaceLLM, "huggingface": HuggingFaceLLM,
'llama.cpp': LlamaCpp, "llama.cpp": LlamaCpp,
'anthropic': AnthropicLLM, "anthropic": AnthropicLLM,
'docsgpt': DocsGPTAPILLM, "docsgpt": DocsGPTAPILLM,
'premai': PremAILLM, "premai": PremAILLM,
} }
@classmethod @classmethod
def create_llm(cls, type, *args, **kwargs): def create_llm(cls, type, api_key, user_api_key, *args, **kwargs):
llm_class = cls.llms.get(type.lower()) llm_class = cls.llms.get(type.lower())
if not llm_class: if not llm_class:
raise ValueError(f"No LLM class found for type {type}") raise ValueError(f"No LLM class found for type {type}")
return llm_class(*args, **kwargs) return llm_class(api_key, user_api_key, *args, **kwargs)

@ -1,36 +1,53 @@
from application.llm.base import BaseLLM from application.llm.base import BaseLLM
from application.core.settings import settings from application.core.settings import settings
class OpenAILLM(BaseLLM): class OpenAILLM(BaseLLM):
def __init__(self, api_key): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
global openai global openai
from openai import OpenAI from openai import OpenAI
super().__init__(*args, **kwargs)
self.client = OpenAI( self.client = OpenAI(
api_key=api_key, api_key=api_key,
) )
self.api_key = api_key self.api_key = api_key
self.user_api_key = user_api_key
def _get_openai(self): def _get_openai(self):
# Import openai when needed # Import openai when needed
import openai import openai
return openai return openai
def gen(self, model, messages, stream=False, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): def _raw_gen(
response = self.client.chat.completions.create(model=model, self,
messages=messages, baseself,
stream=stream, model,
**kwargs) messages,
stream=False,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content return response.choices[0].message.content
def gen_stream(self, model, messages, stream=True, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): def _raw_gen_stream(
response = self.client.chat.completions.create(model=model, self,
messages=messages, baseself,
stream=stream, model,
**kwargs) messages,
stream=True,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response: for line in response:
# import sys # import sys
@ -41,14 +58,17 @@ class OpenAILLM(BaseLLM):
class AzureOpenAILLM(OpenAILLM): class AzureOpenAILLM(OpenAILLM):
def __init__(self, openai_api_key, openai_api_base, openai_api_version, deployment_name): def __init__(
self, openai_api_key, openai_api_base, openai_api_version, deployment_name
):
super().__init__(openai_api_key) super().__init__(openai_api_key)
self.api_base = settings.OPENAI_API_BASE, self.api_base = (settings.OPENAI_API_BASE,)
self.api_version = settings.OPENAI_API_VERSION, self.api_version = (settings.OPENAI_API_VERSION,)
self.deployment_name = settings.AZURE_DEPLOYMENT_NAME, self.deployment_name = (settings.AZURE_DEPLOYMENT_NAME,)
from openai import AzureOpenAI from openai import AzureOpenAI
self.client = AzureOpenAI( self.client = AzureOpenAI(
api_key=openai_api_key, api_key=openai_api_key,
api_version=settings.OPENAI_API_VERSION, api_version=settings.OPENAI_API_VERSION,
api_base=settings.OPENAI_API_BASE, api_base=settings.OPENAI_API_BASE,
deployment_name=settings.AZURE_DEPLOYMENT_NAME, deployment_name=settings.AZURE_DEPLOYMENT_NAME,

@ -1,32 +1,37 @@
from application.llm.base import BaseLLM from application.llm.base import BaseLLM
from application.core.settings import settings from application.core.settings import settings
class PremAILLM(BaseLLM): class PremAILLM(BaseLLM):
def __init__(self, api_key): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from premai import Prem from premai import Prem
self.client = Prem( super().__init__(*args, **kwargs)
api_key=api_key self.client = Prem(api_key=api_key)
)
self.api_key = api_key self.api_key = api_key
self.user_api_key = user_api_key
self.project_id = settings.PREMAI_PROJECT_ID self.project_id = settings.PREMAI_PROJECT_ID
def gen(self, model, messages, stream=False, **kwargs): def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
response = self.client.chat.completions.create(model=model, response = self.client.chat.completions.create(
model=model,
project_id=self.project_id, project_id=self.project_id,
messages=messages, messages=messages,
stream=stream, stream=stream,
**kwargs) **kwargs
)
return response.choices[0].message["content"] return response.choices[0].message["content"]
def gen_stream(self, model, messages, stream=True, **kwargs): def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
response = self.client.chat.completions.create(model=model, response = self.client.chat.completions.create(
model=model,
project_id=self.project_id, project_id=self.project_id,
messages=messages, messages=messages,
stream=stream, stream=stream,
**kwargs) **kwargs
)
for line in response: for line in response:
if line.choices[0].delta["content"] is not None: if line.choices[0].delta["content"] is not None:

@ -4,11 +4,10 @@ import json
import io import io
class LineIterator: class LineIterator:
""" """
A helper class for parsing the byte stream input. A helper class for parsing the byte stream input.
The output of the model will be in the following format: The output of the model will be in the following format:
``` ```
b'{"outputs": [" a"]}\n' b'{"outputs": [" a"]}\n'
@ -16,21 +15,21 @@ class LineIterator:
b'{"outputs": [" problem"]}\n' b'{"outputs": [" problem"]}\n'
... ...
``` ```
While usually each PayloadPart event from the event stream will contain a byte array While usually each PayloadPart event from the event stream will contain a byte array
with a full json, this is not guaranteed and some of the json objects may be split across with a full json, this is not guaranteed and some of the json objects may be split across
PayloadPart events. For example: PayloadPart events. For example:
``` ```
{'PayloadPart': {'Bytes': b'{"outputs": '}} {'PayloadPart': {'Bytes': b'{"outputs": '}}
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}} {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
``` ```
This class accounts for this by concatenating bytes written via the 'write' function This class accounts for this by concatenating bytes written via the 'write' function
and then exposing a method which will return lines (ending with a '\n' character) within and then exposing a method which will return lines (ending with a '\n' character) within
the buffer via the 'scan_lines' function. It maintains the position of the last read the buffer via the 'scan_lines' function. It maintains the position of the last read
position to ensure that previous bytes are not exposed again. position to ensure that previous bytes are not exposed again.
""" """
def __init__(self, stream): def __init__(self, stream):
self.byte_iterator = iter(stream) self.byte_iterator = iter(stream)
self.buffer = io.BytesIO() self.buffer = io.BytesIO()
@ -43,7 +42,7 @@ class LineIterator:
while True: while True:
self.buffer.seek(self.read_pos) self.buffer.seek(self.read_pos)
line = self.buffer.readline() line = self.buffer.readline()
if line and line[-1] == ord('\n'): if line and line[-1] == ord("\n"):
self.read_pos += len(line) self.read_pos += len(line)
return line[:-1] return line[:-1]
try: try:
@ -52,33 +51,35 @@ class LineIterator:
if self.read_pos < self.buffer.getbuffer().nbytes: if self.read_pos < self.buffer.getbuffer().nbytes:
continue continue
raise raise
if 'PayloadPart' not in chunk: if "PayloadPart" not in chunk:
print('Unknown event type:' + chunk) print("Unknown event type:" + chunk)
continue continue
self.buffer.seek(0, io.SEEK_END) self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk['PayloadPart']['Bytes']) self.buffer.write(chunk["PayloadPart"]["Bytes"])
class SagemakerAPILLM(BaseLLM): class SagemakerAPILLM(BaseLLM):
def __init__(self, *args, **kwargs): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
import boto3 import boto3
runtime = boto3.client( runtime = boto3.client(
'runtime.sagemaker', "runtime.sagemaker",
aws_access_key_id='xxx', aws_access_key_id="xxx",
aws_secret_access_key='xxx', aws_secret_access_key="xxx",
region_name='us-west-2' region_name="us-west-2",
) )
super().__init__(*args, **kwargs)
self.endpoint = settings.SAGEMAKER_ENDPOINT self.api_key = api_key
self.user_api_key = user_api_key
self.endpoint = settings.SAGEMAKER_ENDPOINT
self.runtime = runtime self.runtime = runtime
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
def gen(self, model, messages, stream=False, **kwargs): context = messages[0]["content"]
context = messages[0]['content'] user_question = messages[-1]["content"]
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
# Construct payload for endpoint # Construct payload for endpoint
payload = { payload = {
@ -89,25 +90,25 @@ class SagemakerAPILLM(BaseLLM):
"temperature": 0.1, "temperature": 0.1,
"max_new_tokens": 30, "max_new_tokens": 30,
"repetition_penalty": 1.03, "repetition_penalty": 1.03,
"stop": ["</s>", "###"] "stop": ["</s>", "###"],
} },
} }
body_bytes = json.dumps(payload).encode('utf-8') body_bytes = json.dumps(payload).encode("utf-8")
# Invoke the endpoint # Invoke the endpoint
response = self.runtime.invoke_endpoint(EndpointName=self.endpoint, response = self.runtime.invoke_endpoint(
ContentType='application/json', EndpointName=self.endpoint, ContentType="application/json", Body=body_bytes
Body=body_bytes) )
result = json.loads(response['Body'].read().decode()) result = json.loads(response["Body"].read().decode())
import sys import sys
print(result[0]['generated_text'], file=sys.stderr)
return result[0]['generated_text'][len(prompt):]
def gen_stream(self, model, messages, stream=True, **kwargs): print(result[0]["generated_text"], file=sys.stderr)
context = messages[0]['content'] return result[0]["generated_text"][len(prompt) :]
user_question = messages[-1]['content']
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
# Construct payload for endpoint # Construct payload for endpoint
payload = { payload = {
@ -118,22 +119,22 @@ class SagemakerAPILLM(BaseLLM):
"temperature": 0.1, "temperature": 0.1,
"max_new_tokens": 512, "max_new_tokens": 512,
"repetition_penalty": 1.03, "repetition_penalty": 1.03,
"stop": ["</s>", "###"] "stop": ["</s>", "###"],
} },
} }
body_bytes = json.dumps(payload).encode('utf-8') body_bytes = json.dumps(payload).encode("utf-8")
# Invoke the endpoint # Invoke the endpoint
response = self.runtime.invoke_endpoint_with_response_stream(EndpointName=self.endpoint, response = self.runtime.invoke_endpoint_with_response_stream(
ContentType='application/json', EndpointName=self.endpoint, ContentType="application/json", Body=body_bytes
Body=body_bytes) )
#result = json.loads(response['Body'].read().decode()) # result = json.loads(response['Body'].read().decode())
event_stream = response['Body'] event_stream = response["Body"]
start_json = b'{' start_json = b"{"
for line in LineIterator(event_stream): for line in LineIterator(event_stream):
if line != b'' and start_json in line: if line != b"" and start_json in line:
#print(line) # print(line)
data = json.loads(line[line.find(start_json):].decode('utf-8')) data = json.loads(line[line.find(start_json) :].decode("utf-8"))
if data['token']['text'] not in ["</s>", "###"]: if data["token"]["text"] not in ["</s>", "###"]:
print(data['token']['text'],end='') print(data["token"]["text"], end="")
yield data['token']['text'] yield data["token"]["text"]

@ -6,43 +6,54 @@ from application.utils import count_tokens
from langchain_community.tools import BraveSearch from langchain_community.tools import BraveSearch
class BraveRetSearch(BaseRetriever): class BraveRetSearch(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): def __init__(
self,
question,
source,
chat_history,
prompt,
chunks=2,
gpt_model="docsgpt",
user_api_key=None,
):
self.question = question self.question = question
self.source = source self.source = source
self.chat_history = chat_history self.chat_history = chat_history
self.prompt = prompt self.prompt = prompt
self.chunks = chunks self.chunks = chunks
self.gpt_model = gpt_model self.gpt_model = gpt_model
self.user_api_key = user_api_key
def _get_data(self): def _get_data(self):
if self.chunks == 0: if self.chunks == 0:
docs = [] docs = []
else: else:
search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY, search = BraveSearch.from_api_key(
search_kwargs={"count": int(self.chunks)}) api_key=settings.BRAVE_SEARCH_API_KEY,
search_kwargs={"count": int(self.chunks)},
)
results = search.run(self.question) results = search.run(self.question)
results = json.loads(results) results = json.loads(results)
docs = [] docs = []
for i in results: for i in results:
try: try:
title = i['title'] title = i["title"]
link = i['link'] link = i["link"]
snippet = i['snippet'] snippet = i["snippet"]
docs.append({"text": snippet, "title": title, "link": link}) docs.append({"text": snippet, "title": title, "link": link})
except IndexError: except IndexError:
pass pass
if settings.LLM_NAME == "llama.cpp": if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]] docs = [docs[0]]
return docs return docs
def gen(self): def gen(self):
docs = self._get_data() docs = self._get_data()
# join all page_content together with a newline # join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs]) docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together) p_chat_combine = self.prompt.replace("{summaries}", docs_together)
@ -56,20 +67,29 @@ class BraveRetSearch(BaseRetriever):
self.chat_history.reverse() self.chat_history.reverse()
for i in self.chat_history: for i in self.chat_history:
if "prompt" in i and "response" in i: if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) tokens_batch = count_tokens(i["prompt"]) + count_tokens(
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: i["response"]
)
if (
tokens_current_history + tokens_batch
< settings.TOKENS_MAX_HISTORY
):
tokens_current_history += tokens_batch tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]}) messages_combine.append(
messages_combine.append({"role": "system", "content": i["response"]}) {"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question}) messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
)
completion = llm.gen_stream(model=self.gpt_model, completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
messages=messages_combine)
for line in completion: for line in completion:
yield {"answer": str(line)} yield {"answer": str(line)}
def search(self): def search(self):
return self._get_data() return self._get_data()

@ -7,21 +7,30 @@ from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens from application.utils import count_tokens
class ClassicRAG(BaseRetriever): class ClassicRAG(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): def __init__(
self,
question,
source,
chat_history,
prompt,
chunks=2,
gpt_model="docsgpt",
user_api_key=None,
):
self.question = question self.question = question
self.vectorstore = self._get_vectorstore(source=source) self.vectorstore = self._get_vectorstore(source=source)
self.chat_history = chat_history self.chat_history = chat_history
self.prompt = prompt self.prompt = prompt
self.chunks = chunks self.chunks = chunks
self.gpt_model = gpt_model self.gpt_model = gpt_model
self.user_api_key = user_api_key
def _get_vectorstore(self, source): def _get_vectorstore(self, source):
if "active_docs" in source: if "active_docs" in source:
if source["active_docs"].split("/")[0] == "default": if source["active_docs"].split("/")[0] == "default":
vectorstore = "" vectorstore = ""
elif source["active_docs"].split("/")[0] == "local": elif source["active_docs"].split("/")[0] == "local":
vectorstore = "indexes/" + source["active_docs"] vectorstore = "indexes/" + source["active_docs"]
else: else:
@ -33,32 +42,33 @@ class ClassicRAG(BaseRetriever):
vectorstore = os.path.join("application", vectorstore) vectorstore = os.path.join("application", vectorstore)
return vectorstore return vectorstore
def _get_data(self): def _get_data(self):
if self.chunks == 0: if self.chunks == 0:
docs = [] docs = []
else: else:
docsearch = VectorCreator.create_vectorstore( docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
self.vectorstore,
settings.EMBEDDINGS_KEY
) )
docs_temp = docsearch.search(self.question, k=self.chunks) docs_temp = docsearch.search(self.question, k=self.chunks)
docs = [ docs = [
{ {
"title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, "title": (
"text": i.page_content i.metadata["title"].split("/")[-1]
} if i.metadata
else i.page_content
),
"text": i.page_content,
}
for i in docs_temp for i in docs_temp
] ]
if settings.LLM_NAME == "llama.cpp": if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]] docs = [docs[0]]
return docs return docs
def gen(self): def gen(self):
docs = self._get_data() docs = self._get_data()
# join all page_content together with a newline # join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs]) docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together) p_chat_combine = self.prompt.replace("{summaries}", docs_together)
@ -72,20 +82,29 @@ class ClassicRAG(BaseRetriever):
self.chat_history.reverse() self.chat_history.reverse()
for i in self.chat_history: for i in self.chat_history:
if "prompt" in i and "response" in i: if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) tokens_batch = count_tokens(i["prompt"]) + count_tokens(
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: i["response"]
)
if (
tokens_current_history + tokens_batch
< settings.TOKENS_MAX_HISTORY
):
tokens_current_history += tokens_batch tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]}) messages_combine.append(
messages_combine.append({"role": "system", "content": i["response"]}) {"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question}) messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
)
completion = llm.gen_stream(model=self.gpt_model, completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
messages=messages_combine)
for line in completion: for line in completion:
yield {"answer": str(line)} yield {"answer": str(line)}
def search(self): def search(self):
return self._get_data() return self._get_data()

@ -6,16 +6,25 @@ from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
class DuckDuckSearch(BaseRetriever): class DuckDuckSearch(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): def __init__(
self,
question,
source,
chat_history,
prompt,
chunks=2,
gpt_model="docsgpt",
user_api_key=None,
):
self.question = question self.question = question
self.source = source self.source = source
self.chat_history = chat_history self.chat_history = chat_history
self.prompt = prompt self.prompt = prompt
self.chunks = chunks self.chunks = chunks
self.gpt_model = gpt_model self.gpt_model = gpt_model
self.user_api_key = user_api_key
def _parse_lang_string(self, input_string): def _parse_lang_string(self, input_string):
result = [] result = []
@ -30,12 +39,12 @@ class DuckDuckSearch(BaseRetriever):
current_item = "" current_item = ""
elif inside_brackets: elif inside_brackets:
current_item += char current_item += char
if inside_brackets: if inside_brackets:
result.append(current_item) result.append(current_item)
return result return result
def _get_data(self): def _get_data(self):
if self.chunks == 0: if self.chunks == 0:
docs = [] docs = []
@ -44,7 +53,7 @@ class DuckDuckSearch(BaseRetriever):
search = DuckDuckGoSearchResults(api_wrapper=wrapper) search = DuckDuckGoSearchResults(api_wrapper=wrapper)
results = search.run(self.question) results = search.run(self.question)
results = self._parse_lang_string(results) results = self._parse_lang_string(results)
docs = [] docs = []
for i in results: for i in results:
try: try:
@ -56,12 +65,12 @@ class DuckDuckSearch(BaseRetriever):
pass pass
if settings.LLM_NAME == "llama.cpp": if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]] docs = [docs[0]]
return docs return docs
def gen(self): def gen(self):
docs = self._get_data() docs = self._get_data()
# join all page_content together with a newline # join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs]) docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together) p_chat_combine = self.prompt.replace("{summaries}", docs_together)
@ -75,20 +84,29 @@ class DuckDuckSearch(BaseRetriever):
self.chat_history.reverse() self.chat_history.reverse()
for i in self.chat_history: for i in self.chat_history:
if "prompt" in i and "response" in i: if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) tokens_batch = count_tokens(i["prompt"]) + count_tokens(
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: i["response"]
)
if (
tokens_current_history + tokens_batch
< settings.TOKENS_MAX_HISTORY
):
tokens_current_history += tokens_batch tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]}) messages_combine.append(
messages_combine.append({"role": "system", "content": i["response"]}) {"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question}) messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
)
completion = llm.gen_stream(model=self.gpt_model, completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
messages=messages_combine)
for line in completion: for line in completion:
yield {"answer": str(line)} yield {"answer": str(line)}
def search(self): def search(self):
return self._get_data() return self._get_data()

@ -0,0 +1,49 @@
import sys
from pymongo import MongoClient
from datetime import datetime
from application.core.settings import settings
from application.utils import count_tokens
mongo = MongoClient(settings.MONGO_URI)
db = mongo["docsgpt"]
usage_collection = db["token_usage"]
def update_token_usage(user_api_key, token_usage):
if "pytest" in sys.modules:
return
usage_data = {
"api_key": user_api_key,
"prompt_tokens": token_usage["prompt_tokens"],
"generated_tokens": token_usage["generated_tokens"],
"timestamp": datetime.now(),
}
usage_collection.insert_one(usage_data)
def gen_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += count_tokens(message["content"])
result = func(self, model, messages, stream, **kwargs)
self.token_usage["generated_tokens"] += count_tokens(result)
update_token_usage(self.user_api_key, self.token_usage)
return result
return wrapper
def stream_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += count_tokens(message["content"])
batch = []
result = func(self, model, messages, stream, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
self.token_usage["generated_tokens"] += count_tokens(line)
update_token_usage(self.user_api_key, self.token_usage)
return wrapper
Loading…
Cancel
Save