mirror of
https://github.com/arc53/DocsGPT
synced 2024-11-19 21:25:39 +00:00
Merge pull request #352 from arc53/feature/aws-sagemaker-inference
sagemaker + llm creator class
This commit is contained in:
commit
833e1836e1
@ -13,7 +13,7 @@ from transformers import GPT2TokenizerFast
|
|||||||
|
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.llm.openai import OpenAILLM, AzureOpenAILLM
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.vectorstore.faiss import FaissStore
|
from application.vectorstore.faiss import FaissStore
|
||||||
from application.error import bad_request
|
from application.error import bad_request
|
||||||
|
|
||||||
@ -128,16 +128,8 @@ def is_azure_configured():
|
|||||||
|
|
||||||
|
|
||||||
def complete_stream(question, docsearch, chat_history, api_key, conversation_id):
|
def complete_stream(question, docsearch, chat_history, api_key, conversation_id):
|
||||||
if is_azure_configured():
|
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
|
||||||
llm = AzureOpenAILLM(
|
|
||||||
openai_api_key=api_key,
|
|
||||||
openai_api_base=settings.OPENAI_API_BASE,
|
|
||||||
openai_api_version=settings.OPENAI_API_VERSION,
|
|
||||||
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug("plain OpenAI")
|
|
||||||
llm = OpenAILLM(api_key=api_key)
|
|
||||||
|
|
||||||
docs = docsearch.search(question, k=2)
|
docs = docsearch.search(question, k=2)
|
||||||
# join all page_content together with a newline
|
# join all page_content together with a newline
|
||||||
@ -270,16 +262,8 @@ def api_answer():
|
|||||||
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
|
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
|
||||||
docsearch = FaissStore(vectorstore, embeddings_key)
|
docsearch = FaissStore(vectorstore, embeddings_key)
|
||||||
|
|
||||||
if is_azure_configured():
|
|
||||||
llm = AzureOpenAILLM(
|
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
|
||||||
openai_api_key=api_key,
|
|
||||||
openai_api_base=settings.OPENAI_API_BASE,
|
|
||||||
openai_api_version=settings.OPENAI_API_VERSION,
|
|
||||||
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug("plain OpenAI")
|
|
||||||
llm = OpenAILLM(api_key=api_key)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from pydantic import BaseSettings
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
LLM_NAME: str = "openai_chat"
|
LLM_NAME: str = "openai"
|
||||||
EMBEDDINGS_NAME: str = "openai_text-embedding-ada-002"
|
EMBEDDINGS_NAME: str = "openai_text-embedding-ada-002"
|
||||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||||
|
20
application/llm/llm_creator.py
Normal file
20
application/llm/llm_creator.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from application.llm.openai import OpenAILLM, AzureOpenAILLM
|
||||||
|
from application.llm.sagemaker import SagemakerAPILLM
|
||||||
|
from application.llm.huggingface import HuggingFaceLLM
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCreator:
|
||||||
|
llms = {
|
||||||
|
'openai': OpenAILLM,
|
||||||
|
'azure_openai': AzureOpenAILLM,
|
||||||
|
'sagemaker': SagemakerAPILLM,
|
||||||
|
'huggingface': HuggingFaceLLM
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_llm(cls, type, *args, **kwargs):
|
||||||
|
llm_class = cls.llms.get(type.lower())
|
||||||
|
if not llm_class:
|
||||||
|
raise ValueError(f"No LLM class found for type {type}")
|
||||||
|
return llm_class(*args, **kwargs)
|
@ -1,4 +1,5 @@
|
|||||||
from application.llm.base import BaseLLM
|
from application.llm.base import BaseLLM
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
class OpenAILLM(BaseLLM):
|
class OpenAILLM(BaseLLM):
|
||||||
|
|
||||||
@ -44,9 +45,9 @@ class AzureOpenAILLM(OpenAILLM):
|
|||||||
|
|
||||||
def __init__(self, openai_api_key, openai_api_base, openai_api_version, deployment_name):
|
def __init__(self, openai_api_key, openai_api_base, openai_api_version, deployment_name):
|
||||||
super().__init__(openai_api_key)
|
super().__init__(openai_api_key)
|
||||||
self.api_base = openai_api_base
|
self.api_base = settings.OPENAI_API_BASE,
|
||||||
self.api_version = openai_api_version
|
self.api_version = settings.OPENAI_API_VERSION,
|
||||||
self.deployment_name = deployment_name
|
self.deployment_name = settings.AZURE_DEPLOYMENT_NAME,
|
||||||
|
|
||||||
def _get_openai(self):
|
def _get_openai(self):
|
||||||
openai = super()._get_openai()
|
openai = super()._get_openai()
|
||||||
|
27
application/llm/sagemaker.py
Normal file
27
application/llm/sagemaker.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from application.llm.base import BaseLLM
|
||||||
|
from application.core.settings import settings
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
class SagemakerAPILLM(BaseLLM):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.url = settings.SAGEMAKER_API_URL
|
||||||
|
|
||||||
|
def gen(self, model, engine, messages, stream=False, **kwargs):
|
||||||
|
context = messages[0]['content']
|
||||||
|
user_question = messages[-1]['content']
|
||||||
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
url=self.url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
},
|
||||||
|
data=json.dumps({"input": prompt})
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.json()['answer']
|
||||||
|
|
||||||
|
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
||||||
|
raise NotImplementedError("Sagemaker does not support streaming")
|
Loading…
Reference in New Issue
Block a user