mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
479be3cc91
**Community: Unify Titan Takeoff Integrations and Adding Embedding Support** **Description:** Titan Takeoff no longer reflects this either of the integrations in the community folder. The two integrations (TitanTakeoffPro and TitanTakeoff) where causing confusion with clients, so have moved code into one place and created an alias for backwards compatibility. Added Takeoff Client python package to do the bulk of the work with the requests, this is because this package is actively updated with new versions of Takeoff. So this integration will be far more robust and will not degrade as badly over time. **Issue:** Fixes bugs in the old Titan integrations and unified the code with added unit test converge to avoid future problems. **Dependencies:** Added optional dependency takeoff-client, all imports still work without dependency including the Titan Takeoff classes but just will fail on initialisation if not pip installed takeoff-client **Twitter** @MeryemArik9 Thanks all :) --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
142 lines
4.7 KiB
Python
142 lines
4.7 KiB
Python
"""Test Titan Takeoff wrapper."""
|
|
import json
|
|
from typing import Any, Union
|
|
|
|
import pytest
|
|
|
|
from langchain_community.llms import TitanTakeoff, TitanTakeoffPro
|
|
|
|
|
|
@pytest.mark.requires("takeoff_client")
|
|
@pytest.mark.requires("pytest_httpx")
|
|
@pytest.mark.parametrize("streaming", [True, False])
|
|
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
|
|
def test_titan_takeoff_call(
|
|
httpx_mock: Any,
|
|
streaming: bool,
|
|
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
|
|
) -> None:
|
|
"""Test valid call to Titan Takeoff."""
|
|
from pytest_httpx import IteratorStream
|
|
|
|
port = 2345
|
|
url = (
|
|
f"http://localhost:{port}/generate_stream"
|
|
if streaming
|
|
else f"http://localhost:{port}/generate"
|
|
)
|
|
|
|
if streaming:
|
|
httpx_mock.add_response(
|
|
method="POST",
|
|
url=url,
|
|
stream=IteratorStream([b"data: ask someone else\n\n"]),
|
|
)
|
|
else:
|
|
httpx_mock.add_response(
|
|
method="POST",
|
|
url=url,
|
|
json={"text": "ask someone else"},
|
|
)
|
|
|
|
llm = takeoff_object(port=port, streaming=streaming)
|
|
number_of_calls = 0
|
|
for function_call in [llm, llm.invoke]:
|
|
number_of_calls += 1
|
|
output = function_call("What is 2 + 2?")
|
|
assert isinstance(output, str)
|
|
assert len(httpx_mock.get_requests()) == number_of_calls
|
|
assert httpx_mock.get_requests()[0].url == url
|
|
assert (
|
|
json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
|
|
)
|
|
|
|
if streaming:
|
|
output = llm._stream("What is 2 + 2?")
|
|
for chunk in output:
|
|
assert isinstance(chunk.text, str)
|
|
assert len(httpx_mock.get_requests()) == number_of_calls + 1
|
|
assert httpx_mock.get_requests()[0].url == url
|
|
assert (
|
|
json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("pytest_httpx")
|
|
@pytest.mark.requires("takeoff_client")
|
|
@pytest.mark.parametrize("streaming", [True, False])
|
|
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
|
|
def test_titan_takeoff_bad_call(
|
|
httpx_mock: Any,
|
|
streaming: bool,
|
|
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
|
|
) -> None:
|
|
"""Test valid call to Titan Takeoff."""
|
|
from takeoff_client import TakeoffException
|
|
|
|
url = (
|
|
"http://localhost:3000/generate"
|
|
if not streaming
|
|
else "http://localhost:3000/generate_stream"
|
|
)
|
|
httpx_mock.add_response(
|
|
method="POST", url=url, json={"text": "bad things"}, status_code=400
|
|
)
|
|
|
|
llm = takeoff_object(streaming=streaming)
|
|
with pytest.raises(TakeoffException):
|
|
llm("What is 2 + 2?")
|
|
assert len(httpx_mock.get_requests()) == 1
|
|
assert httpx_mock.get_requests()[0].url == url
|
|
assert json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
|
|
|
|
|
|
@pytest.mark.requires("pytest_httpx")
|
|
@pytest.mark.requires("takeoff_client")
|
|
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
|
|
def test_titan_takeoff_model_initialisation(
|
|
httpx_mock: Any,
|
|
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
|
|
) -> None:
|
|
"""Test valid call to Titan Takeoff."""
|
|
mgnt_port = 36452
|
|
inf_port = 46253
|
|
mgnt_url = f"http://localhost:{mgnt_port}/reader"
|
|
gen_url = f"http://localhost:{inf_port}/generate"
|
|
reader_1 = {
|
|
"model_name": "test",
|
|
"device": "cpu",
|
|
"consumer_group": "primary",
|
|
"max_sequence_length": 512,
|
|
"max_batch_size": 4,
|
|
"tensor_parallel": 3,
|
|
}
|
|
reader_2 = reader_1.copy()
|
|
reader_2["model_name"] = "test2"
|
|
|
|
httpx_mock.add_response(
|
|
method="POST", url=mgnt_url, json={"key": "value"}, status_code=201
|
|
)
|
|
httpx_mock.add_response(
|
|
method="POST", url=gen_url, json={"text": "value"}, status_code=200
|
|
)
|
|
|
|
llm = takeoff_object(
|
|
port=inf_port, mgmt_port=mgnt_port, models=[reader_1, reader_2]
|
|
)
|
|
output = llm("What is 2 + 2?")
|
|
|
|
assert isinstance(output, str)
|
|
# Ensure the management api was called to create the reader
|
|
assert len(httpx_mock.get_requests()) == 3
|
|
for key, value in reader_1.items():
|
|
assert json.loads(httpx_mock.get_requests()[0].content)[key] == value
|
|
assert httpx_mock.get_requests()[0].url == mgnt_url
|
|
# Also second call should be made to spin uo reader 2
|
|
for key, value in reader_2.items():
|
|
assert json.loads(httpx_mock.get_requests()[1].content)[key] == value
|
|
assert httpx_mock.get_requests()[1].url == mgnt_url
|
|
# Ensure the third call is to generate endpoint to inference
|
|
assert httpx_mock.get_requests()[2].url == gen_url
|
|
assert json.loads(httpx_mock.get_requests()[2].content)["text"] == "What is 2 + 2?"
|