From ba796b6be1dde4a405626f06b1db05eaf9492381 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 15 Apr 2024 15:03:00 +0530 Subject: [PATCH] feat: logging token usage to database --- application/llm/base.py | 2 +- application/llm/docsgpt_provider.py | 43 +++++++++++------------- application/llm/llm_creator.py | 21 ++++++------ application/usage.py | 51 +++++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 35 deletions(-) create mode 100644 application/usage.py diff --git a/application/llm/base.py b/application/llm/base.py index e08a3b0..65cb8b1 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod class BaseLLM(ABC): def __init__(self): - pass + self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} @abstractmethod def gen(self, *args, **kwargs): diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index d540a91..a46abaa 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -1,49 +1,46 @@ from application.llm.base import BaseLLM import json import requests +from application.usage import gen_token_usage, stream_token_usage -class DocsGPTAPILLM(BaseLLM): - def __init__(self, *args, **kwargs): - self.endpoint = "https://llm.docsgpt.co.uk" +class DocsGPTAPILLM(BaseLLM): + def __init__(self, api_key, *args, **kwargs): + super().__init__(*args, **kwargs) + self.api_key = api_key + self.endpoint = "https://llm.docsgpt.co.uk" + @gen_token_usage def gen(self, model, messages, stream=False, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" response = requests.post( - f"{self.endpoint}/answer", - json={ - "prompt": prompt, - "max_new_tokens": 30 - } + f"{self.endpoint}/answer", json={"prompt": prompt, "max_new_tokens": 30} ) - response_clean = response.json()['a'].replace("###", "") + response_clean = response.json()["a"].replace("###", "") return response_clean + @stream_token_usage def gen_stream(self, model, messages, stream=True, **kwargs): - context = messages[0]['content'] - user_question = messages[-1]['content'] + context = messages[0]["content"] + user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" # send prompt to endpoint /stream response = requests.post( f"{self.endpoint}/stream", - json={ - "prompt": prompt, - "max_new_tokens": 256 - }, - stream=True + json={"prompt": prompt, "max_new_tokens": 256}, + stream=True, ) - + for line in response.iter_lines(): if line: - #data = json.loads(line) - data_str = line.decode('utf-8') + # data = json.loads(line) + data_str = line.decode("utf-8") if data_str.startswith("data: "): data = json.loads(data_str[6:]) - yield data['a'] - \ No newline at end of file + yield data["a"] diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index b4fdaeb..532f3fc 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -7,22 +7,21 @@ from application.llm.docsgpt_provider import DocsGPTAPILLM from application.llm.premai import PremAILLM - class LLMCreator: llms = { - 'openai': OpenAILLM, - 'azure_openai': AzureOpenAILLM, - 'sagemaker': SagemakerAPILLM, - 'huggingface': HuggingFaceLLM, - 'llama.cpp': LlamaCpp, - 'anthropic': AnthropicLLM, - 'docsgpt': DocsGPTAPILLM, - 'premai': PremAILLM, + "openai": OpenAILLM, + "azure_openai": AzureOpenAILLM, + "sagemaker": SagemakerAPILLM, + "huggingface": HuggingFaceLLM, + "llama.cpp": LlamaCpp, + "anthropic": AnthropicLLM, + "docsgpt": DocsGPTAPILLM, + "premai": PremAILLM, } @classmethod - def create_llm(cls, type, *args, **kwargs): + def create_llm(cls, type, api_key, *args, **kwargs): llm_class = cls.llms.get(type.lower()) if not llm_class: raise ValueError(f"No LLM class found for type {type}") - return llm_class(*args, **kwargs) \ No newline at end of file + return llm_class(api_key, *args, **kwargs) diff --git a/application/usage.py b/application/usage.py new file mode 100644 index 0000000..de854a4 --- /dev/null +++ b/application/usage.py @@ -0,0 +1,51 @@ +from pymongo import MongoClient +from bson.son import SON +from datetime import datetime +from application.core.settings import settings +from application.utils import count_tokens + +mongo = MongoClient(settings.MONGO_URI) +db = mongo["docsgpt"] +usage_collection = db["token_usage"] + + +def update_token_usage(api_key, token_usage): + usage_data = { + "api_key": api_key, + "prompt_tokens": token_usage["prompt_tokens"], + "generated_tokens": token_usage["generated_tokens"], + "timestamp": datetime.now(), + } + usage_collection.insert_one(usage_data) + + +def gen_token_usage(func): + def wrapper(self, model, messages, *args, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] + prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" + self.token_usage["prompt_tokens"] += count_tokens(prompt) + result = func(self, model, messages, *args, **kwargs) + self.token_usage["generated_tokens"] += count_tokens(result) + update_token_usage(self.api_key, self.token_usage) + return result + + return wrapper + + +def stream_token_usage(func): + def wrapper(self, model, messages, *args, **kwargs): + context = messages[0]["content"] + user_question = messages[-1]["content"] + prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" + self.token_usage["prompt_tokens"] += count_tokens(prompt) + batch = [] + result = func(self, model, messages, *args, **kwargs) + for r in result: + batch.append(r) + yield r + for line in batch: + self.token_usage["generated_tokens"] += count_tokens(line) + update_token_usage(self.api_key, self.token_usage) + + return wrapper