test: adds tests for stablestudio (#415)

This commit is contained in:
jaydrennan 2023-12-16 13:00:03 -07:00 committed by GitHub
parent e1e6f8037c
commit 41a9d7007b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 156 additions and 15 deletions

View File

@ -1,8 +1,10 @@
black
coverage
httpx
mypy
ruff
pytest
pytest-asyncio
pytest-randomly
pytest-sugar
responses

View File

@ -5,7 +5,9 @@
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
#
aiohttp==3.9.1
# via fsspec
# via
# black
# fsspec
aiosignal==1.3.1
# via aiohttp
annotated-types==0.6.0
@ -15,15 +17,19 @@ antlr4-python3-runtime==4.9.3
anyio==3.7.1
# via
# fastapi
# httpx
# starlette
async-timeout==4.0.3
# via aiohttp
attrs==23.1.0
# via aiohttp
black==23.11.0
black==23.12.0
# via -r requirements-dev.in
certifi==2023.11.17
# via requests
# via
# httpcore
# httpx
# requests
charset-normalizer==3.3.2
# via requests
click==8.1.7
@ -39,7 +45,7 @@ click-shell==2.1
# via imaginAIry (setup.py)
contourpy==1.2.0
# via matplotlib
coverage==7.3.2
coverage==7.3.3
# via -r requirements-dev.in
cycler==0.12.1
# via matplotlib
@ -55,7 +61,7 @@ facexlib==0.3.0
# via imaginAIry (setup.py)
fairscale==0.4.13
# via imaginAIry (setup.py)
fastapi==0.104.1
fastapi==0.105.0
# via imaginAIry (setup.py)
filelock==3.13.1
# via
@ -67,11 +73,11 @@ filterpy==1.4.5
# via facexlib
fonttools==4.46.0
# via matplotlib
frozenlist==1.4.0
frozenlist==1.4.1
# via
# aiohttp
# aiosignal
fsspec[http]==2023.12.1
fsspec[http]==2023.12.2
# via
# huggingface-hub
# pytorch-lightning
@ -81,7 +87,13 @@ ftfy==6.1.3
# imaginAIry (setup.py)
# open-clip-torch
h11==0.14.0
# via uvicorn
# via
# httpcore
# uvicorn
httpcore==1.0.2
# via httpx
httpx==0.25.2
# via -r requirements-dev.in
huggingface-hub==0.19.4
# via
# diffusers
@ -92,15 +104,16 @@ huggingface-hub==0.19.4
idna==3.6
# via
# anyio
# httpx
# requests
# yarl
imageio==2.33.0
imageio==2.33.1
# via imaginAIry (setup.py)
importlib-metadata==7.0.0
# via diffusers
iniconfig==2.0.0
# via pytest
jaxtyping==0.2.24
jaxtyping==0.2.25
# via refiners
jinja2==3.1.2
# via torch
@ -208,8 +221,11 @@ pyparsing==3.1.1
pytest==7.4.3
# via
# -r requirements-dev.in
# pytest-asyncio
# pytest-randomly
# pytest-sugar
pytest-asyncio==0.23.2
# via -r requirements-dev.in
pytest-randomly==3.15.0
# via -r requirements-dev.in
pytest-sugar==0.9.7
@ -244,7 +260,7 @@ requests==2.31.0
# transformers
responses==0.24.1
# via -r requirements-dev.in
ruff==0.1.7
ruff==0.1.8
# via -r requirements-dev.in
safetensors==0.3.3
# via
@ -264,7 +280,9 @@ sentencepiece==0.1.99
six==1.16.0
# via python-dateutil
sniffio==1.3.0
# via anyio
# via
# anyio
# httpx
starlette==0.27.0
# via fastapi
sympy==1.12
@ -282,7 +300,7 @@ tomli==2.0.1
# black
# mypy
# pytest
torch==2.1.1
torch==2.1.2
# via
# facexlib
# fairscale
@ -301,7 +319,7 @@ torchmetrics==1.2.1
# via
# imaginAIry (setup.py)
# pytorch-lightning
torchvision==0.16.1
torchvision==0.16.2
# via
# facexlib
# imaginAIry (setup.py)
@ -315,7 +333,7 @@ tqdm==4.66.1
# open-clip-torch
# pytorch-lightning
# transformers
transformers==4.35.2
transformers==4.36.1
# via imaginAIry (setup.py)
typeguard==2.13.3
# via jaxtyping

View File

View File

@ -0,0 +1,60 @@
from unittest import mock
import pytest
from fastapi.testclient import TestClient
from imaginairy.http_app.app import app
client = TestClient(app)
@pytest.fixture()
def mock_generate_image(monkeypatch):
fake_generate = mock.MagicMock(return_value=iter("a fake image"))
monkeypatch.setattr("imaginairy.http_app.app.generate_image", fake_generate)
@pytest.mark.asyncio()
def test_imagine_endpoint(mock_generate_image):
test_input = {"prompt": "test prompt"}
response = client.post("/api/imagine", json=test_input)
assert response.status_code == 200
assert response.content == b"a fake image"
@pytest.mark.asyncio()
async def test_get_imagine_endpoint(mock_generate_image):
test_input = {"text": "a dog"}
response = client.get("/api/imagine", params=test_input)
assert response.status_code == 200
assert response.content == b"a fake image"
@pytest.mark.asyncio()
async def test_get_imagine_endpoint_mp(mock_generate_image):
test_input = {"text": "a dog"}
response = client.get("/api/imagine", params=test_input)
assert response.status_code == 200
assert response.content == b"a fake image"
def test_edit_redir():
response = client.get("/edit")
assert response.status_code == 200
assert response.headers["content-type"] == "text/html; charset=utf-8"
assert response.content[:15] == b"<!DOCTYPE html>"
def test_generate_redir():
response = client.get("/generate")
assert response.status_code == 200
assert response.headers["content-type"] == "text/html; charset=utf-8"
assert response.content[:15] == b"<!DOCTYPE html>"

View File

@ -0,0 +1,61 @@
from unittest import mock
import pytest
from fastapi.testclient import TestClient
from imaginairy.http_app.app import app
client = TestClient(app)
@pytest.fixture(name="red_b64")
def _red_b64():
return b"iVBORw0KGgoAAAANSUhEUgAAAgAAAAIAAQMAAADOtka5AAAABlBMVEX/AAD///9BHTQRAAAANklEQVR4nO3BAQEAAACCIP+vbkhAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8G4IAAAHSeInwAAAAAElFTkSuQmCC"
@pytest.fixture()
def mock_generate_image_b64(monkeypatch, red_b64):
fake_generate = mock.MagicMock(return_value=red_b64)
monkeypatch.setattr(
"imaginairy.http_app.stablestudio.routes.generate_image_b64", fake_generate
)
@pytest.mark.asyncio()
async def test_generate_endpoint(mock_generate_image_b64, red_b64):
test_input = {
"input": {
"prompts": [{"text": "A dog"}],
"sampler": {"id": "ddim"},
"height": 512,
"width": 512,
},
}
response = client.post("/api/stablestudio/generate", json=test_input)
assert response.status_code == 200
data = response.json()
assert "images" in data
for image in data["images"]:
assert image["blob"] == red_b64.decode("utf-8")
@pytest.mark.asyncio()
async def test_list_samplers():
response = client.get("/api/stablestudio/samplers")
assert response.status_code == 200
assert response.json() == [
{"id": "ddim", "name": "ddim"},
{"id": "dpmpp", "name": "dpmpp"},
]
@pytest.mark.asyncio()
async def test_list_models():
response = client.get("/api/stablestudio/models")
assert response.status_code == 200
expected_model_ids = {"sd15", "openjourney-v1", "openjourney-v2", "openjourney-v4"}
model_ids = {m["id"] for m in response.json()}
assert model_ids == expected_model_ids