gpt4all/gpt4all-bindings/python/tests/test_pyllmodel.py
Yaroslav Halchenko 15f7f1fc6c [DATALAD RUNCMD] run codespell throughout
=== Do not change lines below ===
{
 "chain": [],
 "cmd": "codespell -w",
 "exit": 0,
 "extra_inputs": [],
 "inputs": [],
 "outputs": [],
 "pwd": "."
}
^^^ Do not change lines above ^^^
2023-05-16 11:33:59 -04:00

62 lines
1.5 KiB
Python

from io import StringIO
import sys
from gpt4all import pyllmodel
# TODO: Integration test for loadmodel and prompt.
# # Right now, too slow b/c it requires file download.
def test_create_gptj():
gptj = pyllmodel.GPTJModel()
assert gptj.model_type == "gptj"
def test_create_llama():
llama = pyllmodel.LlamaModel()
assert llama.model_type == "llama"
def test_create_mpt():
mpt = pyllmodel.MPTModel()
assert mpt.model_type == "mpt"
def prompt_unloaded_mpt():
mpt = pyllmodel.MPTModel()
old_stdout = sys.stdout
collect_response = StringIO()
sys.stdout = collect_response
mpt.prompt("hello there")
response = collect_response.getvalue()
sys.stdout = old_stdout
response = response.strip()
assert response == "MPT ERROR: prompt won't work with an unloaded model!"
def prompt_unloaded_gptj():
gptj = pyllmodel.GPTJModel()
old_stdout = sys.stdout
collect_response = StringIO()
sys.stdout = collect_response
gptj.prompt("hello there")
response = collect_response.getvalue()
sys.stdout = old_stdout
response = response.strip()
assert response == "GPT-J ERROR: prompt won't work with an unloaded model!"
def prompt_unloaded_llama():
llama = pyllmodel.LlamaModel()
old_stdout = sys.stdout
collect_response = StringIO()
sys.stdout = collect_response
llama.prompt("hello there")
response = collect_response.getvalue()
sys.stdout = old_stdout
response = response.strip()
assert response == "LLAMA ERROR: prompt won't work with an unloaded model!"