fix: user_api_key capturing

pull/926/head
Siddhant Rai 3 weeks ago
parent 333b6e60e1
commit af5e73c8cb

@ -184,7 +184,9 @@ def complete_stream(question, retriever, conversation_id, user_api_key):
elif "source" in line:
source_log_docs.append(line["source"])
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key)
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
@ -252,7 +254,7 @@ def stream():
prompt=prompt,
chunks=chunks,
gpt_model=gpt_model,
api_key=user_api_key,
user_api_key=user_api_key,
)
return Response(
@ -317,7 +319,7 @@ def api_answer():
prompt=prompt,
chunks=chunks,
gpt_model=gpt_model,
api_key=user_api_key,
user_api_key=user_api_key,
)
source_log_docs = []
response_full = ""
@ -327,7 +329,9 @@ def api_answer():
elif "answer" in line:
response_full += line["answer"]
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key)
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = save_conversation(
@ -379,7 +383,7 @@ def api_search():
prompt="default",
chunks=chunks,
gpt_model=gpt_model,
api_key=user_api_key,
user_api_key=user_api_key,
)
docs = retriever.search()
return docs

@ -4,13 +4,14 @@ from application.core.settings import settings
class AnthropicLLM(BaseLLM):
def __init__(self, api_key=None, *args, **kwargs):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
super().__init__(*args, **kwargs)
self.api_key = (
api_key or settings.ANTHROPIC_API_KEY
) # If not provided, use a default from settings
self.user_api_key = user_api_key
self.anthropic = Anthropic(api_key=self.api_key)
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT

@ -5,9 +5,10 @@ import requests
class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key=None, *args, **kwargs):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
self.endpoint = "https://llm.docsgpt.co.uk"
def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs):

@ -4,7 +4,13 @@ from application.llm.base import BaseLLM
class HuggingFaceLLM(BaseLLM):
def __init__(
self, api_key=None, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs
self,
api_key=None,
user_api_key=None,
llm_name="Arc53/DocsGPT-7B",
q=False,
*args,
**kwargs,
):
global hf
@ -37,6 +43,7 @@ class HuggingFaceLLM(BaseLLM):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
pipe = pipeline(
"text-generation",
model=model,

@ -4,7 +4,14 @@ from application.core.settings import settings
class LlamaCpp(BaseLLM):
def __init__(self, api_key=None, llm_name=settings.MODEL_PATH, *args, **kwargs):
def __init__(
self,
api_key=None,
user_api_key=None,
llm_name=settings.MODEL_PATH,
*args,
**kwargs,
):
global llama
try:
from llama_cpp import Llama
@ -15,6 +22,7 @@ class LlamaCpp(BaseLLM):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
llama = Llama(model_path=llm_name, n_ctx=2048)
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):

@ -20,8 +20,8 @@ class LLMCreator:
}
@classmethod
def create_llm(cls, type, api_key, *args, **kwargs):
def create_llm(cls, type, api_key, user_api_key, *args, **kwargs):
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
return llm_class(api_key, *args, **kwargs)
return llm_class(api_key, user_api_key, *args, **kwargs)

@ -4,7 +4,7 @@ from application.core.settings import settings
class OpenAILLM(BaseLLM):
def __init__(self, api_key=None, *args, **kwargs):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
global openai
from openai import OpenAI
@ -13,6 +13,7 @@ class OpenAILLM(BaseLLM):
api_key=api_key,
)
self.api_key = api_key
self.user_api_key = user_api_key
def _get_openai(self):
# Import openai when needed

@ -4,12 +4,13 @@ from application.core.settings import settings
class PremAILLM(BaseLLM):
def __init__(self, api_key=None, *args, **kwargs):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from premai import Prem
super().__init__(*args, **kwargs)
self.client = Prem(api_key=api_key)
self.api_key = api_key
self.user_api_key = user_api_key
self.project_id = settings.PREMAI_PROJECT_ID
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):

@ -60,7 +60,7 @@ class LineIterator:
class SagemakerAPILLM(BaseLLM):
def __init__(self, api_key=None, *args, **kwargs):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
import boto3
runtime = boto3.client(
@ -72,6 +72,7 @@ class SagemakerAPILLM(BaseLLM):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
self.endpoint = settings.SAGEMAKER_ENDPOINT
self.runtime = runtime

@ -16,7 +16,7 @@ class BraveRetSearch(BaseRetriever):
prompt,
chunks=2,
gpt_model="docsgpt",
api_key=None,
user_api_key=None,
):
self.question = question
self.source = source
@ -24,7 +24,7 @@ class BraveRetSearch(BaseRetriever):
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.api_key = api_key
self.user_api_key = user_api_key
def _get_data(self):
if self.chunks == 0:
@ -83,7 +83,9 @@ class BraveRetSearch(BaseRetriever):
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key)
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
for line in completion:

@ -17,7 +17,7 @@ class ClassicRAG(BaseRetriever):
prompt,
chunks=2,
gpt_model="docsgpt",
api_key=None,
user_api_key=None,
):
self.question = question
self.vectorstore = self._get_vectorstore(source=source)
@ -25,7 +25,7 @@ class ClassicRAG(BaseRetriever):
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.api_key = api_key
self.user_api_key = user_api_key
def _get_vectorstore(self, source):
if "active_docs" in source:
@ -98,7 +98,9 @@ class ClassicRAG(BaseRetriever):
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key)
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
for line in completion:

@ -16,7 +16,7 @@ class DuckDuckSearch(BaseRetriever):
prompt,
chunks=2,
gpt_model="docsgpt",
api_key=None,
user_api_key=None,
):
self.question = question
self.source = source
@ -24,7 +24,7 @@ class DuckDuckSearch(BaseRetriever):
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.api_key = api_key
self.user_api_key = user_api_key
def _parse_lang_string(self, input_string):
result = []
@ -100,7 +100,9 @@ class DuckDuckSearch(BaseRetriever):
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key)
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
for line in completion:

@ -9,11 +9,11 @@ db = mongo["docsgpt"]
usage_collection = db["token_usage"]
def update_token_usage(api_key, token_usage):
def update_token_usage(user_api_key, token_usage):
if "pytest" in sys.modules:
return
usage_data = {
"api_key": api_key,
"api_key": user_api_key,
"prompt_tokens": token_usage["prompt_tokens"],
"generated_tokens": token_usage["generated_tokens"],
"timestamp": datetime.now(),
@ -27,7 +27,7 @@ def gen_token_usage(func):
self.token_usage["prompt_tokens"] += count_tokens(message["content"])
result = func(self, model, messages, stream, **kwargs)
self.token_usage["generated_tokens"] += count_tokens(result)
update_token_usage(self.api_key, self.token_usage)
update_token_usage(self.user_api_key, self.token_usage)
return result
return wrapper
@ -44,6 +44,6 @@ def stream_token_usage(func):
yield r
for line in batch:
self.token_usage["generated_tokens"] += count_tokens(line)
update_token_usage(self.api_key, self.token_usage)
update_token_usage(self.user_api_key, self.token_usage)
return wrapper

@ -0,0 +1,12 @@
Base URL:http://petstore.swagger.io,https://api.example.com
Path1: /pets
description: None
parameters: []
methods:
get=A paged array of pets
post=Null response
Path2: /pets/{petId}
description: None
parameters: []
methods:
get=Expected response to a valid request
Loading…
Cancel
Save