mirror of https://github.com/hwchase17/langchain
parent
2d078c7767
commit
b2eb4ff0fc
@ -1,349 +0,0 @@
|
||||
"""Test the LangSmith evaluation helpers."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from langsmith.client import Client
|
||||
from langsmith.schemas import Dataset, Example
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.transform import TransformChain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.smith.evaluation.runner_utils import (
|
||||
InputFormatError,
|
||||
_get_messages,
|
||||
_get_prompt,
|
||||
_run_llm,
|
||||
_run_llm_or_chain,
|
||||
_validate_example_inputs_for_chain,
|
||||
_validate_example_inputs_for_language_model,
|
||||
arun_on_dataset,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
||||
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
||||
_EXAMPLE_MESSAGE = {
|
||||
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
|
||||
"type": "human",
|
||||
}
|
||||
_VALID_MESSAGES = [
|
||||
{"messages": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
||||
{"messages": [], "other_key": "value"},
|
||||
{
|
||||
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE]],
|
||||
"other_key": "value",
|
||||
},
|
||||
{"any_key": [_EXAMPLE_MESSAGE]},
|
||||
{"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE]]},
|
||||
]
|
||||
_VALID_PROMPTS = [
|
||||
{"prompts": ["foo"], "other_key": "value"},
|
||||
{"prompt": "foo", "other_key": ["bar", "baz"]},
|
||||
{"some_key": "foo"},
|
||||
{"some_key": ["foo"]},
|
||||
]
|
||||
|
||||
_INVALID_PROMPTS = (
|
||||
[
|
||||
{"prompts": "foo"},
|
||||
{"prompt": ["foo"]},
|
||||
{"some_key": 3},
|
||||
{"some_key": "foo", "other_key": "bar"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
_VALID_MESSAGES,
|
||||
)
|
||||
def test__get_messages_valid(inputs: Dict[str, Any]) -> None:
|
||||
{"messages": []}
|
||||
_get_messages(inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
_VALID_PROMPTS,
|
||||
)
|
||||
def test__get_prompts_valid(inputs: Dict[str, Any]) -> None:
|
||||
_get_prompt(inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
_VALID_PROMPTS,
|
||||
)
|
||||
def test__validate_example_inputs_for_language_model(inputs: Dict[str, Any]) -> None:
|
||||
mock_ = mock.MagicMock()
|
||||
mock_.inputs = inputs
|
||||
_validate_example_inputs_for_language_model(mock_, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
_INVALID_PROMPTS,
|
||||
)
|
||||
def test__validate_example_inputs_for_language_model_invalid(
|
||||
inputs: Dict[str, Any]
|
||||
) -> None:
|
||||
mock_ = mock.MagicMock()
|
||||
mock_.inputs = inputs
|
||||
with pytest.raises(InputFormatError):
|
||||
_validate_example_inputs_for_language_model(mock_, None)
|
||||
|
||||
|
||||
def test__validate_example_inputs_for_chain_single_input() -> None:
|
||||
mock_ = mock.MagicMock()
|
||||
mock_.inputs = {"foo": "bar"}
|
||||
chain = mock.MagicMock()
|
||||
chain.input_keys = ["def not foo"]
|
||||
_validate_example_inputs_for_chain(mock_, chain, None)
|
||||
|
||||
|
||||
def test__validate_example_inputs_for_chain_input_mapper() -> None:
|
||||
mock_ = mock.MagicMock()
|
||||
mock_.inputs = {"foo": "bar", "baz": "qux"}
|
||||
chain = mock.MagicMock()
|
||||
chain.input_keys = ["not foo", "not baz", "not qux"]
|
||||
|
||||
def wrong_output_format(inputs: dict) -> str:
|
||||
assert "foo" in inputs
|
||||
assert "baz" in inputs
|
||||
return "hehe"
|
||||
|
||||
with pytest.raises(InputFormatError, match="must be a dictionary"):
|
||||
_validate_example_inputs_for_chain(mock_, chain, wrong_output_format)
|
||||
|
||||
def wrong_output_keys(inputs: dict) -> dict:
|
||||
assert "foo" in inputs
|
||||
assert "baz" in inputs
|
||||
return {"not foo": "foo", "not baz": "baz"}
|
||||
|
||||
with pytest.raises(InputFormatError, match="keys that match"):
|
||||
_validate_example_inputs_for_chain(mock_, chain, wrong_output_keys)
|
||||
|
||||
def input_mapper(inputs: dict) -> dict:
|
||||
assert "foo" in inputs
|
||||
assert "baz" in inputs
|
||||
return {"not foo": inputs["foo"], "not baz": inputs["baz"], "not qux": "qux"}
|
||||
|
||||
_validate_example_inputs_for_chain(mock_, chain, input_mapper)
|
||||
|
||||
|
||||
def test__validate_example_inputs_for_chain_multi_io() -> None:
|
||||
mock_ = mock.MagicMock()
|
||||
mock_.inputs = {"foo": "bar", "baz": "qux"}
|
||||
chain = mock.MagicMock()
|
||||
chain.input_keys = ["foo", "baz"]
|
||||
_validate_example_inputs_for_chain(mock_, chain, None)
|
||||
|
||||
|
||||
def test__validate_example_inputs_for_chain_single_input_multi_expect() -> None:
|
||||
mock_ = mock.MagicMock()
|
||||
mock_.inputs = {"foo": "bar"}
|
||||
chain = mock.MagicMock()
|
||||
chain.input_keys = ["def not foo", "oh here is another"]
|
||||
with pytest.raises(
|
||||
InputFormatError, match="Example inputs do not match chain input keys."
|
||||
):
|
||||
_validate_example_inputs_for_chain(mock_, chain, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inputs", _INVALID_PROMPTS)
|
||||
def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
|
||||
with pytest.raises(InputFormatError):
|
||||
_get_prompt(inputs)
|
||||
|
||||
|
||||
def test_run_llm_or_chain_with_input_mapper() -> None:
|
||||
example = Example(
|
||||
id=uuid.uuid4(),
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"the wrong input": "1", "another key": "2"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
def run_val(inputs: dict) -> dict:
|
||||
assert "the right input" in inputs
|
||||
return {"output": "2"}
|
||||
|
||||
mock_chain = TransformChain(
|
||||
input_variables=["the right input"],
|
||||
output_variables=["output"],
|
||||
transform=run_val,
|
||||
)
|
||||
|
||||
def input_mapper(inputs: dict) -> dict:
|
||||
assert "the wrong input" in inputs
|
||||
return {"the right input": inputs["the wrong input"]}
|
||||
|
||||
result = _run_llm_or_chain(
|
||||
example, lambda: mock_chain, n_repetitions=1, input_mapper=input_mapper
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0] == {"output": "2", "the right input": "1"}
|
||||
bad_result = _run_llm_or_chain(
|
||||
example,
|
||||
lambda: mock_chain,
|
||||
n_repetitions=1,
|
||||
)
|
||||
assert len(bad_result) == 1
|
||||
assert "Error" in bad_result[0]
|
||||
|
||||
# Try with LLM
|
||||
def llm_input_mapper(inputs: dict) -> str:
|
||||
assert "the wrong input" in inputs
|
||||
return "the right input"
|
||||
|
||||
mock_llm = FakeLLM(queries={"the right input": "somenumber"})
|
||||
result = _run_llm_or_chain(
|
||||
example, mock_llm, n_repetitions=1, input_mapper=llm_input_mapper
|
||||
)
|
||||
assert len(result) == 1
|
||||
llm_result = result[0]
|
||||
assert isinstance(llm_result, str)
|
||||
assert llm_result == "somenumber"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
[
|
||||
{"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
||||
{
|
||||
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE],
|
||||
"other_key": "value",
|
||||
},
|
||||
{"prompts": "foo"},
|
||||
{},
|
||||
],
|
||||
)
|
||||
def test__get_messages_invalid(inputs: Dict[str, Any]) -> None:
|
||||
with pytest.raises(InputFormatError):
|
||||
_get_messages(inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES)
|
||||
def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
|
||||
llm = FakeLLM()
|
||||
_run_llm(llm, inputs, mock.MagicMock())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS)
|
||||
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
|
||||
llm = FakeChatModel()
|
||||
_run_llm(llm, inputs, mock.MagicMock())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = Dataset(
|
||||
id=uuid.uuid4(),
|
||||
name="test",
|
||||
description="Test dataset",
|
||||
owner_id="owner",
|
||||
created_at=_CREATED_AT,
|
||||
tenant_id=_TENANT_ID,
|
||||
)
|
||||
uuids = [
|
||||
"0c193153-2309-4704-9a47-17aee4fb25c8",
|
||||
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
|
||||
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
|
||||
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
|
||||
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
|
||||
]
|
||||
examples = [
|
||||
Example(
|
||||
id=uuids[0],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "1"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[1],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "3"},
|
||||
outputs={"output": "4"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[2],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "5"},
|
||||
outputs={"output": "6"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[3],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "7"},
|
||||
outputs={"output": "8"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[4],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "9"},
|
||||
outputs={"output": "10"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
]
|
||||
|
||||
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
|
||||
return dataset
|
||||
|
||||
def mock_list_examples(*args: Any, **kwargs: Any) -> Iterator[Example]:
|
||||
return iter(examples)
|
||||
|
||||
async def mock_arun_chain(
|
||||
example: Example,
|
||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||
n_repetitions: int,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
]
|
||||
|
||||
def mock_create_project(*args: Any, **kwargs: Any) -> Any:
|
||||
proj = mock.MagicMock()
|
||||
proj.id = "123"
|
||||
return proj
|
||||
|
||||
with mock.patch.object(
|
||||
Client, "read_dataset", new=mock_read_dataset
|
||||
), mock.patch.object(Client, "list_examples", new=mock_list_examples), mock.patch(
|
||||
"langchain.smith.evaluation.runner_utils._arun_llm_or_chain",
|
||||
new=mock_arun_chain,
|
||||
), mock.patch.object(
|
||||
Client, "create_project", new=mock_create_project
|
||||
):
|
||||
client = Client(api_url="http://localhost:1984", api_key="123")
|
||||
chain = mock.MagicMock()
|
||||
chain.input_keys = ["foothing"]
|
||||
num_repetitions = 3
|
||||
results = await arun_on_dataset(
|
||||
dataset_name="test",
|
||||
llm_or_chain_factory=lambda: chain,
|
||||
concurrency_level=2,
|
||||
project_name="test_project",
|
||||
num_repetitions=num_repetitions,
|
||||
client=client,
|
||||
)
|
||||
|
||||
expected = {
|
||||
uuid_: [
|
||||
{"result": f"Result for example {uuid.UUID(uuid_)}"}
|
||||
for _ in range(num_repetitions)
|
||||
]
|
||||
for uuid_ in uuids
|
||||
}
|
||||
assert results["results"] == expected
|
Loading…
Reference in New Issue