From 4f5e363452862406034b61a2cf80f6d5b5f0966a Mon Sep 17 00:00:00 2001
From: Alex
Date: Fri, 6 Oct 2023 01:52:29 +0100
Subject: [PATCH 1/2] sagemaker streaming
---
application/core/settings.py | 6 ++
application/llm/sagemaker.py | 134 ++++++++++++++++++++++++++++++++---
2 files changed, 130 insertions(+), 10 deletions(-)
diff --git a/application/core/settings.py b/application/core/settings.py
index a05fd00..116735a 100644
--- a/application/core/settings.py
+++ b/application/core/settings.py
@@ -32,6 +32,12 @@ class Settings(BaseSettings):
ELASTIC_URL: str = None # url for elasticsearch
ELASTIC_INDEX: str = "docsgpt" # index name for elasticsearch
+ # SageMaker config
+ SAGEMAKER_ENDPOINT: str = None # SageMaker endpoint name
+ SAGEMAKER_REGION: str = None # SageMaker region name
+ SAGEMAKER_ACCESS_KEY: str = None # SageMaker access key
+ SAGEMAKER_SECRET_KEY: str = None # SageMaker secret key
+
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py
index 9ef5d0a..55617cc 100644
--- a/application/llm/sagemaker.py
+++ b/application/llm/sagemaker.py
@@ -2,26 +2,140 @@ from application.llm.base import BaseLLM
from application.core.settings import settings
import requests
import json
+import io
+import json
+
+
+
+class LineIterator:
+ """
+ A helper class for parsing the byte stream input.
+
+ The output of the model will be in the following format:
+ ```
+ b'{"outputs": [" a"]}\n'
+ b'{"outputs": [" challenging"]}\n'
+ b'{"outputs": [" problem"]}\n'
+ ...
+ ```
+
+ While usually each PayloadPart event from the event stream will contain a byte array
+ with a full json, this is not guaranteed and some of the json objects may be split across
+ PayloadPart events. For example:
+ ```
+ {'PayloadPart': {'Bytes': b'{"outputs": '}}
+ {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
+ ```
+
+ This class accounts for this by concatenating bytes written via the 'write' function
+ and then exposing a method which will return lines (ending with a '\n' character) within
+ the buffer via the 'scan_lines' function. It maintains the position of the last read
+ position to ensure that previous bytes are not exposed again.
+ """
+
+ def __init__(self, stream):
+ self.byte_iterator = iter(stream)
+ self.buffer = io.BytesIO()
+ self.read_pos = 0
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ while True:
+ self.buffer.seek(self.read_pos)
+ line = self.buffer.readline()
+ if line and line[-1] == ord('\n'):
+ self.read_pos += len(line)
+ return line[:-1]
+ try:
+ chunk = next(self.byte_iterator)
+ except StopIteration:
+ if self.read_pos < self.buffer.getbuffer().nbytes:
+ continue
+ raise
+ if 'PayloadPart' not in chunk:
+ print('Unknown event type:' + chunk)
+ continue
+ self.buffer.seek(0, io.SEEK_END)
+ self.buffer.write(chunk['PayloadPart']['Bytes'])
class SagemakerAPILLM(BaseLLM):
def __init__(self, *args, **kwargs):
- self.url = settings.SAGEMAKER_API_URL
+ import boto3
+ runtime = boto3.client(
+ 'runtime.sagemaker',
+ aws_access_key_id=settings.SAGEMAKER_ACCESS_KEY,
+ aws_secret_access_key=settings.SAGEMAKER_SECRET_KEY,
+ region_name=settings.SAGEMAKER_REGION
+ )
+
+
+ self.endpoint = settings.SAGEMAKER_ENDPOINT
+ self.runtime = runtime
+
def gen(self, model, engine, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
+
- response = requests.post(
- url=self.url,
- headers={
- "Content-Type": "application/json; charset=utf-8",
- },
- data=json.dumps({"input": prompt})
- )
+ # Construct payload for endpoint
+ payload = {
+ "inputs": prompt,
+ "stream": False,
+ "parameters": {
+ "do_sample": True,
+ "temperature": 0.1,
+ "max_new_tokens": 30,
+ "repetition_penalty": 1.03,
+ "stop": ["", "###"]
+ }
+ }
+ body_bytes = json.dumps(payload).encode('utf-8')
- return response.json()['answer']
+ # Invoke the endpoint
+ response = self.runtime.invoke_endpoint(EndpointName=self.endpoint,
+ ContentType='application/json',
+ Body=body_bytes)
+ result = json.loads(response['Body'].read().decode())
+ import sys
+ print(result[0]['generated_text'], file=sys.stderr)
+ return result[0]['generated_text'][len(prompt):]
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
- raise NotImplementedError("Sagemaker does not support streaming")
\ No newline at end of file
+ context = messages[0]['content']
+ user_question = messages[-1]['content']
+ prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
+
+
+ # Construct payload for endpoint
+ payload = {
+ "inputs": prompt,
+ "stream": True,
+ "parameters": {
+ "do_sample": True,
+ "temperature": 0.1,
+ "max_new_tokens": 512,
+ "repetition_penalty": 1.03,
+ "stop": ["", "###"]
+ }
+ }
+ body_bytes = json.dumps(payload).encode('utf-8')
+
+ # Invoke the endpoint
+ response = self.runtime.invoke_endpoint_with_response_stream(EndpointName=self.endpoint,
+ ContentType='application/json',
+ Body=body_bytes)
+ #result = json.loads(response['Body'].read().decode())
+ event_stream = response['Body']
+ start_json = b'{'
+ for line in LineIterator(event_stream):
+ if line != b'' and start_json in line:
+ #print(line)
+ data = json.loads(line[line.find(start_json):].decode('utf-8'))
+ if data['token']['text'] not in ["", "###"]:
+ print(data['token']['text'],end='')
+ yield data['token']['text']
\ No newline at end of file
From 495728593f963f33b9f6f1378216faa9c49b74fc Mon Sep 17 00:00:00 2001
From: Alex
Date: Fri, 6 Oct 2023 13:22:51 +0100
Subject: [PATCH 2/2] sagemaker fixes + test
---
application/llm/sagemaker.py | 2 -
tests/llm/test_sagemaker.py | 96 ++++++++++++++++++++++++++++++++++++
2 files changed, 96 insertions(+), 2 deletions(-)
create mode 100644 tests/llm/test_sagemaker.py
diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py
index 55617cc..ed5fc67 100644
--- a/application/llm/sagemaker.py
+++ b/application/llm/sagemaker.py
@@ -1,9 +1,7 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
-import requests
import json
import io
-import json
diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py
new file mode 100644
index 0000000..f8d02d8
--- /dev/null
+++ b/tests/llm/test_sagemaker.py
@@ -0,0 +1,96 @@
+# FILEPATH: /path/to/test_sagemaker.py
+
+import json
+import unittest
+from unittest.mock import MagicMock, patch
+from application.llm.sagemaker import SagemakerAPILLM, LineIterator
+
+class TestSagemakerAPILLM(unittest.TestCase):
+
+ def setUp(self):
+ self.sagemaker = SagemakerAPILLM()
+ self.context = "This is the context"
+ self.user_question = "What is the answer?"
+ self.messages = [
+ {"content": self.context},
+ {"content": "Some other message"},
+ {"content": self.user_question}
+ ]
+ self.prompt = f"### Instruction \n {self.user_question} \n ### Context \n {self.context} \n ### Answer \n"
+ self.payload = {
+ "inputs": self.prompt,
+ "stream": False,
+ "parameters": {
+ "do_sample": True,
+ "temperature": 0.1,
+ "max_new_tokens": 30,
+ "repetition_penalty": 1.03,
+ "stop": ["", "###"]
+ }
+ }
+ self.payload_stream = {
+ "inputs": self.prompt,
+ "stream": True,
+ "parameters": {
+ "do_sample": True,
+ "temperature": 0.1,
+ "max_new_tokens": 512,
+ "repetition_penalty": 1.03,
+ "stop": ["", "###"]
+ }
+ }
+ self.body_bytes = json.dumps(self.payload).encode('utf-8')
+ self.body_bytes_stream = json.dumps(self.payload_stream).encode('utf-8')
+ self.response = {
+ "Body": MagicMock()
+ }
+ self.result = [
+ {
+ "generated_text": "This is the generated text"
+ }
+ ]
+ self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result)
+
+ def test_gen(self):
+ with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
+ return_value=self.response) as mock_invoke_endpoint:
+ output = self.sagemaker.gen(None, None, self.messages)
+ mock_invoke_endpoint.assert_called_once_with(
+ EndpointName=self.sagemaker.endpoint,
+ ContentType='application/json',
+ Body=self.body_bytes
+ )
+ self.assertEqual(output,
+ self.result[0]['generated_text'][len(self.prompt):])
+
+ def test_gen_stream(self):
+ with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
+ return_value=self.response) as mock_invoke_endpoint:
+ output = list(self.sagemaker.gen_stream(None, None, self.messages))
+ mock_invoke_endpoint.assert_called_once_with(
+ EndpointName=self.sagemaker.endpoint,
+ ContentType='application/json',
+ Body=self.body_bytes_stream
+ )
+ self.assertEqual(output, [])
+
+class TestLineIterator(unittest.TestCase):
+
+ def setUp(self):
+ self.stream = [
+ {'PayloadPart': {'Bytes': b'{"outputs": [" a"]}\n'}},
+ {'PayloadPart': {'Bytes': b'{"outputs": [" challenging"]}\n'}},
+ {'PayloadPart': {'Bytes': b'{"outputs": [" problem"]}\n'}}
+ ]
+ self.line_iterator = LineIterator(self.stream)
+
+ def test_iter(self):
+ self.assertEqual(iter(self.line_iterator), self.line_iterator)
+
+ def test_next(self):
+ self.assertEqual(next(self.line_iterator), b'{"outputs": [" a"]}')
+ self.assertEqual(next(self.line_iterator), b'{"outputs": [" challenging"]}')
+ self.assertEqual(next(self.line_iterator), b'{"outputs": [" problem"]}')
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file