mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
test: adds tests for stablestudio (#415)
This commit is contained in:
parent
e1e6f8037c
commit
41a9d7007b
@ -1,8 +1,10 @@
|
||||
black
|
||||
coverage
|
||||
httpx
|
||||
mypy
|
||||
ruff
|
||||
pytest
|
||||
pytest-asyncio
|
||||
pytest-randomly
|
||||
pytest-sugar
|
||||
responses
|
||||
|
@ -5,7 +5,9 @@
|
||||
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
|
||||
#
|
||||
aiohttp==3.9.1
|
||||
# via fsspec
|
||||
# via
|
||||
# black
|
||||
# fsspec
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.6.0
|
||||
@ -15,15 +17,19 @@ antlr4-python3-runtime==4.9.3
|
||||
anyio==3.7.1
|
||||
# via
|
||||
# fastapi
|
||||
# httpx
|
||||
# starlette
|
||||
async-timeout==4.0.3
|
||||
# via aiohttp
|
||||
attrs==23.1.0
|
||||
# via aiohttp
|
||||
black==23.11.0
|
||||
black==23.12.0
|
||||
# via -r requirements-dev.in
|
||||
certifi==2023.11.17
|
||||
# via requests
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
charset-normalizer==3.3.2
|
||||
# via requests
|
||||
click==8.1.7
|
||||
@ -39,7 +45,7 @@ click-shell==2.1
|
||||
# via imaginAIry (setup.py)
|
||||
contourpy==1.2.0
|
||||
# via matplotlib
|
||||
coverage==7.3.2
|
||||
coverage==7.3.3
|
||||
# via -r requirements-dev.in
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
@ -55,7 +61,7 @@ facexlib==0.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
fairscale==0.4.13
|
||||
# via imaginAIry (setup.py)
|
||||
fastapi==0.104.1
|
||||
fastapi==0.105.0
|
||||
# via imaginAIry (setup.py)
|
||||
filelock==3.13.1
|
||||
# via
|
||||
@ -67,11 +73,11 @@ filterpy==1.4.5
|
||||
# via facexlib
|
||||
fonttools==4.46.0
|
||||
# via matplotlib
|
||||
frozenlist==1.4.0
|
||||
frozenlist==1.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2023.12.1
|
||||
fsspec[http]==2023.12.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# pytorch-lightning
|
||||
@ -81,7 +87,13 @@ ftfy==6.1.3
|
||||
# imaginAIry (setup.py)
|
||||
# open-clip-torch
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
httpcore==1.0.2
|
||||
# via httpx
|
||||
httpx==0.25.2
|
||||
# via -r requirements-dev.in
|
||||
huggingface-hub==0.19.4
|
||||
# via
|
||||
# diffusers
|
||||
@ -92,15 +104,16 @@ huggingface-hub==0.19.4
|
||||
idna==3.6
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
imageio==2.33.0
|
||||
imageio==2.33.1
|
||||
# via imaginAIry (setup.py)
|
||||
importlib-metadata==7.0.0
|
||||
# via diffusers
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
jaxtyping==0.2.24
|
||||
jaxtyping==0.2.25
|
||||
# via refiners
|
||||
jinja2==3.1.2
|
||||
# via torch
|
||||
@ -208,8 +221,11 @@ pyparsing==3.1.1
|
||||
pytest==7.4.3
|
||||
# via
|
||||
# -r requirements-dev.in
|
||||
# pytest-asyncio
|
||||
# pytest-randomly
|
||||
# pytest-sugar
|
||||
pytest-asyncio==0.23.2
|
||||
# via -r requirements-dev.in
|
||||
pytest-randomly==3.15.0
|
||||
# via -r requirements-dev.in
|
||||
pytest-sugar==0.9.7
|
||||
@ -244,7 +260,7 @@ requests==2.31.0
|
||||
# transformers
|
||||
responses==0.24.1
|
||||
# via -r requirements-dev.in
|
||||
ruff==0.1.7
|
||||
ruff==0.1.8
|
||||
# via -r requirements-dev.in
|
||||
safetensors==0.3.3
|
||||
# via
|
||||
@ -264,7 +280,9 @@ sentencepiece==0.1.99
|
||||
six==1.16.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.0
|
||||
# via anyio
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
starlette==0.27.0
|
||||
# via fastapi
|
||||
sympy==1.12
|
||||
@ -282,7 +300,7 @@ tomli==2.0.1
|
||||
# black
|
||||
# mypy
|
||||
# pytest
|
||||
torch==2.1.1
|
||||
torch==2.1.2
|
||||
# via
|
||||
# facexlib
|
||||
# fairscale
|
||||
@ -301,7 +319,7 @@ torchmetrics==1.2.1
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# pytorch-lightning
|
||||
torchvision==0.16.1
|
||||
torchvision==0.16.2
|
||||
# via
|
||||
# facexlib
|
||||
# imaginAIry (setup.py)
|
||||
@ -315,7 +333,7 @@ tqdm==4.66.1
|
||||
# open-clip-torch
|
||||
# pytorch-lightning
|
||||
# transformers
|
||||
transformers==4.35.2
|
||||
transformers==4.36.1
|
||||
# via imaginAIry (setup.py)
|
||||
typeguard==2.13.3
|
||||
# via jaxtyping
|
||||
|
0
tests/test_http_app/__init__.py
Normal file
0
tests/test_http_app/__init__.py
Normal file
60
tests/test_http_app/test_app.py
Normal file
60
tests/test_http_app/test_app.py
Normal file
@ -0,0 +1,60 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from imaginairy.http_app.app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_generate_image(monkeypatch):
|
||||
fake_generate = mock.MagicMock(return_value=iter("a fake image"))
|
||||
monkeypatch.setattr("imaginairy.http_app.app.generate_image", fake_generate)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
def test_imagine_endpoint(mock_generate_image):
|
||||
test_input = {"prompt": "test prompt"}
|
||||
|
||||
response = client.post("/api/imagine", json=test_input)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"a fake image"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_get_imagine_endpoint(mock_generate_image):
|
||||
test_input = {"text": "a dog"}
|
||||
|
||||
response = client.get("/api/imagine", params=test_input)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"a fake image"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_get_imagine_endpoint_mp(mock_generate_image):
|
||||
test_input = {"text": "a dog"}
|
||||
|
||||
response = client.get("/api/imagine", params=test_input)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"a fake image"
|
||||
|
||||
|
||||
def test_edit_redir():
|
||||
response = client.get("/edit")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/html; charset=utf-8"
|
||||
assert response.content[:15] == b"<!DOCTYPE html>"
|
||||
|
||||
|
||||
def test_generate_redir():
|
||||
response = client.get("/generate")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/html; charset=utf-8"
|
||||
assert response.content[:15] == b"<!DOCTYPE html>"
|
61
tests/test_http_app/test_routes.py
Normal file
61
tests/test_http_app/test_routes.py
Normal file
@ -0,0 +1,61 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from imaginairy.http_app.app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(name="red_b64")
|
||||
def _red_b64():
|
||||
return b"iVBORw0KGgoAAAANSUhEUgAAAgAAAAIAAQMAAADOtka5AAAABlBMVEX/AAD///9BHTQRAAAANklEQVR4nO3BAQEAAACCIP+vbkhAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8G4IAAAHSeInwAAAAAElFTkSuQmCC"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_generate_image_b64(monkeypatch, red_b64):
|
||||
fake_generate = mock.MagicMock(return_value=red_b64)
|
||||
monkeypatch.setattr(
|
||||
"imaginairy.http_app.stablestudio.routes.generate_image_b64", fake_generate
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_generate_endpoint(mock_generate_image_b64, red_b64):
|
||||
test_input = {
|
||||
"input": {
|
||||
"prompts": [{"text": "A dog"}],
|
||||
"sampler": {"id": "ddim"},
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
},
|
||||
}
|
||||
|
||||
response = client.post("/api/stablestudio/generate", json=test_input)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "images" in data
|
||||
for image in data["images"]:
|
||||
assert image["blob"] == red_b64.decode("utf-8")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_list_samplers():
|
||||
response = client.get("/api/stablestudio/samplers")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [
|
||||
{"id": "ddim", "name": "ddim"},
|
||||
{"id": "dpmpp", "name": "dpmpp"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_list_models():
|
||||
response = client.get("/api/stablestudio/models")
|
||||
assert response.status_code == 200
|
||||
|
||||
expected_model_ids = {"sd15", "openjourney-v1", "openjourney-v2", "openjourney-v4"}
|
||||
model_ids = {m["id"] for m in response.json()}
|
||||
assert model_ids == expected_model_ids
|
Loading…
Reference in New Issue
Block a user