langchain/libs/community/tests/integration_tests/llms/test_titan_takeoff.py
pjb157 479be3cc91
community[minor]: Unify Titan Takeoff Integrations and Adding Embedding Support (#18775)
**Community: Unify Titan Takeoff Integrations and Adding Embedding
Support**

 **Description:** 
Titan Takeoff no longer reflects this either of the integrations in the
community folder. The two integrations (TitanTakeoffPro and
TitanTakeoff) where causing confusion with clients, so have moved code
into one place and created an alias for backwards compatibility. Added
Takeoff Client python package to do the bulk of the work with the
requests, this is because this package is actively updated with new
versions of Takeoff. So this integration will be far more robust and
will not degrade as badly over time.

**Issue:**
Fixes bugs in the old Titan integrations and unified the code with added
unit test converge to avoid future problems.

**Dependencies:**
Added optional dependency takeoff-client, all imports still work without
dependency including the Titan Takeoff classes but just will fail on
initialisation if not pip installed takeoff-client

**Twitter**
@MeryemArik9

Thanks all :)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
2024-04-17 01:43:35 +00:00

142 lines
4.7 KiB
Python

"""Test Titan Takeoff wrapper."""
import json
from typing import Any, Union
import pytest
from langchain_community.llms import TitanTakeoff, TitanTakeoffPro
@pytest.mark.requires("takeoff_client")
@pytest.mark.requires("pytest_httpx")
@pytest.mark.parametrize("streaming", [True, False])
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
def test_titan_takeoff_call(
httpx_mock: Any,
streaming: bool,
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
) -> None:
"""Test valid call to Titan Takeoff."""
from pytest_httpx import IteratorStream
port = 2345
url = (
f"http://localhost:{port}/generate_stream"
if streaming
else f"http://localhost:{port}/generate"
)
if streaming:
httpx_mock.add_response(
method="POST",
url=url,
stream=IteratorStream([b"data: ask someone else\n\n"]),
)
else:
httpx_mock.add_response(
method="POST",
url=url,
json={"text": "ask someone else"},
)
llm = takeoff_object(port=port, streaming=streaming)
number_of_calls = 0
for function_call in [llm, llm.invoke]:
number_of_calls += 1
output = function_call("What is 2 + 2?")
assert isinstance(output, str)
assert len(httpx_mock.get_requests()) == number_of_calls
assert httpx_mock.get_requests()[0].url == url
assert (
json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
)
if streaming:
output = llm._stream("What is 2 + 2?")
for chunk in output:
assert isinstance(chunk.text, str)
assert len(httpx_mock.get_requests()) == number_of_calls + 1
assert httpx_mock.get_requests()[0].url == url
assert (
json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
)
@pytest.mark.requires("pytest_httpx")
@pytest.mark.requires("takeoff_client")
@pytest.mark.parametrize("streaming", [True, False])
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
def test_titan_takeoff_bad_call(
httpx_mock: Any,
streaming: bool,
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
) -> None:
"""Test valid call to Titan Takeoff."""
from takeoff_client import TakeoffException
url = (
"http://localhost:3000/generate"
if not streaming
else "http://localhost:3000/generate_stream"
)
httpx_mock.add_response(
method="POST", url=url, json={"text": "bad things"}, status_code=400
)
llm = takeoff_object(streaming=streaming)
with pytest.raises(TakeoffException):
llm("What is 2 + 2?")
assert len(httpx_mock.get_requests()) == 1
assert httpx_mock.get_requests()[0].url == url
assert json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
@pytest.mark.requires("pytest_httpx")
@pytest.mark.requires("takeoff_client")
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
def test_titan_takeoff_model_initialisation(
httpx_mock: Any,
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
) -> None:
"""Test valid call to Titan Takeoff."""
mgnt_port = 36452
inf_port = 46253
mgnt_url = f"http://localhost:{mgnt_port}/reader"
gen_url = f"http://localhost:{inf_port}/generate"
reader_1 = {
"model_name": "test",
"device": "cpu",
"consumer_group": "primary",
"max_sequence_length": 512,
"max_batch_size": 4,
"tensor_parallel": 3,
}
reader_2 = reader_1.copy()
reader_2["model_name"] = "test2"
httpx_mock.add_response(
method="POST", url=mgnt_url, json={"key": "value"}, status_code=201
)
httpx_mock.add_response(
method="POST", url=gen_url, json={"text": "value"}, status_code=200
)
llm = takeoff_object(
port=inf_port, mgmt_port=mgnt_port, models=[reader_1, reader_2]
)
output = llm("What is 2 + 2?")
assert isinstance(output, str)
# Ensure the management api was called to create the reader
assert len(httpx_mock.get_requests()) == 3
for key, value in reader_1.items():
assert json.loads(httpx_mock.get_requests()[0].content)[key] == value
assert httpx_mock.get_requests()[0].url == mgnt_url
# Also second call should be made to spin uo reader 2
for key, value in reader_2.items():
assert json.loads(httpx_mock.get_requests()[1].content)[key] == value
assert httpx_mock.get_requests()[1].url == mgnt_url
# Ensure the third call is to generate endpoint to inference
assert httpx_mock.get_requests()[2].url == gen_url
assert json.loads(httpx_mock.get_requests()[2].content)["text"] == "What is 2 + 2?"