From 033bcf80d0e02bc6a0994899d62c4c5f5f5319f4 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 8 Jan 2024 23:35:37 +0000 Subject: [PATCH] docsgpt llm provider --- application/core/settings.py | 2 +- application/llm/docsgpt_provider.py | 53 +++++++++++++++++++++++++++++ application/llm/llm_creator.py | 4 ++- 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 application/llm/docsgpt_provider.py diff --git a/application/core/settings.py b/application/core/settings.py index da6de24..2d4169e 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -7,7 +7,7 @@ current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__ class Settings(BaseSettings): - LLM_NAME: str = "openai" + LLM_NAME: str = "docsgpt" EMBEDDINGS_NAME: str = "openai_text-embedding-ada-002" CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py new file mode 100644 index 0000000..dd1f00d --- /dev/null +++ b/application/llm/docsgpt_provider.py @@ -0,0 +1,53 @@ +from application.llm.base import BaseLLM +from application.core.settings import settings +import json +import io +import requests + +class DocsGPTAPILLM(BaseLLM): + + def __init__(self, *args, **kwargs): + self.endpoint = "https://llm.docsgpt.co.uk" + + + def gen(self, model, engine, messages, stream=False, **kwargs): + context = messages[0]['content'] + user_question = messages[-1]['content'] + prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" + + response = requests.post( + f"{self.endpoint}/answer", + json={ + "prompt": prompt, + "max_new_tokens": 30 + } + ) + response_clean = response.json()['a'].split("###")[0] + + return response_clean + + def gen_stream(self, model, engine, messages, stream=True, **kwargs): + context = messages[0]['content'] + user_question = messages[-1]['content'] + prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" + + # send prompt to endpoint /stream + response = requests.post( + f"{self.endpoint}/stream", + json={ + "prompt": prompt, + "max_new_tokens": 256 + }, + stream=True + ) + + for line in response.iter_lines(): + import sys + print(line, file=sys.stderr) + if line: + #data = json.loads(line) + data_str = line.decode('utf-8') + if data_str.startswith("data: "): + data = json.loads(data_str[6:]) + yield data['a'] + \ No newline at end of file diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index bbc2c79..d0d6ae3 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -3,6 +3,7 @@ from application.llm.sagemaker import SagemakerAPILLM from application.llm.huggingface import HuggingFaceLLM from application.llm.llama_cpp import LlamaCpp from application.llm.anthropic import AnthropicLLM +from application.llm.docsgpt_provider import DocsGPTAPILLM @@ -13,7 +14,8 @@ class LLMCreator: 'sagemaker': SagemakerAPILLM, 'huggingface': HuggingFaceLLM, 'llama.cpp': LlamaCpp, - 'anthropic': AnthropicLLM + 'anthropic': AnthropicLLM, + 'docsgpt': DocsGPTAPILLM } @classmethod