From 4f5e363452862406034b61a2cf80f6d5b5f0966a Mon Sep 17 00:00:00 2001
From: Alex
Date: Fri, 6 Oct 2023 01:52:29 +0100
Subject: [PATCH 1/7] sagemaker streaming
---
application/core/settings.py | 6 ++
application/llm/sagemaker.py | 134 ++++++++++++++++++++++++++++++++---
2 files changed, 130 insertions(+), 10 deletions(-)
diff --git a/application/core/settings.py b/application/core/settings.py
index a05fd00..116735a 100644
--- a/application/core/settings.py
+++ b/application/core/settings.py
@@ -32,6 +32,12 @@ class Settings(BaseSettings):
ELASTIC_URL: str = None # url for elasticsearch
ELASTIC_INDEX: str = "docsgpt" # index name for elasticsearch
+ # SageMaker config
+ SAGEMAKER_ENDPOINT: str = None # SageMaker endpoint name
+ SAGEMAKER_REGION: str = None # SageMaker region name
+ SAGEMAKER_ACCESS_KEY: str = None # SageMaker access key
+ SAGEMAKER_SECRET_KEY: str = None # SageMaker secret key
+
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py
index 9ef5d0a..55617cc 100644
--- a/application/llm/sagemaker.py
+++ b/application/llm/sagemaker.py
@@ -2,26 +2,140 @@ from application.llm.base import BaseLLM
from application.core.settings import settings
import requests
import json
+import io
+import json
+
+
+
+class LineIterator:
+ """
+ A helper class for parsing the byte stream input.
+
+ The output of the model will be in the following format:
+ ```
+ b'{"outputs": [" a"]}\n'
+ b'{"outputs": [" challenging"]}\n'
+ b'{"outputs": [" problem"]}\n'
+ ...
+ ```
+
+ While usually each PayloadPart event from the event stream will contain a byte array
+ with a full json, this is not guaranteed and some of the json objects may be split across
+ PayloadPart events. For example:
+ ```
+ {'PayloadPart': {'Bytes': b'{"outputs": '}}
+ {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
+ ```
+
+ This class accounts for this by concatenating bytes written via the 'write' function
+ and then exposing a method which will return lines (ending with a '\n' character) within
+ the buffer via the 'scan_lines' function. It maintains the position of the last read
+ position to ensure that previous bytes are not exposed again.
+ """
+
+ def __init__(self, stream):
+ self.byte_iterator = iter(stream)
+ self.buffer = io.BytesIO()
+ self.read_pos = 0
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ while True:
+ self.buffer.seek(self.read_pos)
+ line = self.buffer.readline()
+ if line and line[-1] == ord('\n'):
+ self.read_pos += len(line)
+ return line[:-1]
+ try:
+ chunk = next(self.byte_iterator)
+ except StopIteration:
+ if self.read_pos < self.buffer.getbuffer().nbytes:
+ continue
+ raise
+ if 'PayloadPart' not in chunk:
+ print('Unknown event type:' + chunk)
+ continue
+ self.buffer.seek(0, io.SEEK_END)
+ self.buffer.write(chunk['PayloadPart']['Bytes'])
class SagemakerAPILLM(BaseLLM):
def __init__(self, *args, **kwargs):
- self.url = settings.SAGEMAKER_API_URL
+ import boto3
+ runtime = boto3.client(
+ 'runtime.sagemaker',
+ aws_access_key_id=settings.SAGEMAKER_ACCESS_KEY,
+ aws_secret_access_key=settings.SAGEMAKER_SECRET_KEY,
+ region_name=settings.SAGEMAKER_REGION
+ )
+
+
+ self.endpoint = settings.SAGEMAKER_ENDPOINT
+ self.runtime = runtime
+
def gen(self, model, engine, messages, stream=False, **kwargs):
context = messages[0]['content']
user_question = messages[-1]['content']
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
+
- response = requests.post(
- url=self.url,
- headers={
- "Content-Type": "application/json; charset=utf-8",
- },
- data=json.dumps({"input": prompt})
- )
+ # Construct payload for endpoint
+ payload = {
+ "inputs": prompt,
+ "stream": False,
+ "parameters": {
+ "do_sample": True,
+ "temperature": 0.1,
+ "max_new_tokens": 30,
+ "repetition_penalty": 1.03,
+ "stop": ["", "###"]
+ }
+ }
+ body_bytes = json.dumps(payload).encode('utf-8')
- return response.json()['answer']
+ # Invoke the endpoint
+ response = self.runtime.invoke_endpoint(EndpointName=self.endpoint,
+ ContentType='application/json',
+ Body=body_bytes)
+ result = json.loads(response['Body'].read().decode())
+ import sys
+ print(result[0]['generated_text'], file=sys.stderr)
+ return result[0]['generated_text'][len(prompt):]
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
- raise NotImplementedError("Sagemaker does not support streaming")
\ No newline at end of file
+ context = messages[0]['content']
+ user_question = messages[-1]['content']
+ prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
+
+
+ # Construct payload for endpoint
+ payload = {
+ "inputs": prompt,
+ "stream": True,
+ "parameters": {
+ "do_sample": True,
+ "temperature": 0.1,
+ "max_new_tokens": 512,
+ "repetition_penalty": 1.03,
+ "stop": ["", "###"]
+ }
+ }
+ body_bytes = json.dumps(payload).encode('utf-8')
+
+ # Invoke the endpoint
+ response = self.runtime.invoke_endpoint_with_response_stream(EndpointName=self.endpoint,
+ ContentType='application/json',
+ Body=body_bytes)
+ #result = json.loads(response['Body'].read().decode())
+ event_stream = response['Body']
+ start_json = b'{'
+ for line in LineIterator(event_stream):
+ if line != b'' and start_json in line:
+ #print(line)
+ data = json.loads(line[line.find(start_json):].decode('utf-8'))
+ if data['token']['text'] not in ["", "###"]:
+ print(data['token']['text'],end='')
+ yield data['token']['text']
\ No newline at end of file
From 495728593f963f33b9f6f1378216faa9c49b74fc Mon Sep 17 00:00:00 2001
From: Alex
Date: Fri, 6 Oct 2023 13:22:51 +0100
Subject: [PATCH 2/7] sagemaker fixes + test
---
application/llm/sagemaker.py | 2 -
tests/llm/test_sagemaker.py | 96 ++++++++++++++++++++++++++++++++++++
2 files changed, 96 insertions(+), 2 deletions(-)
create mode 100644 tests/llm/test_sagemaker.py
diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py
index 55617cc..ed5fc67 100644
--- a/application/llm/sagemaker.py
+++ b/application/llm/sagemaker.py
@@ -1,9 +1,7 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
-import requests
import json
import io
-import json
diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py
new file mode 100644
index 0000000..f8d02d8
--- /dev/null
+++ b/tests/llm/test_sagemaker.py
@@ -0,0 +1,96 @@
+# FILEPATH: /path/to/test_sagemaker.py
+
+import json
+import unittest
+from unittest.mock import MagicMock, patch
+from application.llm.sagemaker import SagemakerAPILLM, LineIterator
+
+class TestSagemakerAPILLM(unittest.TestCase):
+
+ def setUp(self):
+ self.sagemaker = SagemakerAPILLM()
+ self.context = "This is the context"
+ self.user_question = "What is the answer?"
+ self.messages = [
+ {"content": self.context},
+ {"content": "Some other message"},
+ {"content": self.user_question}
+ ]
+ self.prompt = f"### Instruction \n {self.user_question} \n ### Context \n {self.context} \n ### Answer \n"
+ self.payload = {
+ "inputs": self.prompt,
+ "stream": False,
+ "parameters": {
+ "do_sample": True,
+ "temperature": 0.1,
+ "max_new_tokens": 30,
+ "repetition_penalty": 1.03,
+ "stop": ["", "###"]
+ }
+ }
+ self.payload_stream = {
+ "inputs": self.prompt,
+ "stream": True,
+ "parameters": {
+ "do_sample": True,
+ "temperature": 0.1,
+ "max_new_tokens": 512,
+ "repetition_penalty": 1.03,
+ "stop": ["", "###"]
+ }
+ }
+ self.body_bytes = json.dumps(self.payload).encode('utf-8')
+ self.body_bytes_stream = json.dumps(self.payload_stream).encode('utf-8')
+ self.response = {
+ "Body": MagicMock()
+ }
+ self.result = [
+ {
+ "generated_text": "This is the generated text"
+ }
+ ]
+ self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result)
+
+ def test_gen(self):
+ with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
+ return_value=self.response) as mock_invoke_endpoint:
+ output = self.sagemaker.gen(None, None, self.messages)
+ mock_invoke_endpoint.assert_called_once_with(
+ EndpointName=self.sagemaker.endpoint,
+ ContentType='application/json',
+ Body=self.body_bytes
+ )
+ self.assertEqual(output,
+ self.result[0]['generated_text'][len(self.prompt):])
+
+ def test_gen_stream(self):
+ with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
+ return_value=self.response) as mock_invoke_endpoint:
+ output = list(self.sagemaker.gen_stream(None, None, self.messages))
+ mock_invoke_endpoint.assert_called_once_with(
+ EndpointName=self.sagemaker.endpoint,
+ ContentType='application/json',
+ Body=self.body_bytes_stream
+ )
+ self.assertEqual(output, [])
+
+class TestLineIterator(unittest.TestCase):
+
+ def setUp(self):
+ self.stream = [
+ {'PayloadPart': {'Bytes': b'{"outputs": [" a"]}\n'}},
+ {'PayloadPart': {'Bytes': b'{"outputs": [" challenging"]}\n'}},
+ {'PayloadPart': {'Bytes': b'{"outputs": [" problem"]}\n'}}
+ ]
+ self.line_iterator = LineIterator(self.stream)
+
+ def test_iter(self):
+ self.assertEqual(iter(self.line_iterator), self.line_iterator)
+
+ def test_next(self):
+ self.assertEqual(next(self.line_iterator), b'{"outputs": [" a"]}')
+ self.assertEqual(next(self.line_iterator), b'{"outputs": [" challenging"]}')
+ self.assertEqual(next(self.line_iterator), b'{"outputs": [" problem"]}')
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
From b3a0368b952c58f0cb77302dbd97809ad2398aa9 Mon Sep 17 00:00:00 2001
From: Alex
Date: Fri, 6 Oct 2023 13:54:03 +0100
Subject: [PATCH 3/7] Update app.py
---
application/app.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/application/app.py b/application/app.py
index e97febc..ae61997 100644
--- a/application/app.py
+++ b/application/app.py
@@ -27,7 +27,10 @@ celery.config_from_object("application.celeryconfig")
@app.route("/")
def home():
- return redirect('http://localhost:5173') if request.remote_addr in ('0.0.0.0', '127.0.0.1', 'localhost', '172.18.0.1') else 'Welcome to DocsGPT Backend!'
+ if request.remote_addr in ('0.0.0.0', '127.0.0.1', 'localhost', '172.18.0.1'):
+ return redirect('http://localhost:5173')
+ else:
+ return 'Welcome to DocsGPT Backend!'
@app.after_request
def after_request(response):
From 43a22f84d9c934ee1aa11481acf55eb70e2b1bd4 Mon Sep 17 00:00:00 2001
From: Alex
Date: Fri, 6 Oct 2023 14:43:05 +0100
Subject: [PATCH 4/7] Update sagemaker.py
---
application/llm/sagemaker.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py
index ed5fc67..84ae09a 100644
--- a/application/llm/sagemaker.py
+++ b/application/llm/sagemaker.py
@@ -64,9 +64,9 @@ class SagemakerAPILLM(BaseLLM):
import boto3
runtime = boto3.client(
'runtime.sagemaker',
- aws_access_key_id=settings.SAGEMAKER_ACCESS_KEY,
- aws_secret_access_key=settings.SAGEMAKER_SECRET_KEY,
- region_name=settings.SAGEMAKER_REGION
+ aws_access_key_id='xxx',
+ aws_secret_access_key='xxx',
+ region_name='us-west-2'
)
From 316c276545acfb85978788c0a4b913c8b90cc4e4 Mon Sep 17 00:00:00 2001
From: Alex
Date: Fri, 6 Oct 2023 15:29:17 +0100
Subject: [PATCH 5/7] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 98958c8..edeb7b6 100644
--- a/README.md
+++ b/README.md
@@ -21,7 +21,7 @@ Say goodbye to time-consuming manual searches, and let DocsGPT
-### Production Support/ Help for companies:
+### Production Support / Help for companies:
We're eager to provide personalized assistance when deploying your DocsGPT to a live environment.
- [Schedule Demo 👋](https://cal.com/arc53/docsgpt-demo-b2b?date=2023-10-04&month=2023-10)
From bbd0325c104204b5daa2dd5432835594a80704be Mon Sep 17 00:00:00 2001
From: John Bampton
Date: Sat, 7 Oct 2023 00:40:27 +1000
Subject: [PATCH 6/7] Add pull request labeler
---
.github/labeler.yml | 23 +++++++++++++++++++++++
.github/workflows/labeler.yml | 15 +++++++++++++++
2 files changed, 38 insertions(+)
create mode 100644 .github/labeler.yml
create mode 100644 .github/workflows/labeler.yml
diff --git a/.github/labeler.yml b/.github/labeler.yml
new file mode 100644
index 0000000..0c9b183
--- /dev/null
+++ b/.github/labeler.yml
@@ -0,0 +1,23 @@
+repo:
+ - '*'
+
+github:
+ - .github/**/*
+
+application:
+ - application/**/*
+
+docs:
+ - docs/**/*
+
+extensions:
+ - extensions/**/*
+
+frontend:
+ - frontend/**/*
+
+scripts:
+ - scripts/**/*
+
+tests:
+ - tests/**/*
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
new file mode 100644
index 0000000..f85abb1
--- /dev/null
+++ b/.github/workflows/labeler.yml
@@ -0,0 +1,15 @@
+# https://github.com/actions/labeler
+name: Pull Request Labeler
+on:
+ - pull_request_target
+jobs:
+ triage:
+ permissions:
+ contents: read
+ pull-requests: write
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/labeler@v4
+ with:
+ repo-token: "${{ secrets.GITHUB_TOKEN }}"
+ sync-labels: true
From 17edaa0e1fccf2c1cda0fe3da086fc1426e402b3 Mon Sep 17 00:00:00 2001
From: Alex
Date: Fri, 6 Oct 2023 16:05:10 +0100
Subject: [PATCH 7/7] Update faiss.py
---
application/vectorstore/faiss.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/application/vectorstore/faiss.py b/application/vectorstore/faiss.py
index 5c5cee7..217b045 100644
--- a/application/vectorstore/faiss.py
+++ b/application/vectorstore/faiss.py
@@ -1,5 +1,5 @@
from application.vectorstore.base import BaseVectorStore
-from langchain import FAISS
+from langchain.vectorstores import FAISS
from application.core.settings import settings
class FaissStore(BaseVectorStore):