test: adds tests for stablestudio (#415)
parent
e1e6f8037c
commit
41a9d7007b
@ -0,0 +1,60 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from imaginairy.http_app.app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_generate_image(monkeypatch):
|
||||
fake_generate = mock.MagicMock(return_value=iter("a fake image"))
|
||||
monkeypatch.setattr("imaginairy.http_app.app.generate_image", fake_generate)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
def test_imagine_endpoint(mock_generate_image):
|
||||
test_input = {"prompt": "test prompt"}
|
||||
|
||||
response = client.post("/api/imagine", json=test_input)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"a fake image"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_get_imagine_endpoint(mock_generate_image):
|
||||
test_input = {"text": "a dog"}
|
||||
|
||||
response = client.get("/api/imagine", params=test_input)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"a fake image"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_get_imagine_endpoint_mp(mock_generate_image):
|
||||
test_input = {"text": "a dog"}
|
||||
|
||||
response = client.get("/api/imagine", params=test_input)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"a fake image"
|
||||
|
||||
|
||||
def test_edit_redir():
|
||||
response = client.get("/edit")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/html; charset=utf-8"
|
||||
assert response.content[:15] == b"<!DOCTYPE html>"
|
||||
|
||||
|
||||
def test_generate_redir():
|
||||
response = client.get("/generate")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/html; charset=utf-8"
|
||||
assert response.content[:15] == b"<!DOCTYPE html>"
|
@ -0,0 +1,61 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from imaginairy.http_app.app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(name="red_b64")
|
||||
def _red_b64():
|
||||
return b"iVBORw0KGgoAAAANSUhEUgAAAgAAAAIAAQMAAADOtka5AAAABlBMVEX/AAD///9BHTQRAAAANklEQVR4nO3BAQEAAACCIP+vbkhAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8G4IAAAHSeInwAAAAAElFTkSuQmCC"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_generate_image_b64(monkeypatch, red_b64):
|
||||
fake_generate = mock.MagicMock(return_value=red_b64)
|
||||
monkeypatch.setattr(
|
||||
"imaginairy.http_app.stablestudio.routes.generate_image_b64", fake_generate
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_generate_endpoint(mock_generate_image_b64, red_b64):
|
||||
test_input = {
|
||||
"input": {
|
||||
"prompts": [{"text": "A dog"}],
|
||||
"sampler": {"id": "ddim"},
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
},
|
||||
}
|
||||
|
||||
response = client.post("/api/stablestudio/generate", json=test_input)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "images" in data
|
||||
for image in data["images"]:
|
||||
assert image["blob"] == red_b64.decode("utf-8")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_list_samplers():
|
||||
response = client.get("/api/stablestudio/samplers")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [
|
||||
{"id": "ddim", "name": "ddim"},
|
||||
{"id": "dpmpp", "name": "dpmpp"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_list_models():
|
||||
response = client.get("/api/stablestudio/models")
|
||||
assert response.status_code == 200
|
||||
|
||||
expected_model_ids = {"sd15", "openjourney-v1", "openjourney-v2", "openjourney-v4"}
|
||||
model_ids = {m["id"] for m in response.json()}
|
||||
assert model_ids == expected_model_ids
|
Loading…
Reference in New Issue