You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DocsGPT/tests/llm/test_sagemaker.py

96 lines
3.6 KiB
Python

# FILEPATH: /path/to/test_sagemaker.py
import json
import unittest
from unittest.mock import MagicMock, patch
from application.llm.sagemaker import SagemakerAPILLM, LineIterator
class TestSagemakerAPILLM(unittest.TestCase):
def setUp(self):
self.sagemaker = SagemakerAPILLM()
self.context = "This is the context"
self.user_question = "What is the answer?"
self.messages = [
{"content": self.context},
{"content": "Some other message"},
{"content": self.user_question}
]
self.prompt = f"### Instruction \n {self.user_question} \n ### Context \n {self.context} \n ### Answer \n"
self.payload = {
"inputs": self.prompt,
"stream": False,
"parameters": {
"do_sample": True,
"temperature": 0.1,
"max_new_tokens": 30,
"repetition_penalty": 1.03,
"stop": ["</s>", "###"]
}
}
self.payload_stream = {
"inputs": self.prompt,
"stream": True,
"parameters": {
"do_sample": True,
"temperature": 0.1,
"max_new_tokens": 512,
"repetition_penalty": 1.03,
"stop": ["</s>", "###"]
}
}
self.body_bytes = json.dumps(self.payload).encode('utf-8')
self.body_bytes_stream = json.dumps(self.payload_stream).encode('utf-8')
self.response = {
"Body": MagicMock()
}
self.result = [
{
"generated_text": "This is the generated text"
}
]
self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result)
def test_gen(self):
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
return_value=self.response) as mock_invoke_endpoint:
output = self.sagemaker.gen(None, self.messages)
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
Body=self.body_bytes
)
self.assertEqual(output,
self.result[0]['generated_text'][len(self.prompt):])
def test_gen_stream(self):
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
return_value=self.response) as mock_invoke_endpoint:
output = list(self.sagemaker.gen_stream(None, self.messages))
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
Body=self.body_bytes_stream
)
self.assertEqual(output, [])
class TestLineIterator(unittest.TestCase):
def setUp(self):
self.stream = [
{'PayloadPart': {'Bytes': b'{"outputs": [" a"]}\n'}},
{'PayloadPart': {'Bytes': b'{"outputs": [" challenging"]}\n'}},
{'PayloadPart': {'Bytes': b'{"outputs": [" problem"]}\n'}}
]
self.line_iterator = LineIterator(self.stream)
def test_iter(self):
self.assertEqual(iter(self.line_iterator), self.line_iterator)
def test_next(self):
self.assertEqual(next(self.line_iterator), b'{"outputs": [" a"]}')
self.assertEqual(next(self.line_iterator), b'{"outputs": [" challenging"]}')
self.assertEqual(next(self.line_iterator), b'{"outputs": [" problem"]}')
if __name__ == '__main__':
unittest.main()