From ba796b6be1dde4a405626f06b1db05eaf9492381 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 15 Apr 2024 15:03:00 +0530 Subject: [PATCH 01/10] feat: logging token usage to database --- application/llm/base.py | 2 +- application/llm/docsgpt_provider.py | 43 +++++++++++------------- application/llm/llm_creator.py | 21 ++++++------ application/usage.py | 51 +++++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 35 deletions(-) create mode 100644 application/usage.py diff --git a/application/llm/base.py b/application/llm/base.py index e08a3b0..65cb8b1 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod class BaseLLM(ABC): def __init__(self): - pass + self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} @abstractmethod def gen(self, *args, **kwargs): diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index d540a91..a46abaa 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -1,49 +1,46 @@ from application.llm.base import BaseLLM import json import requests +from application.usage import gen_token_usage, stream_token_usage -class DocsGPTAPILLM(BaseLLM): - def __init__(self, *args, **kwargs): - self.endpoint = "https://llm.docsgpt.co.uk" +class DocsGPTAPILLM(BaseLLM): + def __init__(self, api_key, *args, **kwargs): + super().__init__(*args, **kwargs) + self.api_key = api_key + self.endpoint = "https://llm.docsgpt.co.uk" + @gen_token_usage def gen(self, model, messages, stream=False, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" response = requests.post( - f"{self.endpoint}/answer", - json={ - "prompt": prompt, - "max_new_tokens": 30 - } + f"{self.endpoint}/answer", json={"prompt": prompt, "max_new_tokens": 30} ) - response_clean = response.json()['a'].replace("###", "") + response_clean = response.json()["a"].replace("###", "") return response_clean + @stream_token_usage def gen_stream(self, model, messages, stream=True, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" # send prompt to endpoint /stream response = requests.post( f"{self.endpoint}/stream", - json={ - "prompt": prompt, - "max_new_tokens": 256 - }, - stream=True + json={"prompt": prompt, "max_new_tokens": 256}, + stream=True, ) - + for line in response.iter_lines(): if line: - #data = json.loads(line) - data_str = line.decode('utf-8') + # data = json.loads(line) + data_str = line.decode("utf-8") if data_str.startswith("data: "): data = json.loads(data_str[6:]) - yield data['a'] - \ No newline at end of file + yield data["a"] diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index b4fdaeb..532f3fc 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -7,22 +7,21 @@ from application.llm.docsgpt_provider import DocsGPTAPILLM from application.llm.premai import PremAILLM - class LLMCreator: llms = { - 'openai': OpenAILLM, - 'azure_openai': AzureOpenAILLM, - 'sagemaker': SagemakerAPILLM, - 'huggingface': HuggingFaceLLM, - 'llama.cpp': LlamaCpp, - 'anthropic': AnthropicLLM, - 'docsgpt': DocsGPTAPILLM, - 'premai': PremAILLM, + "openai": OpenAILLM, + "azure_openai": AzureOpenAILLM, + "sagemaker": SagemakerAPILLM, + "huggingface": HuggingFaceLLM, + "llama.cpp": LlamaCpp, + "anthropic": AnthropicLLM, + "docsgpt": DocsGPTAPILLM, + "premai": PremAILLM, } @classmethod - def create_llm(cls, type, *args, **kwargs): + def create_llm(cls, type, api_key, *args, **kwargs): llm_class = cls.llms.get(type.lower()) if not llm_class: raise ValueError(f"No LLM class found for type {type}") - return llm_class(*args, **kwargs) \ No newline at end of file + return llm_class(api_key, *args, **kwargs) diff --git a/application/usage.py b/application/usage.py new file mode 100644 index 0000000..de854a4 --- /dev/null +++ b/application/usage.py @@ -0,0 +1,51 @@ +from pymongo import MongoClient +from bson.son import SON +from datetime import datetime +from application.core.settings import settings +from application.utils import count_tokens + +mongo = MongoClient(settings.MONGO_URI) +db = mongo["docsgpt"] +usage_collection = db["token_usage"] + + +def update_token_usage(api_key, token_usage): + usage_data = { + "api_key": api_key, + "prompt_tokens": token_usage["prompt_tokens"], + "generated_tokens": token_usage["generated_tokens"], + "timestamp": datetime.now(), + } + usage_collection.insert_one(usage_data) + + +def gen_token_usage(func): + def wrapper(self, model, messages, *args, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] + prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" + self.token_usage["prompt_tokens"] += count_tokens(prompt) + result = func(self, model, messages, *args, **kwargs) + self.token_usage["generated_tokens"] += count_tokens(result) + update_token_usage(self.api_key, self.token_usage) + return result + + return wrapper + + +def stream_token_usage(func): + def wrapper(self, model, messages, *args, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] + prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" + self.token_usage["prompt_tokens"] += count_tokens(prompt) + batch = [] + result = func(self, model, messages, *args, **kwargs) + for r in result: + batch.append(r) + yield r + for line in batch: + self.token_usage["generated_tokens"] += count_tokens(line) + update_token_usage(self.api_key, self.token_usage) + + return wrapper From 9146827590b2896897cc2bcea9c47486b35b36e9 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 15 Apr 2024 15:14:17 +0530 Subject: [PATCH 02/10] fix: removed unused import --- application/usage.py | 1 - 1 file changed, 1 deletion(-) diff --git a/application/usage.py b/application/usage.py index de854a4..2fc307a 100644 --- a/application/usage.py +++ b/application/usage.py @@ -1,5 +1,4 @@ from pymongo import MongoClient -from bson.son import SON from datetime import datetime from application.core.settings import settings from application.utils import count_tokens From 590aa8b43f0f6fab07e4d95c430d4ced9f9388a0 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 15 Apr 2024 18:57:28 +0530 Subject: [PATCH 03/10] update: apply decorator to abstract classes --- application/llm/base.py | 18 ++++++++++++++++-- application/llm/docsgpt_provider.py | 7 ++----- application/usage.py | 20 ++++++++------------ 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/application/llm/base.py b/application/llm/base.py index 65cb8b1..475b793 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -1,14 +1,28 @@ from abc import ABC, abstractmethod +from application.usage import gen_token_usage, stream_token_usage class BaseLLM(ABC): def __init__(self): self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} + def _apply_decorator(self, method, decorator, *args, **kwargs): + return decorator(method, *args, **kwargs) + @abstractmethod - def gen(self, *args, **kwargs): + def _raw_gen(self, model, messages, stream, *args, **kwargs): pass + def gen(self, model, messages, stream=False, *args, **kwargs): + return self._apply_decorator(self._raw_gen, gen_token_usage)( + self, model=model, messages=messages, stream=stream, *args, **kwargs + ) + @abstractmethod - def gen_stream(self, *args, **kwargs): + def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): pass + + def gen_stream(self, model, messages, stream=True, *args, **kwargs): + return self._apply_decorator(self._raw_gen_stream, stream_token_usage)( + self, model=model, messages=messages, stream=stream, *args, **kwargs + ) diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index a46abaa..ffe1e31 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -1,7 +1,6 @@ from application.llm.base import BaseLLM import json import requests -from application.usage import gen_token_usage, stream_token_usage class DocsGPTAPILLM(BaseLLM): @@ -11,8 +10,7 @@ class DocsGPTAPILLM(BaseLLM): self.api_key = api_key self.endpoint = "https://llm.docsgpt.co.uk" - @gen_token_usage - def gen(self, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -24,8 +22,7 @@ class DocsGPTAPILLM(BaseLLM): return response_clean - @stream_token_usage - def gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, *args, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/usage.py b/application/usage.py index 2fc307a..95cd02f 100644 --- a/application/usage.py +++ b/application/usage.py @@ -19,12 +19,10 @@ def update_token_usage(api_key, token_usage): def gen_token_usage(func): - def wrapper(self, model, messages, *args, **kwargs): - context = messages[0]["content"] - user_question = messages[-1]["content"] - prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - self.token_usage["prompt_tokens"] += count_tokens(prompt) - result = func(self, model, messages, *args, **kwargs) + def wrapper(self, model, messages, stream, **kwargs): + for message in messages: + self.token_usage["prompt_tokens"] += count_tokens(message["content"]) + result = func(self, model, messages, stream, **kwargs) self.token_usage["generated_tokens"] += count_tokens(result) update_token_usage(self.api_key, self.token_usage) return result @@ -33,13 +31,11 @@ def gen_token_usage(func): def stream_token_usage(func): - def wrapper(self, model, messages, *args, **kwargs): - context = messages[0]["content"] - user_question = messages[-1]["content"] - prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - self.token_usage["prompt_tokens"] += count_tokens(prompt) + def wrapper(self, model, messages, stream, **kwargs): + for message in messages: + self.token_usage["prompt_tokens"] += count_tokens(message["content"]) batch = [] - result = func(self, model, messages, *args, **kwargs) + result = func(self, model, messages, stream, **kwargs) for r in result: batch.append(r) yield r From c1c69ed22bff86c6dc84c795386ea66044c7c39e Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 15 Apr 2024 19:35:59 +0530 Subject: [PATCH 04/10] fix: pytest issues --- application/llm/anthropic.py | 20 ++++--- application/llm/huggingface.py | 49 ++++++++++------ application/llm/llama_cpp.py | 25 ++++---- application/llm/openai.py | 54 +++++++++++------ application/llm/premai.py | 23 ++++---- application/llm/sagemaker.py | 102 ++++++++++++++++----------------- 6 files changed, 155 insertions(+), 118 deletions(-) diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index 6b0d646..70495f0 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -1,18 +1,22 @@ from application.llm.base import BaseLLM from application.core.settings import settings + class AnthropicLLM(BaseLLM): def __init__(self, api_key=None): from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT - self.api_key = api_key or settings.ANTHROPIC_API_KEY # If not provided, use a default from settings + + self.api_key = ( + api_key or settings.ANTHROPIC_API_KEY + ) # If not provided, use a default from settings self.anthropic = Anthropic(api_key=self.api_key) self.HUMAN_PROMPT = HUMAN_PROMPT self.AI_PROMPT = AI_PROMPT - def gen(self, model, messages, max_tokens=300, stream=False, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen(self, model, messages, max_tokens=300, stream=False, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" if stream: return self.gen_stream(model, prompt, max_tokens, **kwargs) @@ -25,9 +29,9 @@ class AnthropicLLM(BaseLLM): ) return completion.completion - def gen_stream(self, model, messages, max_tokens=300, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen_stream(self, model, messages, max_tokens=300, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" stream_response = self.anthropic.completions.create( model=model, @@ -37,4 +41,4 @@ class AnthropicLLM(BaseLLM): ) for completion in stream_response: - yield completion.completion \ No newline at end of file + yield completion.completion diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index 554bee2..c9e500e 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -1,44 +1,57 @@ from application.llm.base import BaseLLM + class HuggingFaceLLM(BaseLLM): - def __init__(self, api_key, llm_name='Arc53/DocsGPT-7B',q=False): + def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False): global hf - + from langchain.llms import HuggingFacePipeline + if q: import torch - from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + pipeline, + BitsAndBytesConfig, + ) + tokenizer = AutoTokenizer.from_pretrained(llm_name) bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16 - ) - model = AutoModelForCausalLM.from_pretrained(llm_name,quantization_config=bnb_config) + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + model = AutoModelForCausalLM.from_pretrained( + llm_name, quantization_config=bnb_config + ) else: from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline + tokenizer = AutoTokenizer.from_pretrained(llm_name) model = AutoModelForCausalLM.from_pretrained(llm_name) - + pipe = pipeline( - "text-generation", model=model, - tokenizer=tokenizer, max_new_tokens=2000, - device_map="auto", eos_token_id=tokenizer.eos_token_id + "text-generation", + model=model, + tokenizer=tokenizer, + max_new_tokens=2000, + device_map="auto", + eos_token_id=tokenizer.eos_token_id, ) hf = HuggingFacePipeline(pipeline=pipe) - def gen(self, model, messages, stream=False, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen(self, model, messages, stream=False, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" result = hf(prompt) return result.content - def gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, model, messages, stream=True, **kwargs): raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.") - diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index be34d4f..1512cd7 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -1,6 +1,7 @@ from application.llm.base import BaseLLM from application.core.settings import settings + class LlamaCpp(BaseLLM): def __init__(self, api_key, llm_name=settings.MODEL_PATH, **kwargs): @@ -8,25 +9,27 @@ class LlamaCpp(BaseLLM): try: from llama_cpp import Llama except ImportError: - raise ImportError("Please install llama_cpp using pip install llama-cpp-python") + raise ImportError( + "Please install llama_cpp using pip install llama-cpp-python" + ) llama = Llama(model_path=llm_name, n_ctx=2048) - def gen(self, model, messages, stream=False, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen(self, model, messages, stream=False, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" result = llama(prompt, max_tokens=150, echo=False) # import sys # print(result['choices'][0]['text'].split('### Answer \n')[-1], file=sys.stderr) - - return result['choices'][0]['text'].split('### Answer \n')[-1] - def gen_stream(self, model, messages, stream=True, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + return result["choices"][0]["text"].split("### Answer \n")[-1] + + def _raw_gen_stream(self, model, messages, stream=True, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" result = llama(prompt, max_tokens=150, echo=False, stream=stream) @@ -35,5 +38,5 @@ class LlamaCpp(BaseLLM): # print(list(result), file=sys.stderr) for item in result: - for choice in item['choices']: - yield choice['text'] + for choice in item["choices"]: + yield choice["text"] diff --git a/application/llm/openai.py b/application/llm/openai.py index 4b0ed25..de29246 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -1,36 +1,49 @@ from application.llm.base import BaseLLM from application.core.settings import settings + class OpenAILLM(BaseLLM): def __init__(self, api_key): global openai from openai import OpenAI - + self.client = OpenAI( - api_key=api_key, - ) + api_key=api_key, + ) self.api_key = api_key def _get_openai(self): # Import openai when needed import openai - + return openai - def gen(self, model, messages, stream=False, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): - response = self.client.chat.completions.create(model=model, - messages=messages, - stream=stream, - **kwargs) + def _raw_gen( + self, + model, + messages, + stream=False, + engine=settings.AZURE_DEPLOYMENT_NAME, + **kwargs + ): + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) return response.choices[0].message.content - def gen_stream(self, model, messages, stream=True, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs): - response = self.client.chat.completions.create(model=model, - messages=messages, - stream=stream, - **kwargs) + def _raw_gen_stream( + self, + model, + messages, + stream=True, + engine=settings.AZURE_DEPLOYMENT_NAME, + **kwargs + ): + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) for line in response: # import sys @@ -41,14 +54,17 @@ class OpenAILLM(BaseLLM): class AzureOpenAILLM(OpenAILLM): - def __init__(self, openai_api_key, openai_api_base, openai_api_version, deployment_name): + def __init__( + self, openai_api_key, openai_api_base, openai_api_version, deployment_name + ): super().__init__(openai_api_key) - self.api_base = settings.OPENAI_API_BASE, - self.api_version = settings.OPENAI_API_VERSION, - self.deployment_name = settings.AZURE_DEPLOYMENT_NAME, + self.api_base = (settings.OPENAI_API_BASE,) + self.api_version = (settings.OPENAI_API_VERSION,) + self.deployment_name = (settings.AZURE_DEPLOYMENT_NAME,) from openai import AzureOpenAI + self.client = AzureOpenAI( - api_key=openai_api_key, + api_key=openai_api_key, api_version=settings.OPENAI_API_VERSION, api_base=settings.OPENAI_API_BASE, deployment_name=settings.AZURE_DEPLOYMENT_NAME, diff --git a/application/llm/premai.py b/application/llm/premai.py index 5faa5fe..c0552ea 100644 --- a/application/llm/premai.py +++ b/application/llm/premai.py @@ -1,32 +1,35 @@ from application.llm.base import BaseLLM from application.core.settings import settings + class PremAILLM(BaseLLM): def __init__(self, api_key): from premai import Prem - - self.client = Prem( - api_key=api_key - ) + + self.client = Prem(api_key=api_key) self.api_key = api_key self.project_id = settings.PREMAI_PROJECT_ID - def gen(self, model, messages, stream=False, **kwargs): - response = self.client.chat.completions.create(model=model, + def _raw_gen(self, model, messages, stream=False, **kwargs): + response = self.client.chat.completions.create( + model=model, project_id=self.project_id, messages=messages, stream=stream, - **kwargs) + **kwargs + ) return response.choices[0].message["content"] - def gen_stream(self, model, messages, stream=True, **kwargs): - response = self.client.chat.completions.create(model=model, + def _raw_gen_stream(self, model, messages, stream=True, **kwargs): + response = self.client.chat.completions.create( + model=model, project_id=self.project_id, messages=messages, stream=stream, - **kwargs) + **kwargs + ) for line in response: if line.choices[0].delta["content"] is not None: diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index b81f638..b531020 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -4,11 +4,10 @@ import json import io - class LineIterator: """ - A helper class for parsing the byte stream input. - + A helper class for parsing the byte stream input. + The output of the model will be in the following format: ``` b'{"outputs": [" a"]}\n' @@ -16,21 +15,21 @@ class LineIterator: b'{"outputs": [" problem"]}\n' ... ``` - - While usually each PayloadPart event from the event stream will contain a byte array + + While usually each PayloadPart event from the event stream will contain a byte array with a full json, this is not guaranteed and some of the json objects may be split across PayloadPart events. For example: ``` {'PayloadPart': {'Bytes': b'{"outputs": '}} {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} ``` - + This class accounts for this by concatenating bytes written via the 'write' function and then exposing a method which will return lines (ending with a '\n' character) within - the buffer via the 'scan_lines' function. It maintains the position of the last read - position to ensure that previous bytes are not exposed again. + the buffer via the 'scan_lines' function. It maintains the position of the last read + position to ensure that previous bytes are not exposed again. """ - + def __init__(self, stream): self.byte_iterator = iter(stream) self.buffer = io.BytesIO() @@ -43,7 +42,7 @@ class LineIterator: while True: self.buffer.seek(self.read_pos) line = self.buffer.readline() - if line and line[-1] == ord('\n'): + if line and line[-1] == ord("\n"): self.read_pos += len(line) return line[:-1] try: @@ -52,33 +51,32 @@ class LineIterator: if self.read_pos < self.buffer.getbuffer().nbytes: continue raise - if 'PayloadPart' not in chunk: - print('Unknown event type:' + chunk) + if "PayloadPart" not in chunk: + print("Unknown event type:" + chunk) continue self.buffer.seek(0, io.SEEK_END) - self.buffer.write(chunk['PayloadPart']['Bytes']) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) + class SagemakerAPILLM(BaseLLM): def __init__(self, *args, **kwargs): import boto3 + runtime = boto3.client( - 'runtime.sagemaker', - aws_access_key_id='xxx', - aws_secret_access_key='xxx', - region_name='us-west-2' + "runtime.sagemaker", + aws_access_key_id="xxx", + aws_secret_access_key="xxx", + region_name="us-west-2", ) - - self.endpoint = settings.SAGEMAKER_ENDPOINT + self.endpoint = settings.SAGEMAKER_ENDPOINT self.runtime = runtime - - def gen(self, model, messages, stream=False, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + def _raw_gen(self, model, messages, stream=False, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - # Construct payload for endpoint payload = { @@ -89,25 +87,25 @@ class SagemakerAPILLM(BaseLLM): "temperature": 0.1, "max_new_tokens": 30, "repetition_penalty": 1.03, - "stop": ["", "###"] - } + "stop": ["", "###"], + }, } - body_bytes = json.dumps(payload).encode('utf-8') + body_bytes = json.dumps(payload).encode("utf-8") # Invoke the endpoint - response = self.runtime.invoke_endpoint(EndpointName=self.endpoint, - ContentType='application/json', - Body=body_bytes) - result = json.loads(response['Body'].read().decode()) + response = self.runtime.invoke_endpoint( + EndpointName=self.endpoint, ContentType="application/json", Body=body_bytes + ) + result = json.loads(response["Body"].read().decode()) import sys - print(result[0]['generated_text'], file=sys.stderr) - return result[0]['generated_text'][len(prompt):] - def gen_stream(self, model, messages, stream=True, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + print(result[0]["generated_text"], file=sys.stderr) + return result[0]["generated_text"][len(prompt) :] + + def _raw_gen_stream(self, model, messages, stream=True, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - # Construct payload for endpoint payload = { @@ -118,22 +116,22 @@ class SagemakerAPILLM(BaseLLM): "temperature": 0.1, "max_new_tokens": 512, "repetition_penalty": 1.03, - "stop": ["", "###"] - } + "stop": ["", "###"], + }, } - body_bytes = json.dumps(payload).encode('utf-8') + body_bytes = json.dumps(payload).encode("utf-8") # Invoke the endpoint - response = self.runtime.invoke_endpoint_with_response_stream(EndpointName=self.endpoint, - ContentType='application/json', - Body=body_bytes) - #result = json.loads(response['Body'].read().decode()) - event_stream = response['Body'] - start_json = b'{' + response = self.runtime.invoke_endpoint_with_response_stream( + EndpointName=self.endpoint, ContentType="application/json", Body=body_bytes + ) + # result = json.loads(response['Body'].read().decode()) + event_stream = response["Body"] + start_json = b"{" for line in LineIterator(event_stream): - if line != b'' and start_json in line: - #print(line) - data = json.loads(line[line.find(start_json):].decode('utf-8')) - if data['token']['text'] not in ["", "###"]: - print(data['token']['text'],end='') - yield data['token']['text'] \ No newline at end of file + if line != b"" and start_json in line: + # print(line) + data = json.loads(line[line.find(start_json) :].decode("utf-8")) + if data["token"]["text"] not in ["", "###"]: + print(data["token"]["text"], end="") + yield data["token"]["text"] From 60a670ce29080ac672baa1654ace57852a9fed98 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 15 Apr 2024 19:47:24 +0530 Subject: [PATCH 05/10] fix: changes to llm classes according to base --- application/llm/anthropic.py | 3 ++- application/llm/huggingface.py | 4 +++- application/llm/llama_cpp.py | 4 +++- application/llm/openai.py | 3 ++- application/llm/premai.py | 3 ++- application/llm/sagemaker.py | 4 +++- 6 files changed, 15 insertions(+), 6 deletions(-) diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index 70495f0..b3fde3d 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -4,9 +4,10 @@ from application.core.settings import settings class AnthropicLLM(BaseLLM): - def __init__(self, api_key=None): + def __init__(self, api_key=None, *args, **kwargs): from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT + super().__init__(*args, **kwargs) self.api_key = ( api_key or settings.ANTHROPIC_API_KEY ) # If not provided, use a default from settings diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index c9e500e..b1118ed 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -3,7 +3,7 @@ from application.llm.base import BaseLLM class HuggingFaceLLM(BaseLLM): - def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False): + def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs): global hf from langchain.llms import HuggingFacePipeline @@ -33,6 +33,8 @@ class HuggingFaceLLM(BaseLLM): tokenizer = AutoTokenizer.from_pretrained(llm_name) model = AutoModelForCausalLM.from_pretrained(llm_name) + super().__init__(*args, **kwargs) + self.api_key = api_key pipe = pipeline( "text-generation", model=model, diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index 1512cd7..896e66f 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -4,7 +4,7 @@ from application.core.settings import settings class LlamaCpp(BaseLLM): - def __init__(self, api_key, llm_name=settings.MODEL_PATH, **kwargs): + def __init__(self, api_key, llm_name=settings.MODEL_PATH, *args, **kwargs): global llama try: from llama_cpp import Llama @@ -13,6 +13,8 @@ class LlamaCpp(BaseLLM): "Please install llama_cpp using pip install llama-cpp-python" ) + super().__init__(*args, **kwargs) + self.api_key = api_key llama = Llama(model_path=llm_name, n_ctx=2048) def _raw_gen(self, model, messages, stream=False, **kwargs): diff --git a/application/llm/openai.py b/application/llm/openai.py index de29246..c741404 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -4,10 +4,11 @@ from application.core.settings import settings class OpenAILLM(BaseLLM): - def __init__(self, api_key): + def __init__(self, api_key, *args, **kwargs): global openai from openai import OpenAI + super().__init__(*args, **kwargs) self.client = OpenAI( api_key=api_key, ) diff --git a/application/llm/premai.py b/application/llm/premai.py index c0552ea..203ff4d 100644 --- a/application/llm/premai.py +++ b/application/llm/premai.py @@ -4,9 +4,10 @@ from application.core.settings import settings class PremAILLM(BaseLLM): - def __init__(self, api_key): + def __init__(self, api_key, *args, **kwargs): from premai import Prem + super().__init__(*args, **kwargs) self.client = Prem(api_key=api_key) self.api_key = api_key self.project_id = settings.PREMAI_PROJECT_ID diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index b531020..807bfa2 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -60,7 +60,7 @@ class LineIterator: class SagemakerAPILLM(BaseLLM): - def __init__(self, *args, **kwargs): + def __init__(self, api_key, *args, **kwargs): import boto3 runtime = boto3.client( @@ -70,6 +70,8 @@ class SagemakerAPILLM(BaseLLM): region_name="us-west-2", ) + super().__init__(*args, **kwargs) + self.api_key = api_key self.endpoint = settings.SAGEMAKER_ENDPOINT self.runtime = runtime From 77991896b4eab66322d4ca95b057e8a3624728dd Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 15 Apr 2024 22:32:24 +0530 Subject: [PATCH 06/10] fix: api_key capturing + pytest errors --- application/api/answer/routes.py | 218 +++++++++++++++-------- application/llm/anthropic.py | 6 +- application/llm/docsgpt_provider.py | 2 +- application/llm/huggingface.py | 8 +- application/llm/llama_cpp.py | 6 +- application/llm/openai.py | 4 +- application/llm/premai.py | 6 +- application/llm/sagemaker.py | 6 +- application/retriever/brave_search.py | 60 ++++--- application/retriever/classic_rag.py | 61 ++++--- application/retriever/duckduck_search.py | 52 ++++-- 11 files changed, 280 insertions(+), 149 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index fa0ac4f..b95dd14 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -10,14 +10,12 @@ from pymongo import MongoClient from bson.objectid import ObjectId - from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator from application.error import bad_request - logger = logging.getLogger(__name__) mongo = MongoClient(settings.MONGO_URI) @@ -26,20 +24,22 @@ conversations_collection = db["conversations"] vectors_collection = db["vectors"] prompts_collection = db["prompts"] api_key_collection = db["api_keys"] -answer = Blueprint('answer', __name__) +answer = Blueprint("answer", __name__) gpt_model = "" # to have some kind of default behaviour if settings.LLM_NAME == "openai": - gpt_model = 'gpt-3.5-turbo' + gpt_model = "gpt-3.5-turbo" elif settings.LLM_NAME == "anthropic": - gpt_model = 'claude-2' + gpt_model = "claude-2" if settings.MODEL_NAME: # in case there is particular model name configured gpt_model = settings.MODEL_NAME # load the prompts -current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +current_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f: chat_combine_template = f.read() @@ -50,7 +50,7 @@ with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r" chat_combine_creative = f.read() with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f: - chat_combine_strict = f.read() + chat_combine_strict = f.read() api_key_set = settings.API_KEY is not None embeddings_key_set = settings.EMBEDDINGS_KEY is not None @@ -61,8 +61,6 @@ async def async_generate(chain, question, chat_history): return result - - def run_async_chain(chain, question, chat_history): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -74,17 +72,18 @@ def run_async_chain(chain, question, chat_history): result["answer"] = answer return result + def get_data_from_api_key(api_key): data = api_key_collection.find_one({"key": api_key}) if data is None: return bad_request(401, "Invalid API key") return data - + def get_vectorstore(data): if "active_docs" in data: if data["active_docs"].split("/")[0] == "default": - vectorstore = "" + vectorstore = "" elif data["active_docs"].split("/")[0] == "local": vectorstore = "indexes/" + data["active_docs"] else: @@ -98,52 +97,82 @@ def get_vectorstore(data): def is_azure_configured(): - return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME + return ( + settings.OPENAI_API_BASE + and settings.OPENAI_API_VERSION + and settings.AZURE_DEPLOYMENT_NAME + ) + def save_conversation(conversation_id, question, response, source_log_docs, llm): if conversation_id is not None and conversation_id != "None": conversations_collection.update_one( {"_id": ObjectId(conversation_id)}, - {"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}}, + { + "$push": { + "queries": { + "prompt": question, + "response": response, + "sources": source_log_docs, + } + } + }, ) else: # create new conversation # generate summary - messages_summary = [{"role": "assistant", "content": "Summarise following conversation in no more than 3 " - "words, respond ONLY with the summary, use the same " - "language as the system \n\nUser: " + question + "\n\n" + - "AI: " + - response}, - {"role": "user", "content": "Summarise following conversation in no more than 3 words, " - "respond ONLY with the summary, use the same language as the " - "system"}] - - completion = llm.gen(model=gpt_model, - messages=messages_summary, max_tokens=30) + messages_summary = [ + { + "role": "assistant", + "content": "Summarise following conversation in no more than 3 " + "words, respond ONLY with the summary, use the same " + "language as the system \n\nUser: " + + question + + "\n\n" + + "AI: " + + response, + }, + { + "role": "user", + "content": "Summarise following conversation in no more than 3 words, " + "respond ONLY with the summary, use the same language as the " + "system", + }, + ] + + completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30) conversation_id = conversations_collection.insert_one( - {"user": "local", - "date": datetime.datetime.utcnow(), - "name": completion, - "queries": [{"prompt": question, "response": response, "sources": source_log_docs}]} + { + "user": "local", + "date": datetime.datetime.utcnow(), + "name": completion, + "queries": [ + { + "prompt": question, + "response": response, + "sources": source_log_docs, + } + ], + } ).inserted_id return conversation_id + def get_prompt(prompt_id): - if prompt_id == 'default': + if prompt_id == "default": prompt = chat_combine_template - elif prompt_id == 'creative': + elif prompt_id == "creative": prompt = chat_combine_creative - elif prompt_id == 'strict': + elif prompt_id == "strict": prompt = chat_combine_strict else: prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"] return prompt -def complete_stream(question, retriever, conversation_id): - - +def complete_stream(question, retriever, conversation_id, user_api_key): + response_full = "" source_log_docs = [] answer = retriever.gen() @@ -155,9 +184,10 @@ def complete_stream(question, retriever, conversation_id): elif "source" in line: source_log_docs.append(line["source"]) - - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) - conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key) + conversation_id = save_conversation( + conversation_id, question, response_full, source_log_docs, llm + ) # send data.type = "end" to indicate that the stream has ended as json data = json.dumps({"type": "id", "id": str(conversation_id)}) @@ -180,17 +210,17 @@ def stream(): conversation_id = None else: conversation_id = data["conversation_id"] - if 'prompt_id' in data: + if "prompt_id" in data: prompt_id = data["prompt_id"] else: - prompt_id = 'default' - if 'selectedDocs' in data and data['selectedDocs'] is None: + prompt_id = "default" + if "selectedDocs" in data and data["selectedDocs"] is None: chunks = 0 - elif 'chunks' in data: + elif "chunks" in data: chunks = int(data["chunks"]) else: chunks = 2 - + prompt = get_prompt(prompt_id) # check if active_docs is set @@ -198,23 +228,42 @@ def stream(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) source = {"active_docs": data_key["source"]} + user_api_key = data["api_key"] elif "active_docs" in data: source = {"active_docs": data["active_docs"]} + user_api_key = None else: source = {} + user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + if ( + source["active_docs"].split("/")[0] == "default" + or source["active_docs"].split("/")[0] == "local" + ): retriever_name = "classic" else: - retriever_name = source['active_docs'] - - retriever = RetrieverCreator.create_retriever(retriever_name, question=question, - source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model - ) + retriever_name = source["active_docs"] + + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + gpt_model=gpt_model, + api_key=user_api_key, + ) return Response( - complete_stream(question=question, retriever=retriever, - conversation_id=conversation_id), mimetype="text/event-stream") + complete_stream( + question=question, + retriever=retriever, + conversation_id=conversation_id, + user_api_key=user_api_key, + ), + mimetype="text/event-stream", + ) @answer.route("/api/answer", methods=["POST"]) @@ -230,15 +279,15 @@ def api_answer(): else: conversation_id = data["conversation_id"] print("-" * 5) - if 'prompt_id' in data: + if "prompt_id" in data: prompt_id = data["prompt_id"] else: - prompt_id = 'default' - if 'chunks' in data: + prompt_id = "default" + if "chunks" in data: chunks = int(data["chunks"]) else: chunks = 2 - + prompt = get_prompt(prompt_id) # use try and except to check for exception @@ -247,17 +296,29 @@ def api_answer(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) source = {"active_docs": data_key["source"]} + user_api_key = data["api_key"] else: source = {data} + user_api_key = None - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + if ( + source["active_docs"].split("/")[0] == "default" + or source["active_docs"].split("/")[0] == "local" + ): retriever_name = "classic" else: - retriever_name = source['active_docs'] - - retriever = RetrieverCreator.create_retriever(retriever_name, question=question, - source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model - ) + retriever_name = source["active_docs"] + + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + gpt_model=gpt_model, + api_key=user_api_key, + ) source_log_docs = [] response_full = "" for line in retriever.gen(): @@ -265,12 +326,13 @@ def api_answer(): source_log_docs.append(line["source"]) elif "answer" in line: response_full += line["answer"] - - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) - + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key) result = {"answer": response_full, "sources": source_log_docs} - result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm) + result["conversation_id"] = save_conversation( + conversation_id, question, response_full, source_log_docs, llm + ) return result except Exception as e: @@ -289,23 +351,35 @@ def api_search(): if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) source = {"active_docs": data_key["source"]} + user_api_key = data["api_key"] elif "active_docs" in data: source = {"active_docs": data["active_docs"]} + user_api_key = None else: source = {} - if 'chunks' in data: + user_api_key = None + if "chunks" in data: chunks = int(data["chunks"]) else: chunks = 2 - if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local": + if ( + source["active_docs"].split("/")[0] == "default" + or source["active_docs"].split("/")[0] == "local" + ): retriever_name = "classic" else: - retriever_name = source['active_docs'] - - retriever = RetrieverCreator.create_retriever(retriever_name, question=question, - source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model - ) + retriever_name = source["active_docs"] + + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=[], + prompt="default", + chunks=chunks, + gpt_model=gpt_model, + api_key=user_api_key, + ) docs = retriever.search() return docs - diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index b3fde3d..31e4563 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -15,7 +15,9 @@ class AnthropicLLM(BaseLLM): self.HUMAN_PROMPT = HUMAN_PROMPT self.AI_PROMPT = AI_PROMPT - def _raw_gen(self, model, messages, max_tokens=300, stream=False, **kwargs): + def _raw_gen( + self, baseself, model, messages, max_tokens=300, stream=False, **kwargs + ): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" @@ -30,7 +32,7 @@ class AnthropicLLM(BaseLLM): ) return completion.completion - def _raw_gen_stream(self, model, messages, max_tokens=300, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, max_tokens=300, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index ffe1e31..6dc1d6c 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -5,7 +5,7 @@ import requests class DocsGPTAPILLM(BaseLLM): - def __init__(self, api_key, *args, **kwargs): + def __init__(self, api_key=None, *args, **kwargs): super().__init__(*args, **kwargs) self.api_key = api_key self.endpoint = "https://llm.docsgpt.co.uk" diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index b1118ed..0baaceb 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -3,7 +3,9 @@ from application.llm.base import BaseLLM class HuggingFaceLLM(BaseLLM): - def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs): + def __init__( + self, api_key=None, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs + ): global hf from langchain.llms import HuggingFacePipeline @@ -45,7 +47,7 @@ class HuggingFaceLLM(BaseLLM): ) hf = HuggingFacePipeline(pipeline=pipe) - def _raw_gen(self, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -54,6 +56,6 @@ class HuggingFaceLLM(BaseLLM): return result.content - def _raw_gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.") diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index 896e66f..ebaefca 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -4,7 +4,7 @@ from application.core.settings import settings class LlamaCpp(BaseLLM): - def __init__(self, api_key, llm_name=settings.MODEL_PATH, *args, **kwargs): + def __init__(self, api_key=None, llm_name=settings.MODEL_PATH, *args, **kwargs): global llama try: from llama_cpp import Llama @@ -17,7 +17,7 @@ class LlamaCpp(BaseLLM): self.api_key = api_key llama = Llama(model_path=llm_name, n_ctx=2048) - def _raw_gen(self, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -29,7 +29,7 @@ class LlamaCpp(BaseLLM): return result["choices"][0]["text"].split("### Answer \n")[-1] - def _raw_gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/llm/openai.py b/application/llm/openai.py index c741404..b0fd4c5 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -4,7 +4,7 @@ from application.core.settings import settings class OpenAILLM(BaseLLM): - def __init__(self, api_key, *args, **kwargs): + def __init__(self, api_key=None, *args, **kwargs): global openai from openai import OpenAI @@ -22,6 +22,7 @@ class OpenAILLM(BaseLLM): def _raw_gen( self, + baseself, model, messages, stream=False, @@ -36,6 +37,7 @@ class OpenAILLM(BaseLLM): def _raw_gen_stream( self, + baseself, model, messages, stream=True, diff --git a/application/llm/premai.py b/application/llm/premai.py index 203ff4d..cdb8063 100644 --- a/application/llm/premai.py +++ b/application/llm/premai.py @@ -4,7 +4,7 @@ from application.core.settings import settings class PremAILLM(BaseLLM): - def __init__(self, api_key, *args, **kwargs): + def __init__(self, api_key=None, *args, **kwargs): from premai import Prem super().__init__(*args, **kwargs) @@ -12,7 +12,7 @@ class PremAILLM(BaseLLM): self.api_key = api_key self.project_id = settings.PREMAI_PROJECT_ID - def _raw_gen(self, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): response = self.client.chat.completions.create( model=model, project_id=self.project_id, @@ -23,7 +23,7 @@ class PremAILLM(BaseLLM): return response.choices[0].message["content"] - def _raw_gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): response = self.client.chat.completions.create( model=model, project_id=self.project_id, diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 807bfa2..579eec6 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -60,7 +60,7 @@ class LineIterator: class SagemakerAPILLM(BaseLLM): - def __init__(self, api_key, *args, **kwargs): + def __init__(self, api_key=None, *args, **kwargs): import boto3 runtime = boto3.client( @@ -75,7 +75,7 @@ class SagemakerAPILLM(BaseLLM): self.endpoint = settings.SAGEMAKER_ENDPOINT self.runtime = runtime - def _raw_gen(self, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -104,7 +104,7 @@ class SagemakerAPILLM(BaseLLM): print(result[0]["generated_text"], file=sys.stderr) return result[0]["generated_text"][len(prompt) :] - def _raw_gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 0cc7bd4..d0c81b6 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -6,43 +6,54 @@ from application.utils import count_tokens from langchain_community.tools import BraveSearch - class BraveRetSearch(BaseRetriever): - def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + def __init__( + self, + question, + source, + chat_history, + prompt, + chunks=2, + gpt_model="docsgpt", + api_key=None, + ): self.question = question self.source = source self.chat_history = chat_history self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model - + self.api_key = api_key + def _get_data(self): if self.chunks == 0: docs = [] else: - search = BraveSearch.from_api_key(api_key=settings.BRAVE_SEARCH_API_KEY, - search_kwargs={"count": int(self.chunks)}) + search = BraveSearch.from_api_key( + api_key=settings.BRAVE_SEARCH_API_KEY, + search_kwargs={"count": int(self.chunks)}, + ) results = search.run(self.question) results = json.loads(results) - + docs = [] for i in results: try: - title = i['title'] - link = i['link'] - snippet = i['snippet'] + title = i["title"] + link = i["link"] + snippet = i["snippet"] docs.append({"text": snippet, "title": title, "link": link}) except IndexError: pass if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] - + return docs - + def gen(self): docs = self._get_data() - + # join all page_content together with a newline docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) @@ -56,20 +67,27 @@ class BraveRetSearch(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_batch = count_tokens(i["prompt"]) + count_tokens( + i["response"] + ) + if ( + tokens_current_history + tokens_batch + < settings.TOKENS_MAX_HISTORY + ): tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append( + {"role": "user", "content": i["prompt"]} + ) + messages_combine.append( + {"role": "system", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) - completion = llm.gen_stream(model=self.gpt_model, - messages=messages_combine) + completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: yield {"answer": str(line)} - + def search(self): return self._get_data() - diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index b5f1eb9..138971f 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -7,21 +7,30 @@ from application.llm.llm_creator import LLMCreator from application.utils import count_tokens - class ClassicRAG(BaseRetriever): - def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + def __init__( + self, + question, + source, + chat_history, + prompt, + chunks=2, + gpt_model="docsgpt", + api_key=None, + ): self.question = question self.vectorstore = self._get_vectorstore(source=source) self.chat_history = chat_history self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model + self.api_key = api_key def _get_vectorstore(self, source): if "active_docs" in source: if source["active_docs"].split("/")[0] == "default": - vectorstore = "" + vectorstore = "" elif source["active_docs"].split("/")[0] == "local": vectorstore = "indexes/" + source["active_docs"] else: @@ -33,32 +42,33 @@ class ClassicRAG(BaseRetriever): vectorstore = os.path.join("application", vectorstore) return vectorstore - def _get_data(self): if self.chunks == 0: docs = [] else: docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, - self.vectorstore, - settings.EMBEDDINGS_KEY + settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY ) docs_temp = docsearch.search(self.question, k=self.chunks) docs = [ { - "title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, - "text": i.page_content - } + "title": ( + i.metadata["title"].split("/")[-1] + if i.metadata + else i.page_content + ), + "text": i.page_content, + } for i in docs_temp ] if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] - + return docs - + def gen(self): docs = self._get_data() - + # join all page_content together with a newline docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) @@ -72,20 +82,27 @@ class ClassicRAG(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_batch = count_tokens(i["prompt"]) + count_tokens( + i["response"] + ) + if ( + tokens_current_history + tokens_batch + < settings.TOKENS_MAX_HISTORY + ): tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append( + {"role": "user", "content": i["prompt"]} + ) + messages_combine.append( + {"role": "system", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) - completion = llm.gen_stream(model=self.gpt_model, - messages=messages_combine) + completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: yield {"answer": str(line)} - + def search(self): return self._get_data() - diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index d662bb0..f61c082 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -6,16 +6,25 @@ from langchain_community.tools import DuckDuckGoSearchResults from langchain_community.utilities import DuckDuckGoSearchAPIWrapper - class DuckDuckSearch(BaseRetriever): - def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + def __init__( + self, + question, + source, + chat_history, + prompt, + chunks=2, + gpt_model="docsgpt", + api_key=None, + ): self.question = question self.source = source self.chat_history = chat_history self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model + self.api_key = api_key def _parse_lang_string(self, input_string): result = [] @@ -30,12 +39,12 @@ class DuckDuckSearch(BaseRetriever): current_item = "" elif inside_brackets: current_item += char - + if inside_brackets: result.append(current_item) - + return result - + def _get_data(self): if self.chunks == 0: docs = [] @@ -44,7 +53,7 @@ class DuckDuckSearch(BaseRetriever): search = DuckDuckGoSearchResults(api_wrapper=wrapper) results = search.run(self.question) results = self._parse_lang_string(results) - + docs = [] for i in results: try: @@ -56,12 +65,12 @@ class DuckDuckSearch(BaseRetriever): pass if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] - + return docs - + def gen(self): docs = self._get_data() - + # join all page_content together with a newline docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) @@ -75,20 +84,27 @@ class DuckDuckSearch(BaseRetriever): self.chat_history.reverse() for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) - if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_batch = count_tokens(i["prompt"]) + count_tokens( + i["response"] + ) + if ( + tokens_current_history + tokens_batch + < settings.TOKENS_MAX_HISTORY + ): tokens_current_history += tokens_batch - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append( + {"role": "user", "content": i["prompt"]} + ) + messages_combine.append( + {"role": "system", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) - completion = llm.gen_stream(model=self.gpt_model, - messages=messages_combine) + completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: yield {"answer": str(line)} - + def search(self): return self._get_data() - From 1b61337b7576fe650fe231e2814508852a0a743b Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Tue, 16 Apr 2024 01:08:39 +0530 Subject: [PATCH 07/10] fix: skip logging to db during tests --- application/usage.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/application/usage.py b/application/usage.py index 95cd02f..131de66 100644 --- a/application/usage.py +++ b/application/usage.py @@ -1,3 +1,4 @@ +import sys from pymongo import MongoClient from datetime import datetime from application.core.settings import settings @@ -9,6 +10,8 @@ usage_collection = db["token_usage"] def update_token_usage(api_key, token_usage): + if "pytest" in sys.modules: + return usage_data = { "api_key": api_key, "prompt_tokens": token_usage["prompt_tokens"], From 333b6e60e173d2294d204e2c9f4f26bb82ba0919 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Tue, 16 Apr 2024 10:02:04 +0530 Subject: [PATCH 08/10] fix: anthropic llm positional arguments --- application/llm/anthropic.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index 31e4563..a078fc8 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -16,13 +16,13 @@ class AnthropicLLM(BaseLLM): self.AI_PROMPT = AI_PROMPT def _raw_gen( - self, baseself, model, messages, max_tokens=300, stream=False, **kwargs + self, baseself, model, messages, stream=False, max_tokens=300, **kwargs ): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" if stream: - return self.gen_stream(model, prompt, max_tokens, **kwargs) + return self.gen_stream(model, prompt, stream, max_tokens, **kwargs) completion = self.anthropic.completions.create( model=model, @@ -32,7 +32,9 @@ class AnthropicLLM(BaseLLM): ) return completion.completion - def _raw_gen_stream(self, baseself, model, messages, max_tokens=300, **kwargs): + def _raw_gen_stream( + self, baseself, model, messages, stream=True, max_tokens=300, **kwargs + ): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" From af5e73c8cb1c981a0557001a5136cacc4b90a3a1 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Tue, 16 Apr 2024 15:31:11 +0530 Subject: [PATCH 09/10] fix: user_api_key capturing --- application/api/answer/routes.py | 14 +++++++++----- application/llm/anthropic.py | 3 ++- application/llm/docsgpt_provider.py | 3 ++- application/llm/huggingface.py | 9 ++++++++- application/llm/llama_cpp.py | 10 +++++++++- application/llm/llm_creator.py | 4 ++-- application/llm/openai.py | 3 ++- application/llm/premai.py | 3 ++- application/llm/sagemaker.py | 3 ++- application/retriever/brave_search.py | 8 +++++--- application/retriever/classic_rag.py | 8 +++++--- application/retriever/duckduck_search.py | 8 +++++--- application/usage.py | 8 ++++---- results.txt | 12 ++++++++++++ 14 files changed, 69 insertions(+), 27 deletions(-) create mode 100644 results.txt diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index b95dd14..af3c5f0 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -184,7 +184,9 @@ def complete_stream(question, retriever, conversation_id, user_api_key): elif "source" in line: source_log_docs.append(line["source"]) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + ) conversation_id = save_conversation( conversation_id, question, response_full, source_log_docs, llm ) @@ -252,7 +254,7 @@ def stream(): prompt=prompt, chunks=chunks, gpt_model=gpt_model, - api_key=user_api_key, + user_api_key=user_api_key, ) return Response( @@ -317,7 +319,7 @@ def api_answer(): prompt=prompt, chunks=chunks, gpt_model=gpt_model, - api_key=user_api_key, + user_api_key=user_api_key, ) source_log_docs = [] response_full = "" @@ -327,7 +329,9 @@ def api_answer(): elif "answer" in line: response_full += line["answer"] - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + ) result = {"answer": response_full, "sources": source_log_docs} result["conversation_id"] = save_conversation( @@ -379,7 +383,7 @@ def api_search(): prompt="default", chunks=chunks, gpt_model=gpt_model, - api_key=user_api_key, + user_api_key=user_api_key, ) docs = retriever.search() return docs diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index a078fc8..4081bcd 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -4,13 +4,14 @@ from application.core.settings import settings class AnthropicLLM(BaseLLM): - def __init__(self, api_key=None, *args, **kwargs): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT super().__init__(*args, **kwargs) self.api_key = ( api_key or settings.ANTHROPIC_API_KEY ) # If not provided, use a default from settings + self.user_api_key = user_api_key self.anthropic = Anthropic(api_key=self.api_key) self.HUMAN_PROMPT = HUMAN_PROMPT self.AI_PROMPT = AI_PROMPT diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index 6dc1d6c..bca3972 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -5,9 +5,10 @@ import requests class DocsGPTAPILLM(BaseLLM): - def __init__(self, api_key=None, *args, **kwargs): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): super().__init__(*args, **kwargs) self.api_key = api_key + self.user_api_key = user_api_key self.endpoint = "https://llm.docsgpt.co.uk" def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs): diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index 0baaceb..2fb4a92 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -4,7 +4,13 @@ from application.llm.base import BaseLLM class HuggingFaceLLM(BaseLLM): def __init__( - self, api_key=None, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs + self, + api_key=None, + user_api_key=None, + llm_name="Arc53/DocsGPT-7B", + q=False, + *args, + **kwargs, ): global hf @@ -37,6 +43,7 @@ class HuggingFaceLLM(BaseLLM): super().__init__(*args, **kwargs) self.api_key = api_key + self.user_api_key = user_api_key pipe = pipeline( "text-generation", model=model, diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index ebaefca..25a2f0c 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -4,7 +4,14 @@ from application.core.settings import settings class LlamaCpp(BaseLLM): - def __init__(self, api_key=None, llm_name=settings.MODEL_PATH, *args, **kwargs): + def __init__( + self, + api_key=None, + user_api_key=None, + llm_name=settings.MODEL_PATH, + *args, + **kwargs, + ): global llama try: from llama_cpp import Llama @@ -15,6 +22,7 @@ class LlamaCpp(BaseLLM): super().__init__(*args, **kwargs) self.api_key = api_key + self.user_api_key = user_api_key llama = Llama(model_path=llm_name, n_ctx=2048) def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index 532f3fc..7960778 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -20,8 +20,8 @@ class LLMCreator: } @classmethod - def create_llm(cls, type, api_key, *args, **kwargs): + def create_llm(cls, type, api_key, user_api_key, *args, **kwargs): llm_class = cls.llms.get(type.lower()) if not llm_class: raise ValueError(f"No LLM class found for type {type}") - return llm_class(api_key, *args, **kwargs) + return llm_class(api_key, user_api_key, *args, **kwargs) diff --git a/application/llm/openai.py b/application/llm/openai.py index b0fd4c5..b1574dd 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -4,7 +4,7 @@ from application.core.settings import settings class OpenAILLM(BaseLLM): - def __init__(self, api_key=None, *args, **kwargs): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): global openai from openai import OpenAI @@ -13,6 +13,7 @@ class OpenAILLM(BaseLLM): api_key=api_key, ) self.api_key = api_key + self.user_api_key = user_api_key def _get_openai(self): # Import openai when needed diff --git a/application/llm/premai.py b/application/llm/premai.py index cdb8063..c250c65 100644 --- a/application/llm/premai.py +++ b/application/llm/premai.py @@ -4,12 +4,13 @@ from application.core.settings import settings class PremAILLM(BaseLLM): - def __init__(self, api_key=None, *args, **kwargs): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): from premai import Prem super().__init__(*args, **kwargs) self.client = Prem(api_key=api_key) self.api_key = api_key + self.user_api_key = user_api_key self.project_id = settings.PREMAI_PROJECT_ID def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 579eec6..6394743 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -60,7 +60,7 @@ class LineIterator: class SagemakerAPILLM(BaseLLM): - def __init__(self, api_key=None, *args, **kwargs): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): import boto3 runtime = boto3.client( @@ -72,6 +72,7 @@ class SagemakerAPILLM(BaseLLM): super().__init__(*args, **kwargs) self.api_key = api_key + self.user_api_key = user_api_key self.endpoint = settings.SAGEMAKER_ENDPOINT self.runtime = runtime diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index d0c81b6..47ca0e7 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -16,7 +16,7 @@ class BraveRetSearch(BaseRetriever): prompt, chunks=2, gpt_model="docsgpt", - api_key=None, + user_api_key=None, ): self.question = question self.source = source @@ -24,7 +24,7 @@ class BraveRetSearch(BaseRetriever): self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model - self.api_key = api_key + self.user_api_key = user_api_key def _get_data(self): if self.chunks == 0: @@ -83,7 +83,9 @@ class BraveRetSearch(BaseRetriever): ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + ) completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 138971f..1bce6f8 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -17,7 +17,7 @@ class ClassicRAG(BaseRetriever): prompt, chunks=2, gpt_model="docsgpt", - api_key=None, + user_api_key=None, ): self.question = question self.vectorstore = self._get_vectorstore(source=source) @@ -25,7 +25,7 @@ class ClassicRAG(BaseRetriever): self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model - self.api_key = api_key + self.user_api_key = user_api_key def _get_vectorstore(self, source): if "active_docs" in source: @@ -98,7 +98,9 @@ class ClassicRAG(BaseRetriever): ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + ) completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index f61c082..9189298 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -16,7 +16,7 @@ class DuckDuckSearch(BaseRetriever): prompt, chunks=2, gpt_model="docsgpt", - api_key=None, + user_api_key=None, ): self.question = question self.source = source @@ -24,7 +24,7 @@ class DuckDuckSearch(BaseRetriever): self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model - self.api_key = api_key + self.user_api_key = user_api_key def _parse_lang_string(self, input_string): result = [] @@ -100,7 +100,9 @@ class DuckDuckSearch(BaseRetriever): ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + ) completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: diff --git a/application/usage.py b/application/usage.py index 131de66..1b26e9d 100644 --- a/application/usage.py +++ b/application/usage.py @@ -9,11 +9,11 @@ db = mongo["docsgpt"] usage_collection = db["token_usage"] -def update_token_usage(api_key, token_usage): +def update_token_usage(user_api_key, token_usage): if "pytest" in sys.modules: return usage_data = { - "api_key": api_key, + "api_key": user_api_key, "prompt_tokens": token_usage["prompt_tokens"], "generated_tokens": token_usage["generated_tokens"], "timestamp": datetime.now(), @@ -27,7 +27,7 @@ def gen_token_usage(func): self.token_usage["prompt_tokens"] += count_tokens(message["content"]) result = func(self, model, messages, stream, **kwargs) self.token_usage["generated_tokens"] += count_tokens(result) - update_token_usage(self.api_key, self.token_usage) + update_token_usage(self.user_api_key, self.token_usage) return result return wrapper @@ -44,6 +44,6 @@ def stream_token_usage(func): yield r for line in batch: self.token_usage["generated_tokens"] += count_tokens(line) - update_token_usage(self.api_key, self.token_usage) + update_token_usage(self.user_api_key, self.token_usage) return wrapper diff --git a/results.txt b/results.txt new file mode 100644 index 0000000..beff805 --- /dev/null +++ b/results.txt @@ -0,0 +1,12 @@ +Base URL:http://petstore.swagger.io,https://api.example.com +Path1: /pets +description: None +parameters: [] +methods: +get=A paged array of pets +post=Null response +Path2: /pets/{petId} +description: None +parameters: [] +methods: +get=Expected response to a valid request From ab43c20b8f7d31ee3322a754a11fd28f65540907 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 22 Apr 2024 12:08:11 +0100 Subject: [PATCH 10/10] delete test output --- results.txt | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 results.txt diff --git a/results.txt b/results.txt deleted file mode 100644 index beff805..0000000 --- a/results.txt +++ /dev/null @@ -1,12 +0,0 @@ -Base URL:http://petstore.swagger.io,https://api.example.com -Path1: /pets -description: None -parameters: [] -methods: -get=A paged array of pets -post=Null response -Path2: /pets/{petId} -description: None -parameters: [] -methods: -get=Expected response to a valid request