diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py index f8d02d8..0602f59 100644 --- a/tests/llm/test_sagemaker.py +++ b/tests/llm/test_sagemaker.py @@ -54,7 +54,7 @@ class TestSagemakerAPILLM(unittest.TestCase): def test_gen(self): with patch.object(self.sagemaker.runtime, 'invoke_endpoint', return_value=self.response) as mock_invoke_endpoint: - output = self.sagemaker.gen(None, None, self.messages) + output = self.sagemaker.gen(None, self.messages) mock_invoke_endpoint.assert_called_once_with( EndpointName=self.sagemaker.endpoint, ContentType='application/json', @@ -66,7 +66,7 @@ class TestSagemakerAPILLM(unittest.TestCase): def test_gen_stream(self): with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', return_value=self.response) as mock_invoke_endpoint: - output = list(self.sagemaker.gen_stream(None, None, self.messages)) + output = list(self.sagemaker.gen_stream(None, self.messages)) mock_invoke_endpoint.assert_called_once_with( EndpointName=self.sagemaker.endpoint, ContentType='application/json',