diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index fa0ac4f..af3c5f0 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -10,14 +10,12 @@ from pymongo import MongoClient from bson.objectid import ObjectId - from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator from application.error import bad_request - logger = logging.getLogger(__name__) mongo = MongoClient(settings.MONGO_URI) @@ -26,20 +24,22 @@ conversations_collection = db["conversations"] vectors_collection = db["vectors"] prompts_collection = db["prompts"] api_key_collection = db["api_keys"] -answer = Blueprint('answer', __name__) +answer = Blueprint("answer", __name__) gpt_model = "" # to have some kind of default behaviour if settings.LLM_NAME == "openai": - gpt_model = 'gpt-3.5-turbo' + gpt_model = "gpt-3.5-turbo" elif settings.LLM_NAME == "anthropic": - gpt_model = 'claude-2' + gpt_model = "claude-2" if settings.MODEL_NAME: # in case there is particular model name configured gpt_model = settings.MODEL_NAME # load the prompts -current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +current_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: chat_combine_template = f.read() @@ -50,7 +50,7 @@ with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r" chat_combine_creative = f.read() with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: - chat_combine_strict = f.read() + chat_combine_strict = f.read() api_key_set = settings.API_KEY is not None embeddings_key_set = settings.EMBEDDINGS_KEY is not None @@ -61,8 +61,6 @@ async def async_generate(chain, question, chat_history): return result - - def run_async_chain(chain, question, chat_history): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -74,17 +72,18 @@ def run_async_chain(chain, question, chat_history): result["answer"] = answer return result + def get_data_from_api_key(api_key): data = api_key_collection.find_one({"key": api_key}) if data is None: return bad_request(401, "Invalid API key") return data - + def get_vectorstore(data): if "active_docs" in data: if data["active_docs"].split("/")[0] == "default": - vectorstore = "" + vectorstore = "" elif data["active_docs"].split("/")[0] == "local": vectorstore = "indexes/" + data["active_docs"] else: @@ -98,52 +97,82 @@ def get_vectorstore(data): def is_azure_configured(): - return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME + return ( + settings.OPENAI_API_BASE + and settings.OPENAI_API_VERSION + and settings.AZURE_DEPLOYMENT_NAME + ) + def save_conversation(conversation_id, question, response, source_log_docs, llm): if conversation_id is not None and conversation_id != "None": conversations_collection.update_one( {"_id": ObjectId(conversation_id)}, - {"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}}, + { + "$push": { + "queries": { + "prompt": question, + "response": response, + "sources": source_log_docs, + } + } + }, ) else: # create new conversation # generate summary - messages_summary = [{"role": "assistant", "content": "Summarise following conversation in no more than 3 " - "words, respond ONLY with the summary, use the same " - "language as the system \n\nUser: " + question + "\n\n" + - "AI: " + - response}, - {"role": "user", "content": "Summarise following conversation in no more than 3 words, " - "respond ONLY with the summary, use the same language as the " - "system"}] - - completion = llm.gen(model=gpt_model, - messages=messages_summary, max_tokens=30) + messages_summary = [ + { + "role": "assistant", + "content": "Summarise following conversation in no more than 3 " + "words, respond ONLY with the summary, use the same " + "language as the system \n\nUser: " + + question + + "\n\n" + + "AI: " + + response, + }, + { + "role": "user", + "content": "Summarise following conversation in no more than 3 words, " + "respond ONLY with the summary, use the same language as the " + "system", + }, + ] + + completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30) conversation_id = conversations_collection.insert_one( - {"user": "local", - "date": datetime.datetime.utcnow(), - "name": completion, - "queries": [{"prompt": question, "response": response, "sources": source_log_docs}]} + { + "user": "local", + "date": datetime.datetime.utcnow(), + "name": completion, + "queries": [ + { + "prompt": question, + "response": response, + "sources": source_log_docs, + } + ], + } ).inserted_id return conversation_id + def get_prompt(prompt_id): - if prompt_id == 'default': + if prompt_id == "default": prompt = chat_combine_template - elif prompt_id == 'creative': + elif prompt_id == "creative": prompt = chat_combine_creative - elif prompt_id == 'strict': + elif prompt_id == "strict": prompt = chat_combine_strict else: prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] return prompt -def complete_stream(question, retriever, conversation_id): - - +def complete_stream(question, retriever, conversation_id, user_api_key): + response_full = "" source_log_docs = [] answer = retriever.gen() @@ -155,9 +184,12 @@ def complete_stream(question, retriever, conversation_id): elif "source" in line: source_log_docs.append(line["source"]) - - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) - conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + ) + conversation_id = save_conversation( + conversation_id, question, response_full, source_log_docs, llm + ) # send data.type = "end" to indicate that the stream has ended as json data = json.dumps({"type": "id", "id": str(conversation_id)}) @@ -180,17 +212,17 @@ def stream(): conversation_id = None else: conversation_id = data["conversation_id"] - if 'prompt_id' in data: + if "prompt_id" in data: prompt_id = data["prompt_id"] else: - prompt_id = 'default' - if 'selectedDocs' in data and data['selectedDocs'] is None: + prompt_id = "default" + if "selectedDocs" in data and data["selectedDocs"] is None: chunks = 0 - elif 'chunks' in data: + elif "chunks" in data: chunks = int(data["chunks"]) else: chunks = 2 - + prompt = get_prompt(prompt_id) # check if active_docs is set @@ -198,23 +230,42 @@ def stream(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) source = {"active_docs": data_key["source"]} + user_api_key = data["api_key"] elif "active_docs" in data: source = {"active_docs": data["active_docs"]} + user_api_key = None else: source = {} + user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + if ( + source["active_docs"].split("/")[0] == "default" + or source["active_docs"].split("/")[0] == "local" + ): retriever_name = "classic" else: - retriever_name = source['active_docs'] - - retriever = RetrieverCreator.create_retriever(retriever_name, question=question, - source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model - ) + retriever_name = source["active_docs"] + + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + gpt_model=gpt_model, + user_api_key=user_api_key, + ) return Response( - complete_stream(question=question, retriever=retriever, - conversation_id=conversation_id), mimetype="text/event-stream") + complete_stream( + question=question, + retriever=retriever, + conversation_id=conversation_id, + user_api_key=user_api_key, + ), + mimetype="text/event-stream", + ) @answer.route("/api/answer", methods=["POST"]) @@ -230,15 +281,15 @@ def api_answer(): else: conversation_id = data["conversation_id"] print("-" * 5) - if 'prompt_id' in data: + if "prompt_id" in data: prompt_id = data["prompt_id"] else: - prompt_id = 'default' - if 'chunks' in data: + prompt_id = "default" + if "chunks" in data: chunks = int(data["chunks"]) else: chunks = 2 - + prompt = get_prompt(prompt_id) # use try and except to check for exception @@ -247,17 +298,29 @@ def api_answer(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) source = {"active_docs": data_key["source"]} + user_api_key = data["api_key"] else: source = {data} + user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + if ( + source["active_docs"].split("/")[0] == "default" + or source["active_docs"].split("/")[0] == "local" + ): retriever_name = "classic" else: - retriever_name = source['active_docs'] - - retriever = RetrieverCreator.create_retriever(retriever_name, question=question, - source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model - ) + retriever_name = source["active_docs"] + + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + gpt_model=gpt_model, + user_api_key=user_api_key, + ) source_log_docs = [] response_full = "" for line in retriever.gen(): @@ -265,12 +328,15 @@ def api_answer(): source_log_docs.append(line["source"]) elif "answer" in line: response_full += line["answer"] - - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) - + + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + ) result = {"answer": response_full, "sources": source_log_docs} - result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm) + result["conversation_id"] = save_conversation( + conversation_id, question, response_full, source_log_docs, llm + ) return result except Exception as e: @@ -289,23 +355,35 @@ def api_search(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) source = {"active_docs": data_key["source"]} + user_api_key = data["api_key"] elif "active_docs" in data: source = {"active_docs": data["active_docs"]} + user_api_key = None else: source = {} - if 'chunks' in data: + user_api_key = None + if "chunks" in data: chunks = int(data["chunks"]) else: chunks = 2 - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + if ( + source["active_docs"].split("/")[0] == "default" + or source["active_docs"].split("/")[0] == "local" + ): retriever_name = "classic" else: - retriever_name = source['active_docs'] - - retriever = RetrieverCreator.create_retriever(retriever_name, question=question, - source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model - ) + retriever_name = source["active_docs"] + + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=[], + prompt="default", + chunks=chunks, + gpt_model=gpt_model, + user_api_key=user_api_key, + ) docs = retriever.search() return docs - diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index 6b0d646..4081bcd 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -1,21 +1,29 @@ from application.llm.base import BaseLLM from application.core.settings import settings + class AnthropicLLM(BaseLLM): - def __init__(self, api_key=None): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT - self.api_key = api_key or settings.ANTHROPIC_API_KEY # If not provided, use a default from settings + + super().__init__(*args, **kwargs) + self.api_key = ( + api_key or settings.ANTHROPIC_API_KEY + ) # If not provided, use a default from settings + self.user_api_key = user_api_key self.anthropic = Anthropic(api_key=self.api_key) self.HUMAN_PROMPT = HUMAN_PROMPT self.AI_PROMPT = AI_PROMPT - def gen(self, model, messages, max_tokens=300, stream=False, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen( + self, baseself, model, messages, stream=False, max_tokens=300, **kwargs + ): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" if stream: - return self.gen_stream(model, prompt, max_tokens, **kwargs) + return self.gen_stream(model, prompt, stream, max_tokens, **kwargs) completion = self.anthropic.completions.create( model=model, @@ -25,9 +33,11 @@ class AnthropicLLM(BaseLLM): ) return completion.completion - def gen_stream(self, model, messages, max_tokens=300, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen_stream( + self, baseself, model, messages, stream=True, max_tokens=300, **kwargs + ): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" stream_response = self.anthropic.completions.create( model=model, @@ -37,4 +47,4 @@ class AnthropicLLM(BaseLLM): ) for completion in stream_response: - yield completion.completion \ No newline at end of file + yield completion.completion diff --git a/application/llm/base.py b/application/llm/base.py index e08a3b0..475b793 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -1,14 +1,28 @@ from abc import ABC, abstractmethod +from application.usage import gen_token_usage, stream_token_usage class BaseLLM(ABC): def __init__(self): - pass + self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} + + def _apply_decorator(self, method, decorator, *args, **kwargs): + return decorator(method, *args, **kwargs) @abstractmethod - def gen(self, *args, **kwargs): + def _raw_gen(self, model, messages, stream, *args, **kwargs): pass + def gen(self, model, messages, stream=False, *args, **kwargs): + return self._apply_decorator(self._raw_gen, gen_token_usage)( + self, model=model, messages=messages, stream=stream, *args, **kwargs + ) + @abstractmethod - def gen_stream(self, *args, **kwargs): + def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): pass + + def gen_stream(self, model, messages, stream=True, *args, **kwargs): + return self._apply_decorator(self._raw_gen_stream, stream_token_usage)( + self, model=model, messages=messages, stream=stream, *args, **kwargs + ) diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index d540a91..bca3972 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -2,48 +2,43 @@ from application.llm.base import BaseLLM import json import requests -class DocsGPTAPILLM(BaseLLM): - def __init__(self, *args, **kwargs): - self.endpoint = "https://llm.docsgpt.co.uk" +class DocsGPTAPILLM(BaseLLM): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.api_key = api_key + self.user_api_key = user_api_key + self.endpoint = "https://llm.docsgpt.co.uk" - def gen(self, model, messages, stream=False, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" response = requests.post( - f"{self.endpoint}/answer", - json={ - "prompt": prompt, - "max_new_tokens": 30 - } + f"{self.endpoint}/answer", json={"prompt": prompt, "max_new_tokens": 30} ) - response_clean = response.json()['a'].replace("###", "") + response_clean = response.json()["a"].replace("###", "") return response_clean - def gen_stream(self, model, messages, stream=True, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen_stream(self, baseself, model, messages, stream=True, *args, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" # send prompt to endpoint /stream response = requests.post( f"{self.endpoint}/stream", - json={ - "prompt": prompt, - "max_new_tokens": 256 - }, - stream=True + json={"prompt": prompt, "max_new_tokens": 256}, + stream=True, ) - + for line in response.iter_lines(): if line: - #data = json.loads(line) - data_str = line.decode('utf-8') + # data = json.loads(line) + data_str = line.decode("utf-8") if data_str.startswith("data: "): data = json.loads(data_str[6:]) - yield data['a'] - \ No newline at end of file + yield data["a"] diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index 554bee2..2fb4a92 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -1,44 +1,68 @@ from application.llm.base import BaseLLM + class HuggingFaceLLM(BaseLLM): - def __init__(self, api_key, llm_name='Arc53/DocsGPT-7B',q=False): + def __init__( + self, + api_key=None, + user_api_key=None, + llm_name="Arc53/DocsGPT-7B", + q=False, + *args, + **kwargs, + ): global hf - + from langchain.llms import HuggingFacePipeline + if q: import torch - from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + pipeline, + BitsAndBytesConfig, + ) + tokenizer = AutoTokenizer.from_pretrained(llm_name) bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16 - ) - model = AutoModelForCausalLM.from_pretrained(llm_name,quantization_config=bnb_config) + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + model = AutoModelForCausalLM.from_pretrained( + llm_name, quantization_config=bnb_config + ) else: from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline + tokenizer = AutoTokenizer.from_pretrained(llm_name) model = AutoModelForCausalLM.from_pretrained(llm_name) - + + super().__init__(*args, **kwargs) + self.api_key = api_key + self.user_api_key = user_api_key pipe = pipeline( - "text-generation", model=model, - tokenizer=tokenizer, max_new_tokens=2000, - device_map="auto", eos_token_id=tokenizer.eos_token_id + "text-generation", + model=model, + tokenizer=tokenizer, + max_new_tokens=2000, + device_map="auto", + eos_token_id=tokenizer.eos_token_id, ) hf = HuggingFacePipeline(pipeline=pipe) - def gen(self, model, messages, stream=False, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" result = hf(prompt) return result.content - def gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.") - diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index be34d4f..25a2f0c 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -1,32 +1,45 @@ from application.llm.base import BaseLLM from application.core.settings import settings + class LlamaCpp(BaseLLM): - def __init__(self, api_key, llm_name=settings.MODEL_PATH, **kwargs): + def __init__( + self, + api_key=None, + user_api_key=None, + llm_name=settings.MODEL_PATH, + *args, + **kwargs, + ): global llama try: from llama_cpp import Llama except ImportError: - raise ImportError("Please install llama_cpp using pip install llama-cpp-python") + raise ImportError( + "Please install llama_cpp using pip install llama-cpp-python" + ) + super().__init__(*args, **kwargs) + self.api_key = api_key + self.user_api_key = user_api_key llama = Llama(model_path=llm_name, n_ctx=2048) - def gen(self, model, messages, stream=False, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" result = llama(prompt, max_tokens=150, echo=False) # import sys # print(result['choices'][0]['text'].split('### Answer \n')[-1], file=sys.stderr) - - return result['choices'][0]['text'].split('### Answer \n')[-1] - def gen_stream(self, model, messages, stream=True, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + return result["choices"][0]["text"].split("### Answer \n")[-1] + + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" result = llama(prompt, max_tokens=150, echo=False, stream=stream) @@ -35,5 +48,5 @@ class LlamaCpp(BaseLLM): # print(list(result), file=sys.stderr) for item in result: - for choice in item['choices']: - yield choice['text'] + for choice in item["choices"]: + yield choice["text"] diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index b4fdaeb..7960778 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -7,22 +7,21 @@ from application.llm.docsgpt_provider import DocsGPTAPILLM from application.llm.premai import PremAILLM - class LLMCreator: llms = { - 'openai': OpenAILLM, - 'azure_openai': AzureOpenAILLM, - 'sagemaker': SagemakerAPILLM, - 'huggingface': HuggingFaceLLM, - 'llama.cpp': LlamaCpp, - 'anthropic': AnthropicLLM, - 'docsgpt': DocsGPTAPILLM, - 'premai': PremAILLM, + "openai": OpenAILLM, + "azure_openai": AzureOpenAILLM, + "sagemaker": SagemakerAPILLM, + "huggingface": HuggingFaceLLM, + "llama.cpp": LlamaCpp, + "anthropic": AnthropicLLM, + "docsgpt": DocsGPTAPILLM, + "premai": PremAILLM, } @classmethod - def create_llm(cls, type, *args, **kwargs): + def create_llm(cls, type, api_key, user_api_key, *args, **kwargs): llm_class = cls.llms.get(type.lower()) if not llm_class: raise ValueError(f"No LLM class found for type {type}") - return llm_class(*args, **kwargs) \ No newline at end of file + return llm_class(api_key, user_api_key, *args, **kwargs) diff --git a/application/llm/openai.py b/application/llm/openai.py index 4b0ed25..b1574dd 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -1,36 +1,53 @@ from application.llm.base import BaseLLM from application.core.settings import settings + class OpenAILLM(BaseLLM): - def __init__(self, api_key): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): global openai from openai import OpenAI - + + super().__init__(*args, **kwargs) self.client = OpenAI( - api_key=api_key, - ) + api_key=api_key, + ) self.api_key = api_key + self.user_api_key = user_api_key def _get_openai(self): # Import openai when needed import openai - + return openai - def gen(self, model, messages, stream=False, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): - response = self.client.chat.completions.create(model=model, - messages=messages, - stream=stream, - **kwargs) + def _raw_gen( + self, + baseself, + model, + messages, + stream=False, + engine=settings.AZURE_DEPLOYMENT_NAME, + **kwargs + ): + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) return response.choices[0].message.content - def gen_stream(self, model, messages, stream=True, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): - response = self.client.chat.completions.create(model=model, - messages=messages, - stream=stream, - **kwargs) + def _raw_gen_stream( + self, + baseself, + model, + messages, + stream=True, + engine=settings.AZURE_DEPLOYMENT_NAME, + **kwargs + ): + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) for line in response: # import sys @@ -41,14 +58,17 @@ class OpenAILLM(BaseLLM): class AzureOpenAILLM(OpenAILLM): - def __init__(self, openai_api_key, openai_api_base, openai_api_version, deployment_name): + def __init__( + self, openai_api_key, openai_api_base, openai_api_version, deployment_name + ): super().__init__(openai_api_key) - self.api_base = settings.OPENAI_API_BASE, - self.api_version = settings.OPENAI_API_VERSION, - self.deployment_name = settings.AZURE_DEPLOYMENT_NAME, + self.api_base = (settings.OPENAI_API_BASE,) + self.api_version = (settings.OPENAI_API_VERSION,) + self.deployment_name = (settings.AZURE_DEPLOYMENT_NAME,) from openai import AzureOpenAI + self.client = AzureOpenAI( - api_key=openai_api_key, + api_key=openai_api_key, api_version=settings.OPENAI_API_VERSION, api_base=settings.OPENAI_API_BASE, deployment_name=settings.AZURE_DEPLOYMENT_NAME, diff --git a/application/llm/premai.py b/application/llm/premai.py index 5faa5fe..c250c65 100644 --- a/application/llm/premai.py +++ b/application/llm/premai.py @@ -1,32 +1,37 @@ from application.llm.base import BaseLLM from application.core.settings import settings + class PremAILLM(BaseLLM): - def __init__(self, api_key): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): from premai import Prem - - self.client = Prem( - api_key=api_key - ) + + super().__init__(*args, **kwargs) + self.client = Prem(api_key=api_key) self.api_key = api_key + self.user_api_key = user_api_key self.project_id = settings.PREMAI_PROJECT_ID - def gen(self, model, messages, stream=False, **kwargs): - response = self.client.chat.completions.create(model=model, + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): + response = self.client.chat.completions.create( + model=model, project_id=self.project_id, messages=messages, stream=stream, - **kwargs) + **kwargs + ) return response.choices[0].message["content"] - def gen_stream(self, model, messages, stream=True, **kwargs): - response = self.client.chat.completions.create(model=model, + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): + response = self.client.chat.completions.create( + model=model, project_id=self.project_id, messages=messages, stream=stream, - **kwargs) + **kwargs + ) for line in response: if line.choices[0].delta["content"] is not None: diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index b81f638..6394743 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -4,11 +4,10 @@ import json import io - class LineIterator: """ - A helper class for parsing the byte stream input. - + A helper class for parsing the byte stream input. + The output of the model will be in the following format: ``` b'{"outputs": [" a"]}\n' @@ -16,21 +15,21 @@ class LineIterator: b'{"outputs": [" problem"]}\n' ... ``` - - While usually each PayloadPart event from the event stream will contain a byte array + + While usually each PayloadPart event from the event stream will contain a byte array with a full json, this is not guaranteed and some of the json objects may be split across PayloadPart events. For example: ``` {'PayloadPart': {'Bytes': b'{"outputs": '}} {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} ``` - + This class accounts for this by concatenating bytes written via the 'write' function and then exposing a method which will return lines (ending with a '\n' character) within - the buffer via the 'scan_lines' function. It maintains the position of the last read - position to ensure that previous bytes are not exposed again. + the buffer via the 'scan_lines' function. It maintains the position of the last read + position to ensure that previous bytes are not exposed again. """ - + def __init__(self, stream): self.byte_iterator = iter(stream) self.buffer = io.BytesIO() @@ -43,7 +42,7 @@ class LineIterator: while True: self.buffer.seek(self.read_pos) line = self.buffer.readline() - if line and line[-1] == ord('\n'): + if line and line[-1] == ord("\n"): self.read_pos += len(line) return line[:-1] try: @@ -52,33 +51,35 @@ class LineIterator: if self.read_pos < self.buffer.getbuffer().nbytes: continue raise - if 'PayloadPart' not in chunk: - print('Unknown event type:' + chunk) + if "PayloadPart" not in chunk: + print("Unknown event type:" + chunk) continue self.buffer.seek(0, io.SEEK_END) - self.buffer.write(chunk['PayloadPart']['Bytes']) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) + class SagemakerAPILLM(BaseLLM): - def __init__(self, *args, **kwargs): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): import boto3 + runtime = boto3.client( - 'runtime.sagemaker', - aws_access_key_id='xxx', - aws_secret_access_key='xxx', - region_name='us-west-2' + "runtime.sagemaker", + aws_access_key_id="xxx", + aws_secret_access_key="xxx", + region_name="us-west-2", ) - - self.endpoint = settings.SAGEMAKER_ENDPOINT + super().__init__(*args, **kwargs) + self.api_key = api_key + self.user_api_key = user_api_key + self.endpoint = settings.SAGEMAKER_ENDPOINT self.runtime = runtime - - def gen(self, model, messages, stream=False, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - # Construct payload for endpoint payload = { @@ -89,25 +90,25 @@ class SagemakerAPILLM(BaseLLM): "temperature": 0.1, "max_new_tokens": 30, "repetition_penalty": 1.03, - "stop": ["", "###"] - } + "stop": ["", "###"], + }, } - body_bytes = json.dumps(payload).encode('utf-8') + body_bytes = json.dumps(payload).encode("utf-8") # Invoke the endpoint - response = self.runtime.invoke_endpoint(EndpointName=self.endpoint, - ContentType='application/json', - Body=body_bytes) - result = json.loads(response['Body'].read().decode()) + response = self.runtime.invoke_endpoint( + EndpointName=self.endpoint, ContentType="application/json", Body=body_bytes + ) + result = json.loads(response["Body"].read().decode()) import sys - print(result[0]['generated_text'], file=sys.stderr) - return result[0]['generated_text'][len(prompt):] - def gen_stream(self, model, messages, stream=True, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + print(result[0]["generated_text"], file=sys.stderr) + return result[0]["generated_text"][len(prompt) :] + + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - # Construct payload for endpoint payload = { @@ -118,22 +119,22 @@ class SagemakerAPILLM(BaseLLM): "temperature": 0.1, "max_new_tokens": 512, "repetition_penalty": 1.03, - "stop": ["", "###"] - } + "stop": ["", "###"], + }, } - body_bytes = json.dumps(payload).encode('utf-8') + body_bytes = json.dumps(payload).encode("utf-8") # Invoke the endpoint - response = self.runtime.invoke_endpoint_with_response_stream(EndpointName=self.endpoint, - ContentType='application/json', - Body=body_bytes) - #result = json.loads(response['Body'].read().decode()) - event_stream = response['Body'] - start_json = b'{' + response = self.runtime.invoke_endpoint_with_response_stream( + EndpointName=self.endpoint, ContentType="application/json", Body=body_bytes + ) + # result = json.loads(response['Body'].read().decode()) + event_stream = response["Body"] + start_json = b"{" for line in LineIterator(event_stream): - if line != b'' and start_json in line: - #print(line) - data = json.loads(line[line.find(start_json):].decode('utf-8')) - if data['token']['text'] not in ["", "###"]: - print(data['token']['text'],end='') - yield data['token']['text'] \ No newline at end of file + if line != b"" and start_json in line: + # print(line) + data = json.loads(line[line.find(start_json) :].decode("utf-8")) + if data["token"]["text"] not in ["", "###"]: + print(data["token"]["text"], end="") + yield data["token"]["text"] diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 0cc7bd4..47ca0e7 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -6,43 +6,54 @@ from application.utils import count_tokens from langchain_community.tools import BraveSearch - class BraveRetSearch(BaseRetriever): - def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + def __init__( + self, + question, + source, + chat_history, + prompt, + chunks=2, + gpt_model="docsgpt", + user_api_key=None, + ): self.question = question self.source = source self.chat_history = chat_history self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model - + self.user_api_key = user_api_key + def _get_data(self): if self.chunks == 0: docs = [] else: - search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY, - search_kwargs={"count": int(self.chunks)}) + search = BraveSearch.from_api_key( + api_key=settings.BRAVE_SEARCH_API_KEY, + search_kwargs={"count": int(self.chunks)}, + ) results = search.run(self.question) results = json.loads(results) - + docs = [] for i in results: try: - title = i['title'] - link = i['link'] - snippet = i['snippet'] + title = i["title"] + link = i["link"] + snippet = i["snippet"] docs.append({"text": snippet, "title": title, "link": link}) except IndexError: pass if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] - + return docs - + def gen(self): docs = self._get_data() - + # join all page_content together with a newline docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) @@ -56,20 +67,29 @@ class BraveRetSearch(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_batch = count_tokens(i["prompt"]) + count_tokens( + i["response"] + ) + if ( + tokens_current_history + tokens_batch + < settings.TOKENS_MAX_HISTORY + ): tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append( + {"role": "user", "content": i["prompt"]} + ) + messages_combine.append( + {"role": "system", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + ) - completion = llm.gen_stream(model=self.gpt_model, - messages=messages_combine) + completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: yield {"answer": str(line)} - + def search(self): return self._get_data() - diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index b5f1eb9..1bce6f8 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -7,21 +7,30 @@ from application.llm.llm_creator import LLMCreator from application.utils import count_tokens - class ClassicRAG(BaseRetriever): - def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + def __init__( + self, + question, + source, + chat_history, + prompt, + chunks=2, + gpt_model="docsgpt", + user_api_key=None, + ): self.question = question self.vectorstore = self._get_vectorstore(source=source) self.chat_history = chat_history self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model + self.user_api_key = user_api_key def _get_vectorstore(self, source): if "active_docs" in source: if source["active_docs"].split("/")[0] == "default": - vectorstore = "" + vectorstore = "" elif source["active_docs"].split("/")[0] == "local": vectorstore = "indexes/" + source["active_docs"] else: @@ -33,32 +42,33 @@ class ClassicRAG(BaseRetriever): vectorstore = os.path.join("application", vectorstore) return vectorstore - def _get_data(self): if self.chunks == 0: docs = [] else: docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, - self.vectorstore, - settings.EMBEDDINGS_KEY + settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY ) docs_temp = docsearch.search(self.question, k=self.chunks) docs = [ { - "title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, - "text": i.page_content - } + "title": ( + i.metadata["title"].split("/")[-1] + if i.metadata + else i.page_content + ), + "text": i.page_content, + } for i in docs_temp ] if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] - + return docs - + def gen(self): docs = self._get_data() - + # join all page_content together with a newline docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) @@ -72,20 +82,29 @@ class ClassicRAG(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_batch = count_tokens(i["prompt"]) + count_tokens( + i["response"] + ) + if ( + tokens_current_history + tokens_batch + < settings.TOKENS_MAX_HISTORY + ): tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append( + {"role": "user", "content": i["prompt"]} + ) + messages_combine.append( + {"role": "system", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + ) - completion = llm.gen_stream(model=self.gpt_model, - messages=messages_combine) + completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: yield {"answer": str(line)} - + def search(self): return self._get_data() - diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index d662bb0..9189298 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -6,16 +6,25 @@ from langchain_community.tools import DuckDuckGoSearchResults from langchain_community.utilities import DuckDuckGoSearchAPIWrapper - class DuckDuckSearch(BaseRetriever): - def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + def __init__( + self, + question, + source, + chat_history, + prompt, + chunks=2, + gpt_model="docsgpt", + user_api_key=None, + ): self.question = question self.source = source self.chat_history = chat_history self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model + self.user_api_key = user_api_key def _parse_lang_string(self, input_string): result = [] @@ -30,12 +39,12 @@ class DuckDuckSearch(BaseRetriever): current_item = "" elif inside_brackets: current_item += char - + if inside_brackets: result.append(current_item) - + return result - + def _get_data(self): if self.chunks == 0: docs = [] @@ -44,7 +53,7 @@ class DuckDuckSearch(BaseRetriever): search = DuckDuckGoSearchResults(api_wrapper=wrapper) results = search.run(self.question) results = self._parse_lang_string(results) - + docs = [] for i in results: try: @@ -56,12 +65,12 @@ class DuckDuckSearch(BaseRetriever): pass if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] - + return docs - + def gen(self): docs = self._get_data() - + # join all page_content together with a newline docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) @@ -75,20 +84,29 @@ class DuckDuckSearch(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_batch = count_tokens(i["prompt"]) + count_tokens( + i["response"] + ) + if ( + tokens_current_history + tokens_batch + < settings.TOKENS_MAX_HISTORY + ): tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append( + {"role": "user", "content": i["prompt"]} + ) + messages_combine.append( + {"role": "system", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + ) - completion = llm.gen_stream(model=self.gpt_model, - messages=messages_combine) + completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: yield {"answer": str(line)} - + def search(self): return self._get_data() - diff --git a/application/usage.py b/application/usage.py new file mode 100644 index 0000000..1b26e9d --- /dev/null +++ b/application/usage.py @@ -0,0 +1,49 @@ +import sys +from pymongo import MongoClient +from datetime import datetime +from application.core.settings import settings +from application.utils import count_tokens + +mongo = MongoClient(settings.MONGO_URI) +db = mongo["docsgpt"] +usage_collection = db["token_usage"] + + +def update_token_usage(user_api_key, token_usage): + if "pytest" in sys.modules: + return + usage_data = { + "api_key": user_api_key, + "prompt_tokens": token_usage["prompt_tokens"], + "generated_tokens": token_usage["generated_tokens"], + "timestamp": datetime.now(), + } + usage_collection.insert_one(usage_data) + + +def gen_token_usage(func): + def wrapper(self, model, messages, stream, **kwargs): + for message in messages: + self.token_usage["prompt_tokens"] += count_tokens(message["content"]) + result = func(self, model, messages, stream, **kwargs) + self.token_usage["generated_tokens"] += count_tokens(result) + update_token_usage(self.user_api_key, self.token_usage) + return result + + return wrapper + + +def stream_token_usage(func): + def wrapper(self, model, messages, stream, **kwargs): + for message in messages: + self.token_usage["prompt_tokens"] += count_tokens(message["content"]) + batch = [] + result = func(self, model, messages, stream, **kwargs) + for r in result: + batch.append(r) + yield r + for line in batch: + self.token_usage["generated_tokens"] += count_tokens(line) + update_token_usage(self.user_api_key, self.token_usage) + + return wrapper