From 4f5e363452862406034b61a2cf80f6d5b5f0966a Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 6 Oct 2023 01:52:29 +0100 Subject: [PATCH 1/2] sagemaker streaming --- application/core/settings.py | 6 ++ application/llm/sagemaker.py | 134 ++++++++++++++++++++++++++++++++--- 2 files changed, 130 insertions(+), 10 deletions(-) diff --git a/application/core/settings.py b/application/core/settings.py index a05fd00..116735a 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -32,6 +32,12 @@ class Settings(BaseSettings): ELASTIC_URL: str = None # url for elasticsearch ELASTIC_INDEX: str = "docsgpt" # index name for elasticsearch + # SageMaker config + SAGEMAKER_ENDPOINT: str = None # SageMaker endpoint name + SAGEMAKER_REGION: str = None # SageMaker region name + SAGEMAKER_ACCESS_KEY: str = None # SageMaker access key + SAGEMAKER_SECRET_KEY: str = None # SageMaker secret key + path = Path(__file__).parent.parent.absolute() settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 9ef5d0a..55617cc 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -2,26 +2,140 @@ from application.llm.base import BaseLLM from application.core.settings import settings import requests import json +import io +import json + + + +class LineIterator: + """ + A helper class for parsing the byte stream input. + + The output of the model will be in the following format: + ``` + b'{"outputs": [" a"]}\n' + b'{"outputs": [" challenging"]}\n' + b'{"outputs": [" problem"]}\n' + ... + ``` + + While usually each PayloadPart event from the event stream will contain a byte array + with a full json, this is not guaranteed and some of the json objects may be split across + PayloadPart events. For example: + ``` + {'PayloadPart': {'Bytes': b'{"outputs": '}} + {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} + ``` + + This class accounts for this by concatenating bytes written via the 'write' function + and then exposing a method which will return lines (ending with a '\n' character) within + the buffer via the 'scan_lines' function. It maintains the position of the last read + position to ensure that previous bytes are not exposed again. + """ + + def __init__(self, stream): + self.byte_iterator = iter(stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def __iter__(self): + return self + + def __next__(self): + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord('\n'): + self.read_pos += len(line) + return line[:-1] + try: + chunk = next(self.byte_iterator) + except StopIteration: + if self.read_pos < self.buffer.getbuffer().nbytes: + continue + raise + if 'PayloadPart' not in chunk: + print('Unknown event type:' + chunk) + continue + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk['PayloadPart']['Bytes']) class SagemakerAPILLM(BaseLLM): def __init__(self, *args, **kwargs): - self.url = settings.SAGEMAKER_API_URL + import boto3 + runtime = boto3.client( + 'runtime.sagemaker', + aws_access_key_id=settings.SAGEMAKER_ACCESS_KEY, + aws_secret_access_key=settings.SAGEMAKER_SECRET_KEY, + region_name=settings.SAGEMAKER_REGION + ) + + + self.endpoint = settings.SAGEMAKER_ENDPOINT + self.runtime = runtime + def gen(self, model, engine, messages, stream=False, **kwargs): context = messages[0]['content'] user_question = messages[-1]['content'] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" + - response = requests.post( - url=self.url, - headers={ - "Content-Type": "application/json; charset=utf-8", - }, - data=json.dumps({"input": prompt}) - ) + # Construct payload for endpoint + payload = { + "inputs": prompt, + "stream": False, + "parameters": { + "do_sample": True, + "temperature": 0.1, + "max_new_tokens": 30, + "repetition_penalty": 1.03, + "stop": ["", "###"] + } + } + body_bytes = json.dumps(payload).encode('utf-8') - return response.json()['answer'] + # Invoke the endpoint + response = self.runtime.invoke_endpoint(EndpointName=self.endpoint, + ContentType='application/json', + Body=body_bytes) + result = json.loads(response['Body'].read().decode()) + import sys + print(result[0]['generated_text'], file=sys.stderr) + return result[0]['generated_text'][len(prompt):] def gen_stream(self, model, engine, messages, stream=True, **kwargs): - raise NotImplementedError("Sagemaker does not support streaming") \ No newline at end of file + context = messages[0]['content'] + user_question = messages[-1]['content'] + prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" + + + # Construct payload for endpoint + payload = { + "inputs": prompt, + "stream": True, + "parameters": { + "do_sample": True, + "temperature": 0.1, + "max_new_tokens": 512, + "repetition_penalty": 1.03, + "stop": ["", "###"] + } + } + body_bytes = json.dumps(payload).encode('utf-8') + + # Invoke the endpoint + response = self.runtime.invoke_endpoint_with_response_stream(EndpointName=self.endpoint, + ContentType='application/json', + Body=body_bytes) + #result = json.loads(response['Body'].read().decode()) + event_stream = response['Body'] + start_json = b'{' + for line in LineIterator(event_stream): + if line != b'' and start_json in line: + #print(line) + data = json.loads(line[line.find(start_json):].decode('utf-8')) + if data['token']['text'] not in ["", "###"]: + print(data['token']['text'],end='') + yield data['token']['text'] \ No newline at end of file From 495728593f963f33b9f6f1378216faa9c49b74fc Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 6 Oct 2023 13:22:51 +0100 Subject: [PATCH 2/2] sagemaker fixes + test --- application/llm/sagemaker.py | 2 - tests/llm/test_sagemaker.py | 96 ++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 tests/llm/test_sagemaker.py diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 55617cc..ed5fc67 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -1,9 +1,7 @@ from application.llm.base import BaseLLM from application.core.settings import settings -import requests import json import io -import json diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py new file mode 100644 index 0000000..f8d02d8 --- /dev/null +++ b/tests/llm/test_sagemaker.py @@ -0,0 +1,96 @@ +# FILEPATH: /path/to/test_sagemaker.py + +import json +import unittest +from unittest.mock import MagicMock, patch +from application.llm.sagemaker import SagemakerAPILLM, LineIterator + +class TestSagemakerAPILLM(unittest.TestCase): + + def setUp(self): + self.sagemaker = SagemakerAPILLM() + self.context = "This is the context" + self.user_question = "What is the answer?" + self.messages = [ + {"content": self.context}, + {"content": "Some other message"}, + {"content": self.user_question} + ] + self.prompt = f"### Instruction \n {self.user_question} \n ### Context \n {self.context} \n ### Answer \n" + self.payload = { + "inputs": self.prompt, + "stream": False, + "parameters": { + "do_sample": True, + "temperature": 0.1, + "max_new_tokens": 30, + "repetition_penalty": 1.03, + "stop": ["", "###"] + } + } + self.payload_stream = { + "inputs": self.prompt, + "stream": True, + "parameters": { + "do_sample": True, + "temperature": 0.1, + "max_new_tokens": 512, + "repetition_penalty": 1.03, + "stop": ["", "###"] + } + } + self.body_bytes = json.dumps(self.payload).encode('utf-8') + self.body_bytes_stream = json.dumps(self.payload_stream).encode('utf-8') + self.response = { + "Body": MagicMock() + } + self.result = [ + { + "generated_text": "This is the generated text" + } + ] + self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result) + + def test_gen(self): + with patch.object(self.sagemaker.runtime, 'invoke_endpoint', + return_value=self.response) as mock_invoke_endpoint: + output = self.sagemaker.gen(None, None, self.messages) + mock_invoke_endpoint.assert_called_once_with( + EndpointName=self.sagemaker.endpoint, + ContentType='application/json', + Body=self.body_bytes + ) + self.assertEqual(output, + self.result[0]['generated_text'][len(self.prompt):]) + + def test_gen_stream(self): + with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', + return_value=self.response) as mock_invoke_endpoint: + output = list(self.sagemaker.gen_stream(None, None, self.messages)) + mock_invoke_endpoint.assert_called_once_with( + EndpointName=self.sagemaker.endpoint, + ContentType='application/json', + Body=self.body_bytes_stream + ) + self.assertEqual(output, []) + +class TestLineIterator(unittest.TestCase): + + def setUp(self): + self.stream = [ + {'PayloadPart': {'Bytes': b'{"outputs": [" a"]}\n'}}, + {'PayloadPart': {'Bytes': b'{"outputs": [" challenging"]}\n'}}, + {'PayloadPart': {'Bytes': b'{"outputs": [" problem"]}\n'}} + ] + self.line_iterator = LineIterator(self.stream) + + def test_iter(self): + self.assertEqual(iter(self.line_iterator), self.line_iterator) + + def test_next(self): + self.assertEqual(next(self.line_iterator), b'{"outputs": [" a"]}') + self.assertEqual(next(self.line_iterator), b'{"outputs": [" challenging"]}') + self.assertEqual(next(self.line_iterator), b'{"outputs": [" problem"]}') + +if __name__ == '__main__': + unittest.main() \ No newline at end of file