From 60a670ce29080ac672baa1654ace57852a9fed98 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 15 Apr 2024 19:47:24 +0530 Subject: [PATCH] fix: changes to llm classes according to base --- application/llm/anthropic.py | 3 ++- application/llm/huggingface.py | 4 +++- application/llm/llama_cpp.py | 4 +++- application/llm/openai.py | 3 ++- application/llm/premai.py | 3 ++- application/llm/sagemaker.py | 4 +++- 6 files changed, 15 insertions(+), 6 deletions(-) diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index 70495f0..b3fde3d 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -4,9 +4,10 @@ from application.core.settings import settings class AnthropicLLM(BaseLLM): - def __init__(self, api_key=None): + def __init__(self, api_key=None, *args, **kwargs): from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT + super().__init__(*args, **kwargs) self.api_key = ( api_key or settings.ANTHROPIC_API_KEY ) # If not provided, use a default from settings diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index c9e500e..b1118ed 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -3,7 +3,7 @@ from application.llm.base import BaseLLM class HuggingFaceLLM(BaseLLM): - def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False): + def __init__(self, api_key, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs): global hf from langchain.llms import HuggingFacePipeline @@ -33,6 +33,8 @@ class HuggingFaceLLM(BaseLLM): tokenizer = AutoTokenizer.from_pretrained(llm_name) model = AutoModelForCausalLM.from_pretrained(llm_name) + super().__init__(*args, **kwargs) + self.api_key = api_key pipe = pipeline( "text-generation", model=model, diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index 1512cd7..896e66f 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -4,7 +4,7 @@ from application.core.settings import settings class LlamaCpp(BaseLLM): - def __init__(self, api_key, llm_name=settings.MODEL_PATH, **kwargs): + def __init__(self, api_key, llm_name=settings.MODEL_PATH, *args, **kwargs): global llama try: from llama_cpp import Llama @@ -13,6 +13,8 @@ class LlamaCpp(BaseLLM): "Please install llama_cpp using pip install llama-cpp-python" ) + super().__init__(*args, **kwargs) + self.api_key = api_key llama = Llama(model_path=llm_name, n_ctx=2048) def _raw_gen(self, model, messages, stream=False, **kwargs): diff --git a/application/llm/openai.py b/application/llm/openai.py index de29246..c741404 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -4,10 +4,11 @@ from application.core.settings import settings class OpenAILLM(BaseLLM): - def __init__(self, api_key): + def __init__(self, api_key, *args, **kwargs): global openai from openai import OpenAI + super().__init__(*args, **kwargs) self.client = OpenAI( api_key=api_key, ) diff --git a/application/llm/premai.py b/application/llm/premai.py index c0552ea..203ff4d 100644 --- a/application/llm/premai.py +++ b/application/llm/premai.py @@ -4,9 +4,10 @@ from application.core.settings import settings class PremAILLM(BaseLLM): - def __init__(self, api_key): + def __init__(self, api_key, *args, **kwargs): from premai import Prem + super().__init__(*args, **kwargs) self.client = Prem(api_key=api_key) self.api_key = api_key self.project_id = settings.PREMAI_PROJECT_ID diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index b531020..807bfa2 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -60,7 +60,7 @@ class LineIterator: class SagemakerAPILLM(BaseLLM): - def __init__(self, *args, **kwargs): + def __init__(self, api_key, *args, **kwargs): import boto3 runtime = boto3.client( @@ -70,6 +70,8 @@ class SagemakerAPILLM(BaseLLM): region_name="us-west-2", ) + super().__init__(*args, **kwargs) + self.api_key = api_key self.endpoint = settings.SAGEMAKER_ENDPOINT self.runtime = runtime