2023-05-23 22:59:08 +00:00
|
|
|
"""Test MosaicML API wrapper."""
|
2023-08-25 05:13:17 +00:00
|
|
|
import re
|
|
|
|
|
2023-05-23 22:59:08 +00:00
|
|
|
import pytest
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
from langchain_community.llms.mosaicml import PROMPT_FOR_GENERATION_FORMAT, MosaicML
|
2023-05-23 22:59:08 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_mosaicml_llm_call() -> None:
|
|
|
|
"""Test valid call to MosaicML."""
|
|
|
|
llm = MosaicML(model_kwargs={})
|
|
|
|
output = llm("Say foo:")
|
|
|
|
assert isinstance(output, str)
|
|
|
|
|
|
|
|
|
|
|
|
def test_mosaicml_endpoint_change() -> None:
|
|
|
|
"""Test valid call to MosaicML."""
|
2023-08-25 05:13:17 +00:00
|
|
|
new_url = "https://models.hosted-on.mosaicml.hosting/mpt-30b-instruct/v1/predict"
|
2023-05-23 22:59:08 +00:00
|
|
|
llm = MosaicML(endpoint_url=new_url)
|
|
|
|
assert llm.endpoint_url == new_url
|
|
|
|
output = llm("Say foo:")
|
|
|
|
assert isinstance(output, str)
|
|
|
|
|
|
|
|
|
|
|
|
def test_mosaicml_extra_kwargs() -> None:
|
|
|
|
llm = MosaicML(model_kwargs={"max_new_tokens": 1})
|
|
|
|
assert llm.model_kwargs == {"max_new_tokens": 1}
|
|
|
|
|
|
|
|
output = llm("Say foo:")
|
|
|
|
|
|
|
|
assert isinstance(output, str)
|
|
|
|
|
|
|
|
# should only generate one new token (which might be a new line or whitespace token)
|
|
|
|
assert len(output.split()) <= 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_instruct_prompt() -> None:
|
|
|
|
"""Test instruct prompt."""
|
2023-08-25 05:13:17 +00:00
|
|
|
llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10})
|
2023-05-23 22:59:08 +00:00
|
|
|
instruction = "Repeat the word foo"
|
|
|
|
prompt = llm._transform_prompt(instruction)
|
|
|
|
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
|
|
|
assert prompt == expected_prompt
|
|
|
|
output = llm(prompt)
|
|
|
|
assert isinstance(output, str)
|
|
|
|
|
|
|
|
|
|
|
|
def test_retry_logic() -> None:
|
|
|
|
"""Tests that two queries (which would usually exceed the rate limit) works"""
|
2023-08-25 05:13:17 +00:00
|
|
|
llm = MosaicML(inject_instruction_format=True, model_kwargs={"max_new_tokens": 10})
|
2023-05-23 22:59:08 +00:00
|
|
|
instruction = "Repeat the word foo"
|
|
|
|
prompt = llm._transform_prompt(instruction)
|
|
|
|
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
|
|
|
assert prompt == expected_prompt
|
|
|
|
output = llm(prompt)
|
|
|
|
assert isinstance(output, str)
|
|
|
|
output = llm(prompt)
|
|
|
|
assert isinstance(output, str)
|
|
|
|
|
|
|
|
|
|
|
|
def test_short_retry_does_not_loop() -> None:
|
|
|
|
"""Tests that two queries with a short retry sleep does not infinite loop"""
|
|
|
|
llm = MosaicML(
|
|
|
|
inject_instruction_format=True,
|
|
|
|
model_kwargs={"do_sample": False},
|
|
|
|
retry_sleep=0.1,
|
|
|
|
)
|
|
|
|
instruction = "Repeat the word foo"
|
|
|
|
prompt = llm._transform_prompt(instruction)
|
|
|
|
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
|
|
|
assert prompt == expected_prompt
|
|
|
|
|
|
|
|
with pytest.raises(
|
|
|
|
ValueError,
|
2023-08-25 05:13:17 +00:00
|
|
|
match=re.escape(
|
|
|
|
"Error raised by inference API: rate limit exceeded.\nResponse: You have "
|
|
|
|
"reached maximum request limit.\n"
|
|
|
|
),
|
2023-05-23 22:59:08 +00:00
|
|
|
):
|
2023-08-25 05:13:17 +00:00
|
|
|
for _ in range(10):
|
|
|
|
output = llm(prompt)
|
|
|
|
assert isinstance(output, str)
|