From 4f5e363452862406034b61a2cf80f6d5b5f0966a Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 6 Oct 2023 01:52:29 +0100 Subject: [PATCH 1/7] sagemaker streaming --- application/core/settings.py | 6 ++ application/llm/sagemaker.py | 134 ++++++++++++++++++++++++++++++++--- 2 files changed, 130 insertions(+), 10 deletions(-) diff --git a/application/core/settings.py b/application/core/settings.py index a05fd00..116735a 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -32,6 +32,12 @@ class Settings(BaseSettings): ELASTIC_URL: str = None # url for elasticsearch ELASTIC_INDEX: str = "docsgpt" # index name for elasticsearch + # SageMaker config + SAGEMAKER_ENDPOINT: str = None # SageMaker endpoint name + SAGEMAKER_REGION: str = None # SageMaker region name + SAGEMAKER_ACCESS_KEY: str = None # SageMaker access key + SAGEMAKER_SECRET_KEY: str = None # SageMaker secret key + path = Path(__file__).parent.parent.absolute() settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 9ef5d0a..55617cc 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -2,26 +2,140 @@ from application.llm.base import BaseLLM from application.core.settings import settings import requests import json +import io +import json + + + +class LineIterator: + """ + A helper class for parsing the byte stream input. + + The output of the model will be in the following format: + ``` + b'{"outputs": [" a"]}\n' + b'{"outputs": [" challenging"]}\n' + b'{"outputs": [" problem"]}\n' + ... + ``` + + While usually each PayloadPart event from the event stream will contain a byte array + with a full json, this is not guaranteed and some of the json objects may be split across + PayloadPart events. For example: + ``` + {'PayloadPart': {'Bytes': b'{"outputs": '}} + {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} + ``` + + This class accounts for this by concatenating bytes written via the 'write' function + and then exposing a method which will return lines (ending with a '\n' character) within + the buffer via the 'scan_lines' function. It maintains the position of the last read + position to ensure that previous bytes are not exposed again. + """ + + def __init__(self, stream): + self.byte_iterator = iter(stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def __iter__(self): + return self + + def __next__(self): + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord('\n'): + self.read_pos += len(line) + return line[:-1] + try: + chunk = next(self.byte_iterator) + except StopIteration: + if self.read_pos < self.buffer.getbuffer().nbytes: + continue + raise + if 'PayloadPart' not in chunk: + print('Unknown event type:' + chunk) + continue + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk['PayloadPart']['Bytes']) class SagemakerAPILLM(BaseLLM): def __init__(self, *args, **kwargs): - self.url = settings.SAGEMAKER_API_URL + import boto3 + runtime = boto3.client( + 'runtime.sagemaker', + aws_access_key_id=settings.SAGEMAKER_ACCESS_KEY, + aws_secret_access_key=settings.SAGEMAKER_SECRET_KEY, + region_name=settings.SAGEMAKER_REGION + ) + + + self.endpoint = settings.SAGEMAKER_ENDPOINT + self.runtime = runtime + def gen(self, model, engine, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" + - response = requests.post( - url=self.url, - headers={ - "Content-Type": "application/json; charset=utf-8", - }, - data=json.dumps({"input": prompt}) - ) + # Construct payload for endpoint + payload = { + "inputs": prompt, + "stream": False, + "parameters": { + "do_sample": True, + "temperature": 0.1, + "max_new_tokens": 30, + "repetition_penalty": 1.03, + "stop": ["", "###"] + } + } + body_bytes = json.dumps(payload).encode('utf-8') - return response.json()['answer'] + # Invoke the endpoint + response = self.runtime.invoke_endpoint(EndpointName=self.endpoint, + ContentType='application/json', + Body=body_bytes) + result = json.loads(response['Body'].read().decode()) + import sys + print(result[0]['generated_text'], file=sys.stderr) + return result[0]['generated_text'][len(prompt):] def gen_stream(self, model, engine, messages, stream=True, **kwargs): - raise NotImplementedError("Sagemaker does not support streaming") \ No newline at end of file + context = messages[0]['content'] + user_question = messages[-1]['content'] + prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" + + + # Construct payload for endpoint + payload = { + "inputs": prompt, + "stream": True, + "parameters": { + "do_sample": True, + "temperature": 0.1, + "max_new_tokens": 512, + "repetition_penalty": 1.03, + "stop": ["", "###"] + } + } + body_bytes = json.dumps(payload).encode('utf-8') + + # Invoke the endpoint + response = self.runtime.invoke_endpoint_with_response_stream(EndpointName=self.endpoint, + ContentType='application/json', + Body=body_bytes) + #result = json.loads(response['Body'].read().decode()) + event_stream = response['Body'] + start_json = b'{' + for line in LineIterator(event_stream): + if line != b'' and start_json in line: + #print(line) + data = json.loads(line[line.find(start_json):].decode('utf-8')) + if data['token']['text'] not in ["", "###"]: + print(data['token']['text'],end='') + yield data['token']['text'] \ No newline at end of file From 495728593f963f33b9f6f1378216faa9c49b74fc Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 6 Oct 2023 13:22:51 +0100 Subject: [PATCH 2/7] sagemaker fixes + test --- application/llm/sagemaker.py | 2 - tests/llm/test_sagemaker.py | 96 ++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 tests/llm/test_sagemaker.py diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 55617cc..ed5fc67 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -1,9 +1,7 @@ from application.llm.base import BaseLLM from application.core.settings import settings -import requests import json import io -import json diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py new file mode 100644 index 0000000..f8d02d8 --- /dev/null +++ b/tests/llm/test_sagemaker.py @@ -0,0 +1,96 @@ +# FILEPATH: /path/to/test_sagemaker.py + +import json +import unittest +from unittest.mock import MagicMock, patch +from application.llm.sagemaker import SagemakerAPILLM, LineIterator + +class TestSagemakerAPILLM(unittest.TestCase): + + def setUp(self): + self.sagemaker = SagemakerAPILLM() + self.context = "This is the context" + self.user_question = "What is the answer?" + self.messages = [ + {"content": self.context}, + {"content": "Some other message"}, + {"content": self.user_question} + ] + self.prompt = f"### Instruction \n {self.user_question} \n ### Context \n {self.context} \n ### Answer \n" + self.payload = { + "inputs": self.prompt, + "stream": False, + "parameters": { + "do_sample": True, + "temperature": 0.1, + "max_new_tokens": 30, + "repetition_penalty": 1.03, + "stop": ["", "###"] + } + } + self.payload_stream = { + "inputs": self.prompt, + "stream": True, + "parameters": { + "do_sample": True, + "temperature": 0.1, + "max_new_tokens": 512, + "repetition_penalty": 1.03, + "stop": ["", "###"] + } + } + self.body_bytes = json.dumps(self.payload).encode('utf-8') + self.body_bytes_stream = json.dumps(self.payload_stream).encode('utf-8') + self.response = { + "Body": MagicMock() + } + self.result = [ + { + "generated_text": "This is the generated text" + } + ] + self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result) + + def test_gen(self): + with patch.object(self.sagemaker.runtime, 'invoke_endpoint', + return_value=self.response) as mock_invoke_endpoint: + output = self.sagemaker.gen(None, None, self.messages) + mock_invoke_endpoint.assert_called_once_with( + EndpointName=self.sagemaker.endpoint, + ContentType='application/json', + Body=self.body_bytes + ) + self.assertEqual(output, + self.result[0]['generated_text'][len(self.prompt):]) + + def test_gen_stream(self): + with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', + return_value=self.response) as mock_invoke_endpoint: + output = list(self.sagemaker.gen_stream(None, None, self.messages)) + mock_invoke_endpoint.assert_called_once_with( + EndpointName=self.sagemaker.endpoint, + ContentType='application/json', + Body=self.body_bytes_stream + ) + self.assertEqual(output, []) + +class TestLineIterator(unittest.TestCase): + + def setUp(self): + self.stream = [ + {'PayloadPart': {'Bytes': b'{"outputs": [" a"]}\n'}}, + {'PayloadPart': {'Bytes': b'{"outputs": [" challenging"]}\n'}}, + {'PayloadPart': {'Bytes': b'{"outputs": [" problem"]}\n'}} + ] + self.line_iterator = LineIterator(self.stream) + + def test_iter(self): + self.assertEqual(iter(self.line_iterator), self.line_iterator) + + def test_next(self): + self.assertEqual(next(self.line_iterator), b'{"outputs": [" a"]}') + self.assertEqual(next(self.line_iterator), b'{"outputs": [" challenging"]}') + self.assertEqual(next(self.line_iterator), b'{"outputs": [" problem"]}') + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From b3a0368b952c58f0cb77302dbd97809ad2398aa9 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 6 Oct 2023 13:54:03 +0100 Subject: [PATCH 3/7] Update app.py --- application/app.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/application/app.py b/application/app.py index e97febc..ae61997 100644 --- a/application/app.py +++ b/application/app.py @@ -27,7 +27,10 @@ celery.config_from_object("application.celeryconfig") @app.route("/") def home(): - return redirect('http://localhost:5173') if request.remote_addr in ('0.0.0.0', '127.0.0.1', 'localhost', '172.18.0.1') else 'Welcome to DocsGPT Backend!' + if request.remote_addr in ('0.0.0.0', '127.0.0.1', 'localhost', '172.18.0.1'): + return redirect('http://localhost:5173') + else: + return 'Welcome to DocsGPT Backend!' @app.after_request def after_request(response): From 43a22f84d9c934ee1aa11481acf55eb70e2b1bd4 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 6 Oct 2023 14:43:05 +0100 Subject: [PATCH 4/7] Update sagemaker.py --- application/llm/sagemaker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index ed5fc67..84ae09a 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -64,9 +64,9 @@ class SagemakerAPILLM(BaseLLM): import boto3 runtime = boto3.client( 'runtime.sagemaker', - aws_access_key_id=settings.SAGEMAKER_ACCESS_KEY, - aws_secret_access_key=settings.SAGEMAKER_SECRET_KEY, - region_name=settings.SAGEMAKER_REGION + aws_access_key_id='xxx', + aws_secret_access_key='xxx', + region_name='us-west-2' ) From 316c276545acfb85978788c0a4b913c8b90cc4e4 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 6 Oct 2023 15:29:17 +0100 Subject: [PATCH 5/7] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 98958c8..edeb7b6 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Say goodbye to time-consuming manual searches, and let DocsGPT -### Production Support/ Help for companies: +### Production Support / Help for companies: We're eager to provide personalized assistance when deploying your DocsGPT to a live environment. - [Schedule Demo 👋](https://cal.com/arc53/docsgpt-demo-b2b?date=2023-10-04&month=2023-10) From bbd0325c104204b5daa2dd5432835594a80704be Mon Sep 17 00:00:00 2001 From: John Bampton Date: Sat, 7 Oct 2023 00:40:27 +1000 Subject: [PATCH 6/7] Add pull request labeler --- .github/labeler.yml | 23 +++++++++++++++++++++++ .github/workflows/labeler.yml | 15 +++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 .github/labeler.yml create mode 100644 .github/workflows/labeler.yml diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 0000000..0c9b183 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,23 @@ +repo: + - '*' + +github: + - .github/**/* + +application: + - application/**/* + +docs: + - docs/**/* + +extensions: + - extensions/**/* + +frontend: + - frontend/**/* + +scripts: + - scripts/**/* + +tests: + - tests/**/* diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 0000000..f85abb1 --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,15 @@ +# https://github.com/actions/labeler +name: Pull Request Labeler +on: + - pull_request_target +jobs: + triage: + permissions: + contents: read + pull-requests: write + runs-on: ubuntu-latest + steps: + - uses: actions/labeler@v4 + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" + sync-labels: true From 17edaa0e1fccf2c1cda0fe3da086fc1426e402b3 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 6 Oct 2023 16:05:10 +0100 Subject: [PATCH 7/7] Update faiss.py --- application/vectorstore/faiss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/vectorstore/faiss.py b/application/vectorstore/faiss.py index 5c5cee7..217b045 100644 --- a/application/vectorstore/faiss.py +++ b/application/vectorstore/faiss.py @@ -1,5 +1,5 @@ from application.vectorstore.base import BaseVectorStore -from langchain import FAISS +from langchain.vectorstores import FAISS from application.core.settings import settings class FaissStore(BaseVectorStore):