mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
481d3855dc
- `llm(prompt)` -> `llm.invoke(prompt)` - `llm(prompt=prompt` -> `llm.invoke(prompt)` (same with `messages=`) - `llm(prompt, callbacks=callbacks)` -> `llm.invoke(prompt, config={"callbacks": callbacks})` - `llm(prompt, **kwargs)` -> `llm.invoke(prompt, **kwargs)`
142 lines
4.7 KiB
Python
142 lines
4.7 KiB
Python
"""Test Titan Takeoff wrapper."""
|
|
import json
|
|
from typing import Any, Union
|
|
|
|
import pytest
|
|
|
|
from langchain_community.llms import TitanTakeoff, TitanTakeoffPro
|
|
|
|
|
|
@pytest.mark.requires("takeoff_client")
|
|
@pytest.mark.requires("pytest_httpx")
|
|
@pytest.mark.parametrize("streaming", [True, False])
|
|
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
|
|
def test_titan_takeoff_call(
|
|
httpx_mock: Any,
|
|
streaming: bool,
|
|
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
|
|
) -> None:
|
|
"""Test valid call to Titan Takeoff."""
|
|
from pytest_httpx import IteratorStream
|
|
|
|
port = 2345
|
|
url = (
|
|
f"http://localhost:{port}/generate_stream"
|
|
if streaming
|
|
else f"http://localhost:{port}/generate"
|
|
)
|
|
|
|
if streaming:
|
|
httpx_mock.add_response(
|
|
method="POST",
|
|
url=url,
|
|
stream=IteratorStream([b"data: ask someone else\n\n"]),
|
|
)
|
|
else:
|
|
httpx_mock.add_response(
|
|
method="POST",
|
|
url=url,
|
|
json={"text": "ask someone else"},
|
|
)
|
|
|
|
llm = takeoff_object(port=port, streaming=streaming)
|
|
number_of_calls = 0
|
|
for function_call in [llm, llm.invoke]:
|
|
number_of_calls += 1
|
|
output = function_call("What is 2 + 2?")
|
|
assert isinstance(output, str)
|
|
assert len(httpx_mock.get_requests()) == number_of_calls
|
|
assert httpx_mock.get_requests()[0].url == url
|
|
assert (
|
|
json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
|
|
)
|
|
|
|
if streaming:
|
|
output = llm._stream("What is 2 + 2?")
|
|
for chunk in output:
|
|
assert isinstance(chunk.text, str)
|
|
assert len(httpx_mock.get_requests()) == number_of_calls + 1
|
|
assert httpx_mock.get_requests()[0].url == url
|
|
assert (
|
|
json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("pytest_httpx")
|
|
@pytest.mark.requires("takeoff_client")
|
|
@pytest.mark.parametrize("streaming", [True, False])
|
|
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
|
|
def test_titan_takeoff_bad_call(
|
|
httpx_mock: Any,
|
|
streaming: bool,
|
|
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
|
|
) -> None:
|
|
"""Test valid call to Titan Takeoff."""
|
|
from takeoff_client import TakeoffException
|
|
|
|
url = (
|
|
"http://localhost:3000/generate"
|
|
if not streaming
|
|
else "http://localhost:3000/generate_stream"
|
|
)
|
|
httpx_mock.add_response(
|
|
method="POST", url=url, json={"text": "bad things"}, status_code=400
|
|
)
|
|
|
|
llm = takeoff_object(streaming=streaming)
|
|
with pytest.raises(TakeoffException):
|
|
llm.invoke("What is 2 + 2?")
|
|
assert len(httpx_mock.get_requests()) == 1
|
|
assert httpx_mock.get_requests()[0].url == url
|
|
assert json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
|
|
|
|
|
|
@pytest.mark.requires("pytest_httpx")
|
|
@pytest.mark.requires("takeoff_client")
|
|
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
|
|
def test_titan_takeoff_model_initialisation(
|
|
httpx_mock: Any,
|
|
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
|
|
) -> None:
|
|
"""Test valid call to Titan Takeoff."""
|
|
mgnt_port = 36452
|
|
inf_port = 46253
|
|
mgnt_url = f"http://localhost:{mgnt_port}/reader"
|
|
gen_url = f"http://localhost:{inf_port}/generate"
|
|
reader_1 = {
|
|
"model_name": "test",
|
|
"device": "cpu",
|
|
"consumer_group": "primary",
|
|
"max_sequence_length": 512,
|
|
"max_batch_size": 4,
|
|
"tensor_parallel": 3,
|
|
}
|
|
reader_2 = reader_1.copy()
|
|
reader_2["model_name"] = "test2"
|
|
|
|
httpx_mock.add_response(
|
|
method="POST", url=mgnt_url, json={"key": "value"}, status_code=201
|
|
)
|
|
httpx_mock.add_response(
|
|
method="POST", url=gen_url, json={"text": "value"}, status_code=200
|
|
)
|
|
|
|
llm = takeoff_object(
|
|
port=inf_port, mgmt_port=mgnt_port, models=[reader_1, reader_2]
|
|
)
|
|
output = llm.invoke("What is 2 + 2?")
|
|
|
|
assert isinstance(output, str)
|
|
# Ensure the management api was called to create the reader
|
|
assert len(httpx_mock.get_requests()) == 3
|
|
for key, value in reader_1.items():
|
|
assert json.loads(httpx_mock.get_requests()[0].content)[key] == value
|
|
assert httpx_mock.get_requests()[0].url == mgnt_url
|
|
# Also second call should be made to spin uo reader 2
|
|
for key, value in reader_2.items():
|
|
assert json.loads(httpx_mock.get_requests()[1].content)[key] == value
|
|
assert httpx_mock.get_requests()[1].url == mgnt_url
|
|
# Ensure the third call is to generate endpoint to inference
|
|
assert httpx_mock.get_requests()[2].url == gen_url
|
|
assert json.loads(httpx_mock.get_requests()[2].content)["text"] == "What is 2 + 2?"
|