mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-17 09:25:47 +00:00
62 lines
1.7 KiB
Python
62 lines
1.7 KiB
Python
from unittest import mock
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from imaginairy.http_app.app import app
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
@pytest.fixture(name="red_b64")
|
|
def _red_b64():
|
|
return b"iVBORw0KGgoAAAANSUhEUgAAAgAAAAIAAQMAAADOtka5AAAABlBMVEX/AAD///9BHTQRAAAANklEQVR4nO3BAQEAAACCIP+vbkhAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8G4IAAAHSeInwAAAAAElFTkSuQmCC"
|
|
|
|
|
|
@pytest.fixture()
|
|
def mock_generate_image_b64(monkeypatch, red_b64):
|
|
fake_generate = mock.MagicMock(return_value=red_b64)
|
|
monkeypatch.setattr(
|
|
"imaginairy.http_app.stablestudio.routes.generate_image_b64", fake_generate
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
async def test_generate_endpoint(mock_generate_image_b64, red_b64):
|
|
test_input = {
|
|
"input": {
|
|
"prompts": [{"text": "A dog"}],
|
|
"sampler": {"id": "ddim"},
|
|
"height": 512,
|
|
"width": 512,
|
|
},
|
|
}
|
|
|
|
response = client.post("/api/stablestudio/generate", json=test_input)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "images" in data
|
|
for image in data["images"]:
|
|
assert image["blob"] == red_b64.decode("utf-8")
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
async def test_list_samplers():
|
|
response = client.get("/api/stablestudio/samplers")
|
|
assert response.status_code == 200
|
|
assert response.json() == [
|
|
{"id": "ddim", "name": "ddim"},
|
|
{"id": "dpmpp", "name": "dpmpp"},
|
|
]
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
async def test_list_models():
|
|
response = client.get("/api/stablestudio/models")
|
|
assert response.status_code == 200
|
|
|
|
expected_model_ids = {"sd15", "openjourney-v1", "openjourney-v2", "openjourney-v4"}
|
|
model_ids = {m["id"] for m in response.json()}
|
|
assert model_ids == expected_model_ids
|