langchain/tests/unit_tests/client/test_langchain.py
Zander Chase 8dcad0f272
Add Support for Flexible Input Format for LLM and Chat Model Runs (#4805)
Previously, the client expected a strict 'prompt' or 'messages' format
and wouldn't permit running a chat model or llm on prompts or messages
(respectively).

Since many datasets may want to specify custom key: string , relax this
requirement.
Also, add support for running a chat model on raw prompts and LLM on
chat messages through their respective fallbacks.
2023-05-17 14:24:17 +00:00

318 lines
9.8 KiB
Python

"""Test the LangChain+ client."""
import uuid
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
from unittest import mock
import pytest
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.schemas import TracerSession
from langchain.chains.base import Chain
from langchain.client.langchain import (
InputFormatError,
LangChainPlusClient,
_get_link_stem,
_is_localhost,
)
from langchain.client.models import Dataset, Example
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
@pytest.mark.parametrize(
"api_url, expected_url",
[
("http://localhost:8000", "http://localhost"),
("http://www.example.com", "http://www.example.com"),
(
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
),
("https://www.langchain.com/path/to/nowhere", "https://www.langchain.com"),
],
)
def test_link_split(api_url: str, expected_url: str) -> None:
"""Test the link splitting handles both localhost and deployed urls."""
assert _get_link_stem(api_url) == expected_url
def test_is_localhost() -> None:
assert _is_localhost("http://localhost:8000")
assert _is_localhost("http://127.0.0.1:8000")
assert _is_localhost("http://0.0.0.0:8000")
assert not _is_localhost("http://example.com:8000")
def test_validate_api_key_if_hosted() -> None:
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
return _TENANT_ID
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
with pytest.raises(ValueError, match="API key must be provided"):
LangChainPlusClient(api_url="http://www.example.com")
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
client = LangChainPlusClient(api_url="http://localhost:8000")
assert client.api_url == "http://localhost:8000"
assert client.api_key is None
def test_headers() -> None:
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
return _TENANT_ID
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
assert client._headers == {"x-api-key": "123"}
with mock.patch.object(
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id
):
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
assert client_no_key._headers == {}
@mock.patch("langchain.client.langchain.requests.post")
def test_upload_csv(mock_post: mock.Mock) -> None:
mock_response = mock.Mock()
dataset_id = str(uuid.uuid4())
example_1 = Example(
id=str(uuid.uuid4()),
created_at=_CREATED_AT,
inputs={"input": "1"},
outputs={"output": "2"},
dataset_id=dataset_id,
)
example_2 = Example(
id=str(uuid.uuid4()),
created_at=_CREATED_AT,
inputs={"input": "3"},
outputs={"output": "4"},
dataset_id=dataset_id,
)
mock_response.json.return_value = {
"id": dataset_id,
"name": "test.csv",
"description": "Test dataset",
"owner_id": "the owner",
"created_at": _CREATED_AT,
"examples": [example_1, example_2],
"tenant_id": _TENANT_ID,
}
mock_post.return_value = mock_response
client = LangChainPlusClient(
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
)
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
dataset = client.upload_csv(
csv_file, "Test dataset", input_keys=["input"], output_keys=["output"]
)
assert dataset.id == uuid.UUID(dataset_id)
assert dataset.name == "test.csv"
assert dataset.description == "Test dataset"
@pytest.mark.asyncio
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = Dataset(
id=uuid.uuid4(),
name="test",
description="Test dataset",
owner_id="owner",
created_at=_CREATED_AT,
tenant_id=_TENANT_ID,
)
uuids = [
"0c193153-2309-4704-9a47-17aee4fb25c8",
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
]
examples = [
Example(
id=uuids[0],
created_at=_CREATED_AT,
inputs={"input": "1"},
outputs={"output": "2"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[1],
created_at=_CREATED_AT,
inputs={"input": "3"},
outputs={"output": "4"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[2],
created_at=_CREATED_AT,
inputs={"input": "5"},
outputs={"output": "6"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[3],
created_at=_CREATED_AT,
inputs={"input": "7"},
outputs={"output": "8"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[4],
created_at=_CREATED_AT,
inputs={"input": "9"},
outputs={"output": "10"},
dataset_id=str(uuid.uuid4()),
),
]
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
return dataset
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
return examples
async def mock_arun_chain(
example: Example,
tracer: Any,
llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession:
return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4())
with mock.patch.object(
LangChainPlusClient, "read_dataset", new=mock_read_dataset
), mock.patch.object(
LangChainPlusClient, "list_examples", new=mock_list_examples
), mock.patch.object(
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
), mock.patch.object(
LangChainTracer, "ensure_session", new=mock_ensure_session
):
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID)
client = LangChainPlusClient(
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
)
chain = mock.MagicMock()
num_repetitions = 3
results = await client.arun_on_dataset(
dataset_name="test",
llm_or_chain_factory=lambda: chain,
concurrency_level=2,
session_name="test_session",
num_repetitions=num_repetitions,
)
expected = {
uuid_: [
{"result": f"Result for example {uuid.UUID(uuid_)}"}
for _ in range(num_repetitions)
]
for uuid_ in uuids
}
assert results == expected
_EXAMPLE_MESSAGE = {
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
"type": "human",
}
_VALID_MESSAGES = [
{"messages": [_EXAMPLE_MESSAGE], "other_key": "value"},
{"messages": [], "other_key": "value"},
{
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]],
"other_key": "value",
},
{"any_key": [_EXAMPLE_MESSAGE]},
{"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]]},
]
_VALID_PROMPTS = [
{"prompts": ["foo", "bar", "baz"], "other_key": "value"},
{"prompt": "foo", "other_key": ["bar", "baz"]},
{"some_key": "foo"},
{"some_key": ["foo", "bar"]},
]
@pytest.mark.parametrize(
"inputs",
_VALID_MESSAGES,
)
def test__get_messages_valid(inputs: Dict[str, Any]) -> None:
{"messages": []}
LangChainPlusClient._get_messages(inputs)
@pytest.mark.parametrize(
"inputs",
_VALID_PROMPTS,
)
def test__get_prompts_valid(inputs: Dict[str, Any]) -> None:
LangChainPlusClient._get_prompts(inputs)
@pytest.mark.parametrize(
"inputs",
[
{"prompts": "foo"},
{"prompt": ["foo"]},
{"some_key": 3},
{"some_key": "foo", "other_key": "bar"},
],
)
def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
with pytest.raises(InputFormatError):
LangChainPlusClient._get_prompts(inputs)
@pytest.mark.parametrize(
"inputs",
[
{"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"},
{
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE],
"other_key": "value",
},
{"prompts": "foo"},
{},
],
)
def test__get_messages_invalid(inputs: Dict[str, Any]) -> None:
with pytest.raises(InputFormatError):
LangChainPlusClient._get_messages(inputs)
@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES)
def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
llm = FakeLLM()
LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock())
@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS)
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
llm = FakeChatModel()
LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock())