mirror of
https://github.com/HazyResearch/manifest
synced 2024-10-31 15:20:26 +00:00
134 lines
3.5 KiB
Python
134 lines
3.5 KiB
Python
"""Setup for all tests."""
|
|
import os
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Generator
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import redis
|
|
|
|
from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest
|
|
from manifest.response import ArrayModelChoice, LMModelChoice, ModelChoices
|
|
|
|
|
|
@pytest.fixture
|
|
def model_choice() -> ModelChoices:
|
|
"""Get dummy model choice."""
|
|
model_choices = ModelChoices(
|
|
choices=[
|
|
LMModelChoice(
|
|
text="hello", token_logprobs=[0.1, 0.2], tokens=["hel", "lo"]
|
|
),
|
|
LMModelChoice(text="bye", token_logprobs=[0.3], tokens=["bye"]),
|
|
]
|
|
)
|
|
return model_choices
|
|
|
|
|
|
@pytest.fixture
|
|
def model_choice_single() -> ModelChoices:
|
|
"""Get dummy model choice."""
|
|
model_choices = ModelChoices(
|
|
choices=[
|
|
LMModelChoice(
|
|
text="helloo", token_logprobs=[0.1, 0.2], tokens=["hel", "loo"]
|
|
),
|
|
]
|
|
)
|
|
return model_choices
|
|
|
|
|
|
@pytest.fixture
|
|
def model_choice_arr() -> ModelChoices:
|
|
"""Get dummy model choice."""
|
|
np.random.seed(0)
|
|
model_choices = ModelChoices(
|
|
choices=[
|
|
ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.1, 0.2]),
|
|
ArrayModelChoice(array=np.random.randn(4, 4), token_logprobs=[0.3]),
|
|
]
|
|
)
|
|
return model_choices
|
|
|
|
|
|
@pytest.fixture
|
|
def model_choice_arr_int() -> ModelChoices:
|
|
"""Get dummy model choice."""
|
|
np.random.seed(0)
|
|
model_choices = ModelChoices(
|
|
choices=[
|
|
ArrayModelChoice(
|
|
array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.1, 0.2]
|
|
),
|
|
ArrayModelChoice(
|
|
array=np.random.randint(20, size=(4, 4)), token_logprobs=[0.3]
|
|
),
|
|
]
|
|
)
|
|
return model_choices
|
|
|
|
|
|
@pytest.fixture
|
|
def request_lm() -> LMRequest:
|
|
"""Get dummy request."""
|
|
request = LMRequest(prompt=["what", "cat"])
|
|
return request
|
|
|
|
|
|
@pytest.fixture
|
|
def request_lm_single() -> LMRequest:
|
|
"""Get dummy request."""
|
|
request = LMRequest(prompt="monkey", engine="dummy")
|
|
return request
|
|
|
|
|
|
@pytest.fixture
|
|
def request_array() -> EmbeddingRequest:
|
|
"""Get dummy request."""
|
|
request = EmbeddingRequest(prompt="hello")
|
|
return request
|
|
|
|
|
|
@pytest.fixture
|
|
def request_diff() -> DiffusionRequest:
|
|
"""Get dummy request."""
|
|
request = DiffusionRequest(prompt="hello")
|
|
return request
|
|
|
|
|
|
@pytest.fixture
|
|
def sqlite_cache(tmp_path: Path) -> Generator[str, None, None]:
|
|
"""Sqlite Cache."""
|
|
cache = str(tmp_path / "sqlite_cache.sqlite")
|
|
yield cache
|
|
shutil.rmtree(cache, ignore_errors=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def redis_cache() -> Generator[str, None, None]:
|
|
"""Redis cache."""
|
|
host = os.environ.get("REDIS_HOST", "localhost")
|
|
port = int(os.environ.get("REDIS_PORT", 6379))
|
|
yield f"{host}:{port}"
|
|
# Clear out the database
|
|
try:
|
|
db = redis.Redis(host=host, port=port)
|
|
db.flushdb()
|
|
# For better local testing, pass if redis DB not started
|
|
except redis.exceptions.ConnectionError:
|
|
pass
|
|
|
|
|
|
@pytest.fixture
|
|
def postgres_cache(monkeypatch: pytest.MonkeyPatch) -> Generator[str, None, None]:
|
|
"""Postgres cache."""
|
|
import sqlalchemy # type: ignore
|
|
|
|
# Replace the sqlalchemy.create_engine function with a function that returns an
|
|
# in-memory SQLite engine
|
|
url = sqlalchemy.engine.url.URL.create("sqlite", database=":memory:")
|
|
engine = sqlalchemy.create_engine(url)
|
|
monkeypatch.setattr(sqlalchemy, "create_engine", lambda *args, **kwargs: engine)
|
|
return engine # type: ignore
|