diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index b95dd14..af3c5f0 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -184,7 +184,9 @@ def complete_stream(question, retriever, conversation_id, user_api_key): elif "source" in line: source_log_docs.append(line["source"]) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + ) conversation_id = save_conversation( conversation_id, question, response_full, source_log_docs, llm ) @@ -252,7 +254,7 @@ def stream(): prompt=prompt, chunks=chunks, gpt_model=gpt_model, - api_key=user_api_key, + user_api_key=user_api_key, ) return Response( @@ -317,7 +319,7 @@ def api_answer(): prompt=prompt, chunks=chunks, gpt_model=gpt_model, - api_key=user_api_key, + user_api_key=user_api_key, ) source_log_docs = [] response_full = "" @@ -327,7 +329,9 @@ def api_answer(): elif "answer" in line: response_full += line["answer"] - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + ) result = {"answer": response_full, "sources": source_log_docs} result["conversation_id"] = save_conversation( @@ -379,7 +383,7 @@ def api_search(): prompt="default", chunks=chunks, gpt_model=gpt_model, - api_key=user_api_key, + user_api_key=user_api_key, ) docs = retriever.search() return docs diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index a078fc8..4081bcd 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -4,13 +4,14 @@ from application.core.settings import settings class AnthropicLLM(BaseLLM): - def __init__(self, api_key=None, *args, **kwargs): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT super().__init__(*args, **kwargs) self.api_key = ( api_key or settings.ANTHROPIC_API_KEY ) # If not provided, use a default from settings + self.user_api_key = user_api_key self.anthropic = Anthropic(api_key=self.api_key) self.HUMAN_PROMPT = HUMAN_PROMPT self.AI_PROMPT = AI_PROMPT diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index 6dc1d6c..bca3972 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -5,9 +5,10 @@ import requests class DocsGPTAPILLM(BaseLLM): - def __init__(self, api_key=None, *args, **kwargs): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): super().__init__(*args, **kwargs) self.api_key = api_key + self.user_api_key = user_api_key self.endpoint = "https://llm.docsgpt.co.uk" def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs): diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index 0baaceb..2fb4a92 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -4,7 +4,13 @@ from application.llm.base import BaseLLM class HuggingFaceLLM(BaseLLM): def __init__( - self, api_key=None, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs + self, + api_key=None, + user_api_key=None, + llm_name="Arc53/DocsGPT-7B", + q=False, + *args, + **kwargs, ): global hf @@ -37,6 +43,7 @@ class HuggingFaceLLM(BaseLLM): super().__init__(*args, **kwargs) self.api_key = api_key + self.user_api_key = user_api_key pipe = pipeline( "text-generation", model=model, diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index ebaefca..25a2f0c 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -4,7 +4,14 @@ from application.core.settings import settings class LlamaCpp(BaseLLM): - def __init__(self, api_key=None, llm_name=settings.MODEL_PATH, *args, **kwargs): + def __init__( + self, + api_key=None, + user_api_key=None, + llm_name=settings.MODEL_PATH, + *args, + **kwargs, + ): global llama try: from llama_cpp import Llama @@ -15,6 +22,7 @@ class LlamaCpp(BaseLLM): super().__init__(*args, **kwargs) self.api_key = api_key + self.user_api_key = user_api_key llama = Llama(model_path=llm_name, n_ctx=2048) def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index 532f3fc..7960778 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -20,8 +20,8 @@ class LLMCreator: } @classmethod - def create_llm(cls, type, api_key, *args, **kwargs): + def create_llm(cls, type, api_key, user_api_key, *args, **kwargs): llm_class = cls.llms.get(type.lower()) if not llm_class: raise ValueError(f"No LLM class found for type {type}") - return llm_class(api_key, *args, **kwargs) + return llm_class(api_key, user_api_key, *args, **kwargs) diff --git a/application/llm/openai.py b/application/llm/openai.py index b0fd4c5..b1574dd 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -4,7 +4,7 @@ from application.core.settings import settings class OpenAILLM(BaseLLM): - def __init__(self, api_key=None, *args, **kwargs): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): global openai from openai import OpenAI @@ -13,6 +13,7 @@ class OpenAILLM(BaseLLM): api_key=api_key, ) self.api_key = api_key + self.user_api_key = user_api_key def _get_openai(self): # Import openai when needed diff --git a/application/llm/premai.py b/application/llm/premai.py index cdb8063..c250c65 100644 --- a/application/llm/premai.py +++ b/application/llm/premai.py @@ -4,12 +4,13 @@ from application.core.settings import settings class PremAILLM(BaseLLM): - def __init__(self, api_key=None, *args, **kwargs): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): from premai import Prem super().__init__(*args, **kwargs) self.client = Prem(api_key=api_key) self.api_key = api_key + self.user_api_key = user_api_key self.project_id = settings.PREMAI_PROJECT_ID def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 579eec6..6394743 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -60,7 +60,7 @@ class LineIterator: class SagemakerAPILLM(BaseLLM): - def __init__(self, api_key=None, *args, **kwargs): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): import boto3 runtime = boto3.client( @@ -72,6 +72,7 @@ class SagemakerAPILLM(BaseLLM): super().__init__(*args, **kwargs) self.api_key = api_key + self.user_api_key = user_api_key self.endpoint = settings.SAGEMAKER_ENDPOINT self.runtime = runtime diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index d0c81b6..47ca0e7 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -16,7 +16,7 @@ class BraveRetSearch(BaseRetriever): prompt, chunks=2, gpt_model="docsgpt", - api_key=None, + user_api_key=None, ): self.question = question self.source = source @@ -24,7 +24,7 @@ class BraveRetSearch(BaseRetriever): self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model - self.api_key = api_key + self.user_api_key = user_api_key def _get_data(self): if self.chunks == 0: @@ -83,7 +83,9 @@ class BraveRetSearch(BaseRetriever): ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + ) completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 138971f..1bce6f8 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -17,7 +17,7 @@ class ClassicRAG(BaseRetriever): prompt, chunks=2, gpt_model="docsgpt", - api_key=None, + user_api_key=None, ): self.question = question self.vectorstore = self._get_vectorstore(source=source) @@ -25,7 +25,7 @@ class ClassicRAG(BaseRetriever): self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model - self.api_key = api_key + self.user_api_key = user_api_key def _get_vectorstore(self, source): if "active_docs" in source: @@ -98,7 +98,9 @@ class ClassicRAG(BaseRetriever): ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + ) completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index f61c082..9189298 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -16,7 +16,7 @@ class DuckDuckSearch(BaseRetriever): prompt, chunks=2, gpt_model="docsgpt", - api_key=None, + user_api_key=None, ): self.question = question self.source = source @@ -24,7 +24,7 @@ class DuckDuckSearch(BaseRetriever): self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model - self.api_key = api_key + self.user_api_key = user_api_key def _parse_lang_string(self, input_string): result = [] @@ -100,7 +100,9 @@ class DuckDuckSearch(BaseRetriever): ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + ) completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) for line in completion: diff --git a/application/usage.py b/application/usage.py index 131de66..1b26e9d 100644 --- a/application/usage.py +++ b/application/usage.py @@ -9,11 +9,11 @@ db = mongo["docsgpt"] usage_collection = db["token_usage"] -def update_token_usage(api_key, token_usage): +def update_token_usage(user_api_key, token_usage): if "pytest" in sys.modules: return usage_data = { - "api_key": api_key, + "api_key": user_api_key, "prompt_tokens": token_usage["prompt_tokens"], "generated_tokens": token_usage["generated_tokens"], "timestamp": datetime.now(), @@ -27,7 +27,7 @@ def gen_token_usage(func): self.token_usage["prompt_tokens"] += count_tokens(message["content"]) result = func(self, model, messages, stream, **kwargs) self.token_usage["generated_tokens"] += count_tokens(result) - update_token_usage(self.api_key, self.token_usage) + update_token_usage(self.user_api_key, self.token_usage) return result return wrapper @@ -44,6 +44,6 @@ def stream_token_usage(func): yield r for line in batch: self.token_usage["generated_tokens"] += count_tokens(line) - update_token_usage(self.api_key, self.token_usage) + update_token_usage(self.user_api_key, self.token_usage) return wrapper diff --git a/results.txt b/results.txt new file mode 100644 index 0000000..beff805 --- /dev/null +++ b/results.txt @@ -0,0 +1,12 @@ +Base URL:http://petstore.swagger.io,https://api.example.com +Path1: /pets +description: None +parameters: [] +methods: +get=A paged array of pets +post=Null response +Path2: /pets/{petId} +description: None +parameters: [] +methods: +get=Expected response to a valid request