mirror of
https://github.com/arc53/DocsGPT
synced 2024-11-09 19:10:53 +00:00
33 lines
1.2 KiB
Python
33 lines
1.2 KiB
Python
import unittest
|
|
from unittest.mock import patch
|
|
from application.llm.openai import OpenAILLM
|
|
|
|
class TestOpenAILLM(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.api_key = "test_api_key"
|
|
self.llm = OpenAILLM(self.api_key)
|
|
|
|
def test_init(self):
|
|
self.assertEqual(self.llm.api_key, self.api_key)
|
|
|
|
@patch('application.llm.openai.openai.ChatCompletion.create')
|
|
def test_gen(self, mock_create):
|
|
model = "test_model"
|
|
engine = "test_engine"
|
|
messages = ["test_message"]
|
|
response = {"choices": [{"message": {"content": "test_response"}}]}
|
|
mock_create.return_value = response
|
|
result = self.llm.gen(model, engine, messages)
|
|
self.assertEqual(result, "test_response")
|
|
|
|
@patch('application.llm.openai.openai.ChatCompletion.create')
|
|
def test_gen_stream(self, mock_create):
|
|
model = "test_model"
|
|
engine = "test_engine"
|
|
messages = ["test_message"]
|
|
response = [{"choices": [{"delta": {"content": "test_response"}}]}]
|
|
mock_create.return_value = response
|
|
result = list(self.llm.gen_stream(model, engine, messages))
|
|
self.assertEqual(result, ["test_response"])
|