update: apply decorator to abstract classes

pull/926/head
Siddhant Rai 2 months ago
parent 262d160314
commit 590aa8b43f

@ -1,14 +1,28 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from application.usage import gen_token_usage, stream_token_usage
class BaseLLM(ABC): class BaseLLM(ABC):
def __init__(self): def __init__(self):
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
def _apply_decorator(self, method, decorator, *args, **kwargs):
return decorator(method, *args, **kwargs)
@abstractmethod @abstractmethod
def gen(self, *args, **kwargs): def _raw_gen(self, model, messages, stream, *args, **kwargs):
pass pass
def gen(self, model, messages, stream=False, *args, **kwargs):
return self._apply_decorator(self._raw_gen, gen_token_usage)(
self, model=model, messages=messages, stream=stream, *args, **kwargs
)
@abstractmethod @abstractmethod
def gen_stream(self, *args, **kwargs): def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
pass pass
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
return self._apply_decorator(self._raw_gen_stream, stream_token_usage)(
self, model=model, messages=messages, stream=stream, *args, **kwargs
)

@ -1,7 +1,6 @@
from application.llm.base import BaseLLM from application.llm.base import BaseLLM
import json import json
import requests import requests
from application.usage import gen_token_usage, stream_token_usage
class DocsGPTAPILLM(BaseLLM): class DocsGPTAPILLM(BaseLLM):
@ -11,8 +10,7 @@ class DocsGPTAPILLM(BaseLLM):
self.api_key = api_key self.api_key = api_key
self.endpoint = "https://llm.docsgpt.co.uk" self.endpoint = "https://llm.docsgpt.co.uk"
@gen_token_usage def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs):
def gen(self, model, messages, stream=False, **kwargs):
context = messages[0]["content"] context = messages[0]["content"]
user_question = messages[-1]["content"] user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@ -24,8 +22,7 @@ class DocsGPTAPILLM(BaseLLM):
return response_clean return response_clean
@stream_token_usage def _raw_gen_stream(self, baseself, model, messages, stream=True, *args, **kwargs):
def gen_stream(self, model, messages, stream=True, **kwargs):
context = messages[0]["content"] context = messages[0]["content"]
user_question = messages[-1]["content"] user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

@ -19,12 +19,10 @@ def update_token_usage(api_key, token_usage):
def gen_token_usage(func): def gen_token_usage(func):
def wrapper(self, model, messages, *args, **kwargs): def wrapper(self, model, messages, stream, **kwargs):
context = messages[0]["content"] for message in messages:
user_question = messages[-1]["content"] self.token_usage["prompt_tokens"] += count_tokens(message["content"])
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" result = func(self, model, messages, stream, **kwargs)
self.token_usage["prompt_tokens"] += count_tokens(prompt)
result = func(self, model, messages, *args, **kwargs)
self.token_usage["generated_tokens"] += count_tokens(result) self.token_usage["generated_tokens"] += count_tokens(result)
update_token_usage(self.api_key, self.token_usage) update_token_usage(self.api_key, self.token_usage)
return result return result
@ -33,13 +31,11 @@ def gen_token_usage(func):
def stream_token_usage(func): def stream_token_usage(func):
def wrapper(self, model, messages, *args, **kwargs): def wrapper(self, model, messages, stream, **kwargs):
context = messages[0]["content"] for message in messages:
user_question = messages[-1]["content"] self.token_usage["prompt_tokens"] += count_tokens(message["content"])
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
self.token_usage["prompt_tokens"] += count_tokens(prompt)
batch = [] batch = []
result = func(self, model, messages, *args, **kwargs) result = func(self, model, messages, stream, **kwargs)
for r in result: for r in result:
batch.append(r) batch.append(r)
yield r yield r

Loading…
Cancel
Save