feat: logging token usage to database

pull/926/head
Siddhant Rai 1 month ago
parent 00b6639155
commit ba796b6be1

@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
class BaseLLM(ABC):
def __init__(self):
pass
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@abstractmethod
def gen(self, *args, **kwargs):

@ -1,49 +1,46 @@
from application.llm.base import BaseLLM
import json
import requests
from application.usage import gen_token_usage, stream_token_usage
class DocsGPTAPILLM(BaseLLM):
def __init__(self, *args, **kwargs):
self.endpoint = "https://llm.docsgpt.co.uk"
class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.endpoint = "https://llm.docsgpt.co.uk"
@gen_token_usage
def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
response = requests.post(
f"{self.endpoint}/answer",
json={
"prompt": prompt,
"max_new_tokens": 30
}
f"{self.endpoint}/answer", json={"prompt": prompt, "max_new_tokens": 30}
)
response_clean = response.json()['a'].replace("###", "")
response_clean = response.json()["a"].replace("###", "")
return response_clean
@stream_token_usage
def gen_stream(self, model, messages, stream=True, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
# send prompt to endpoint /stream
response = requests.post(
f"{self.endpoint}/stream",
json={
"prompt": prompt,
"max_new_tokens": 256
},
stream=True
json={"prompt": prompt, "max_new_tokens": 256},
stream=True,
)
for line in response.iter_lines():
if line:
#data = json.loads(line)
data_str = line.decode('utf-8')
# data = json.loads(line)
data_str = line.decode("utf-8")
if data_str.startswith("data: "):
data = json.loads(data_str[6:])
yield data['a']
yield data["a"]

@ -7,22 +7,21 @@ from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.premai import PremAILLM
class LLMCreator:
llms = {
'openai': OpenAILLM,
'azure_openai': AzureOpenAILLM,
'sagemaker': SagemakerAPILLM,
'huggingface': HuggingFaceLLM,
'llama.cpp': LlamaCpp,
'anthropic': AnthropicLLM,
'docsgpt': DocsGPTAPILLM,
'premai': PremAILLM,
"openai": OpenAILLM,
"azure_openai": AzureOpenAILLM,
"sagemaker": SagemakerAPILLM,
"huggingface": HuggingFaceLLM,
"llama.cpp": LlamaCpp,
"anthropic": AnthropicLLM,
"docsgpt": DocsGPTAPILLM,
"premai": PremAILLM,
}
@classmethod
def create_llm(cls, type, *args, **kwargs):
def create_llm(cls, type, api_key, *args, **kwargs):
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
return llm_class(*args, **kwargs)
return llm_class(api_key, *args, **kwargs)

@ -0,0 +1,51 @@
from pymongo import MongoClient
from bson.son import SON
from datetime import datetime
from application.core.settings import settings
from application.utils import count_tokens
mongo = MongoClient(settings.MONGO_URI)
db = mongo["docsgpt"]
usage_collection = db["token_usage"]
def update_token_usage(api_key, token_usage):
usage_data = {
"api_key": api_key,
"prompt_tokens": token_usage["prompt_tokens"],
"generated_tokens": token_usage["generated_tokens"],
"timestamp": datetime.now(),
}
usage_collection.insert_one(usage_data)
def gen_token_usage(func):
def wrapper(self, model, messages, *args, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
self.token_usage["prompt_tokens"] += count_tokens(prompt)
result = func(self, model, messages, *args, **kwargs)
self.token_usage["generated_tokens"] += count_tokens(result)
update_token_usage(self.api_key, self.token_usage)
return result
return wrapper
def stream_token_usage(func):
def wrapper(self, model, messages, *args, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
self.token_usage["prompt_tokens"] += count_tokens(prompt)
batch = []
result = func(self, model, messages, *args, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
self.token_usage["generated_tokens"] += count_tokens(line)
update_token_usage(self.api_key, self.token_usage)
return wrapper
Loading…
Cancel
Save