|
|
|
@ -54,7 +54,7 @@ class TestSagemakerAPILLM(unittest.TestCase):
|
|
|
|
|
def test_gen(self):
|
|
|
|
|
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
|
|
|
|
|
return_value=self.response) as mock_invoke_endpoint:
|
|
|
|
|
output = self.sagemaker.gen(None, None, self.messages)
|
|
|
|
|
output = self.sagemaker.gen(None, self.messages)
|
|
|
|
|
mock_invoke_endpoint.assert_called_once_with(
|
|
|
|
|
EndpointName=self.sagemaker.endpoint,
|
|
|
|
|
ContentType='application/json',
|
|
|
|
@ -66,7 +66,7 @@ class TestSagemakerAPILLM(unittest.TestCase):
|
|
|
|
|
def test_gen_stream(self):
|
|
|
|
|
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
|
|
|
|
|
return_value=self.response) as mock_invoke_endpoint:
|
|
|
|
|
output = list(self.sagemaker.gen_stream(None, None, self.messages))
|
|
|
|
|
output = list(self.sagemaker.gen_stream(None, self.messages))
|
|
|
|
|
mock_invoke_endpoint.assert_called_once_with(
|
|
|
|
|
EndpointName=self.sagemaker.endpoint,
|
|
|
|
|
ContentType='application/json',
|
|
|
|
|