langchain/libs/community/tests/integration_tests/llms/test_titan_takeoff.py

142 lines
4.7 KiB
Python
Raw Normal View History

"""Test Titan Takeoff wrapper."""
import json
from typing import Any, Union
import pytest
from langchain_community.llms import TitanTakeoff, TitanTakeoffPro
@pytest.mark.requires("takeoff_client")
@pytest.mark.requires("pytest_httpx")
@pytest.mark.parametrize("streaming", [True, False])
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
def test_titan_takeoff_call(
httpx_mock: Any,
streaming: bool,
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
) -> None:
"""Test valid call to Titan Takeoff."""
from pytest_httpx import IteratorStream
port = 2345
url = (
f"http://localhost:{port}/generate_stream"
if streaming
else f"http://localhost:{port}/generate"
)
if streaming:
httpx_mock.add_response(
method="POST",
url=url,
stream=IteratorStream([b"data: ask someone else\n\n"]),
)
else:
httpx_mock.add_response(
method="POST",
url=url,
json={"text": "ask someone else"},
)
llm = takeoff_object(port=port, streaming=streaming)
number_of_calls = 0
for function_call in [llm, llm.invoke]:
number_of_calls += 1
output = function_call("What is 2 + 2?")
assert isinstance(output, str)
assert len(httpx_mock.get_requests()) == number_of_calls
assert httpx_mock.get_requests()[0].url == url
assert (
json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
)
if streaming:
output = llm._stream("What is 2 + 2?")
for chunk in output:
assert isinstance(chunk.text, str)
assert len(httpx_mock.get_requests()) == number_of_calls + 1
assert httpx_mock.get_requests()[0].url == url
assert (
json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
)
@pytest.mark.requires("pytest_httpx")
@pytest.mark.requires("takeoff_client")
@pytest.mark.parametrize("streaming", [True, False])
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
def test_titan_takeoff_bad_call(
httpx_mock: Any,
streaming: bool,
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
) -> None:
"""Test valid call to Titan Takeoff."""
from takeoff_client import TakeoffException
url = (
"http://localhost:3000/generate"
if not streaming
else "http://localhost:3000/generate_stream"
)
httpx_mock.add_response(
method="POST", url=url, json={"text": "bad things"}, status_code=400
)
llm = takeoff_object(streaming=streaming)
with pytest.raises(TakeoffException):
llm("What is 2 + 2?")
assert len(httpx_mock.get_requests()) == 1
assert httpx_mock.get_requests()[0].url == url
assert json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
@pytest.mark.requires("pytest_httpx")
@pytest.mark.requires("takeoff_client")
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
def test_titan_takeoff_model_initialisation(
httpx_mock: Any,
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
) -> None:
"""Test valid call to Titan Takeoff."""
mgnt_port = 36452
inf_port = 46253
mgnt_url = f"http://localhost:{mgnt_port}/reader"
gen_url = f"http://localhost:{inf_port}/generate"
reader_1 = {
"model_name": "test",
"device": "cpu",
"consumer_group": "primary",
"max_sequence_length": 512,
"max_batch_size": 4,
"tensor_parallel": 3,
}
reader_2 = reader_1.copy()
reader_2["model_name"] = "test2"
httpx_mock.add_response(
method="POST", url=mgnt_url, json={"key": "value"}, status_code=201
)
httpx_mock.add_response(
method="POST", url=gen_url, json={"text": "value"}, status_code=200
)
llm = takeoff_object(
port=inf_port, mgmt_port=mgnt_port, models=[reader_1, reader_2]
)
output = llm("What is 2 + 2?")
assert isinstance(output, str)
# Ensure the management api was called to create the reader
assert len(httpx_mock.get_requests()) == 3
for key, value in reader_1.items():
assert json.loads(httpx_mock.get_requests()[0].content)[key] == value
assert httpx_mock.get_requests()[0].url == mgnt_url
# Also second call should be made to spin uo reader 2
for key, value in reader_2.items():
assert json.loads(httpx_mock.get_requests()[1].content)[key] == value
assert httpx_mock.get_requests()[1].url == mgnt_url
# Ensure the third call is to generate endpoint to inference
assert httpx_mock.get_requests()[2].url == gen_url
assert json.loads(httpx_mock.get_requests()[2].content)["text"] == "What is 2 + 2?"