From 590aa8b43f0f6fab07e4d95c430d4ced9f9388a0 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 15 Apr 2024 18:57:28 +0530 Subject: [PATCH] update: apply decorator to abstract classes --- application/llm/base.py | 18 ++++++++++++++++-- application/llm/docsgpt_provider.py | 7 ++----- application/usage.py | 20 ++++++++------------ 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/application/llm/base.py b/application/llm/base.py index 65cb8b1..475b793 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -1,14 +1,28 @@ from abc import ABC, abstractmethod +from application.usage import gen_token_usage, stream_token_usage class BaseLLM(ABC): def __init__(self): self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} + def _apply_decorator(self, method, decorator, *args, **kwargs): + return decorator(method, *args, **kwargs) + @abstractmethod - def gen(self, *args, **kwargs): + def _raw_gen(self, model, messages, stream, *args, **kwargs): pass + def gen(self, model, messages, stream=False, *args, **kwargs): + return self._apply_decorator(self._raw_gen, gen_token_usage)( + self, model=model, messages=messages, stream=stream, *args, **kwargs + ) + @abstractmethod - def gen_stream(self, *args, **kwargs): + def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): pass + + def gen_stream(self, model, messages, stream=True, *args, **kwargs): + return self._apply_decorator(self._raw_gen_stream, stream_token_usage)( + self, model=model, messages=messages, stream=stream, *args, **kwargs + ) diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index a46abaa..ffe1e31 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -1,7 +1,6 @@ from application.llm.base import BaseLLM import json import requests -from application.usage import gen_token_usage, stream_token_usage class DocsGPTAPILLM(BaseLLM): @@ -11,8 +10,7 @@ class DocsGPTAPILLM(BaseLLM): self.api_key = api_key self.endpoint = "https://llm.docsgpt.co.uk" - @gen_token_usage - def gen(self, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -24,8 +22,7 @@ class DocsGPTAPILLM(BaseLLM): return response_clean - @stream_token_usage - def gen_stream(self, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, *args, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/usage.py b/application/usage.py index 2fc307a..95cd02f 100644 --- a/application/usage.py +++ b/application/usage.py @@ -19,12 +19,10 @@ def update_token_usage(api_key, token_usage): def gen_token_usage(func): - def wrapper(self, model, messages, *args, **kwargs): - context = messages[0]["content"] - user_question = messages[-1]["content"] - prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - self.token_usage["prompt_tokens"] += count_tokens(prompt) - result = func(self, model, messages, *args, **kwargs) + def wrapper(self, model, messages, stream, **kwargs): + for message in messages: + self.token_usage["prompt_tokens"] += count_tokens(message["content"]) + result = func(self, model, messages, stream, **kwargs) self.token_usage["generated_tokens"] += count_tokens(result) update_token_usage(self.api_key, self.token_usage) return result @@ -33,13 +31,11 @@ def gen_token_usage(func): def stream_token_usage(func): - def wrapper(self, model, messages, *args, **kwargs): - context = messages[0]["content"] - user_question = messages[-1]["content"] - prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - self.token_usage["prompt_tokens"] += count_tokens(prompt) + def wrapper(self, model, messages, stream, **kwargs): + for message in messages: + self.token_usage["prompt_tokens"] += count_tokens(message["content"]) batch = [] - result = func(self, model, messages, *args, **kwargs) + result = func(self, model, messages, stream, **kwargs) for r in result: batch.append(r) yield r