Relax Validation in Eval (#8902)

Just check for missing keys
pull/8932/head
William FH 1 year ago committed by GitHub
parent 2d078c7767
commit b2eb4ff0fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -314,28 +314,28 @@ def _validate_example_inputs_for_chain(
"""Validate that the example inputs match the chain input keys."""
if input_mapper:
first_inputs = input_mapper(first_example.inputs)
missing_keys = set(chain.input_keys).difference(first_inputs)
if not isinstance(first_inputs, dict):
raise InputFormatError(
"When using an input_mapper to prepare dataset example"
" inputs for a chain, the mapped value must be a dictionary."
f"\nGot: {first_inputs} of type {type(first_inputs)}."
)
if not set(first_inputs.keys()) == set(chain.input_keys):
if missing_keys:
raise InputFormatError(
"When using an input_mapper to prepare dataset example inputs"
" for a chain mapped value must have keys that match the chain's"
" expected input keys."
"Missing keys after loading example using input_mapper."
f"\nExpected: {chain.input_keys}. Got: {first_inputs.keys()}"
)
else:
first_inputs = first_example.inputs
missing_keys = set(chain.input_keys).difference(first_inputs)
if len(first_inputs) == 1 and len(chain.input_keys) == 1:
# We can pass this through the run method.
# Refrain from calling to validate.
pass
elif not set(first_inputs.keys()) == set(chain.input_keys):
elif missing_keys:
raise InputFormatError(
"Example inputs do not match chain input keys."
"Example inputs missing expected chain input keys."
" Please provide an input_mapper to convert the example.inputs"
" to a compatible format for the chain you wish to evaluate."
f"Expected: {chain.input_keys}. "

@ -124,7 +124,7 @@ def test__validate_example_inputs_for_chain_input_mapper() -> None:
assert "baz" in inputs
return {"not foo": "foo", "not baz": "baz"}
with pytest.raises(InputFormatError, match="keys that match"):
with pytest.raises(InputFormatError, match="Missing keys after loading example"):
_validate_example_inputs_for_chain(mock_, chain, wrong_output_keys)
def input_mapper(inputs: dict) -> dict:
@ -148,9 +148,7 @@ def test__validate_example_inputs_for_chain_single_input_multi_expect() -> None:
mock_.inputs = {"foo": "bar"}
chain = mock.MagicMock()
chain.input_keys = ["def not foo", "oh here is another"]
with pytest.raises(
InputFormatError, match="Example inputs do not match chain input keys."
):
with pytest.raises(InputFormatError, match="Example inputs missing expected"):
_validate_example_inputs_for_chain(mock_, chain, None)

@ -1,349 +0,0 @@
"""Test the LangSmith evaluation helpers."""
import uuid
from datetime import datetime
from typing import Any, Dict, Iterator, List, Optional, Union
from unittest import mock
import pytest
from langsmith.client import Client
from langsmith.schemas import Dataset, Example
from langchain.chains.base import Chain
from langchain.chains.transform import TransformChain
from langchain.schema.language_model import BaseLanguageModel
from langchain.smith.evaluation.runner_utils import (
InputFormatError,
_get_messages,
_get_prompt,
_run_llm,
_run_llm_or_chain,
_validate_example_inputs_for_chain,
_validate_example_inputs_for_language_model,
arun_on_dataset,
)
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
_EXAMPLE_MESSAGE = {
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
"type": "human",
}
_VALID_MESSAGES = [
{"messages": [_EXAMPLE_MESSAGE], "other_key": "value"},
{"messages": [], "other_key": "value"},
{
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE]],
"other_key": "value",
},
{"any_key": [_EXAMPLE_MESSAGE]},
{"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE]]},
]
_VALID_PROMPTS = [
{"prompts": ["foo"], "other_key": "value"},
{"prompt": "foo", "other_key": ["bar", "baz"]},
{"some_key": "foo"},
{"some_key": ["foo"]},
]
_INVALID_PROMPTS = (
[
{"prompts": "foo"},
{"prompt": ["foo"]},
{"some_key": 3},
{"some_key": "foo", "other_key": "bar"},
],
)
@pytest.mark.parametrize(
"inputs",
_VALID_MESSAGES,
)
def test__get_messages_valid(inputs: Dict[str, Any]) -> None:
{"messages": []}
_get_messages(inputs)
@pytest.mark.parametrize(
"inputs",
_VALID_PROMPTS,
)
def test__get_prompts_valid(inputs: Dict[str, Any]) -> None:
_get_prompt(inputs)
@pytest.mark.parametrize(
"inputs",
_VALID_PROMPTS,
)
def test__validate_example_inputs_for_language_model(inputs: Dict[str, Any]) -> None:
mock_ = mock.MagicMock()
mock_.inputs = inputs
_validate_example_inputs_for_language_model(mock_, None)
@pytest.mark.parametrize(
"inputs",
_INVALID_PROMPTS,
)
def test__validate_example_inputs_for_language_model_invalid(
inputs: Dict[str, Any]
) -> None:
mock_ = mock.MagicMock()
mock_.inputs = inputs
with pytest.raises(InputFormatError):
_validate_example_inputs_for_language_model(mock_, None)
def test__validate_example_inputs_for_chain_single_input() -> None:
mock_ = mock.MagicMock()
mock_.inputs = {"foo": "bar"}
chain = mock.MagicMock()
chain.input_keys = ["def not foo"]
_validate_example_inputs_for_chain(mock_, chain, None)
def test__validate_example_inputs_for_chain_input_mapper() -> None:
mock_ = mock.MagicMock()
mock_.inputs = {"foo": "bar", "baz": "qux"}
chain = mock.MagicMock()
chain.input_keys = ["not foo", "not baz", "not qux"]
def wrong_output_format(inputs: dict) -> str:
assert "foo" in inputs
assert "baz" in inputs
return "hehe"
with pytest.raises(InputFormatError, match="must be a dictionary"):
_validate_example_inputs_for_chain(mock_, chain, wrong_output_format)
def wrong_output_keys(inputs: dict) -> dict:
assert "foo" in inputs
assert "baz" in inputs
return {"not foo": "foo", "not baz": "baz"}
with pytest.raises(InputFormatError, match="keys that match"):
_validate_example_inputs_for_chain(mock_, chain, wrong_output_keys)
def input_mapper(inputs: dict) -> dict:
assert "foo" in inputs
assert "baz" in inputs
return {"not foo": inputs["foo"], "not baz": inputs["baz"], "not qux": "qux"}
_validate_example_inputs_for_chain(mock_, chain, input_mapper)
def test__validate_example_inputs_for_chain_multi_io() -> None:
mock_ = mock.MagicMock()
mock_.inputs = {"foo": "bar", "baz": "qux"}
chain = mock.MagicMock()
chain.input_keys = ["foo", "baz"]
_validate_example_inputs_for_chain(mock_, chain, None)
def test__validate_example_inputs_for_chain_single_input_multi_expect() -> None:
mock_ = mock.MagicMock()
mock_.inputs = {"foo": "bar"}
chain = mock.MagicMock()
chain.input_keys = ["def not foo", "oh here is another"]
with pytest.raises(
InputFormatError, match="Example inputs do not match chain input keys."
):
_validate_example_inputs_for_chain(mock_, chain, None)
@pytest.mark.parametrize("inputs", _INVALID_PROMPTS)
def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
with pytest.raises(InputFormatError):
_get_prompt(inputs)
def test_run_llm_or_chain_with_input_mapper() -> None:
example = Example(
id=uuid.uuid4(),
created_at=_CREATED_AT,
inputs={"the wrong input": "1", "another key": "2"},
outputs={"output": "2"},
dataset_id=str(uuid.uuid4()),
)
def run_val(inputs: dict) -> dict:
assert "the right input" in inputs
return {"output": "2"}
mock_chain = TransformChain(
input_variables=["the right input"],
output_variables=["output"],
transform=run_val,
)
def input_mapper(inputs: dict) -> dict:
assert "the wrong input" in inputs
return {"the right input": inputs["the wrong input"]}
result = _run_llm_or_chain(
example, lambda: mock_chain, n_repetitions=1, input_mapper=input_mapper
)
assert len(result) == 1
assert result[0] == {"output": "2", "the right input": "1"}
bad_result = _run_llm_or_chain(
example,
lambda: mock_chain,
n_repetitions=1,
)
assert len(bad_result) == 1
assert "Error" in bad_result[0]
# Try with LLM
def llm_input_mapper(inputs: dict) -> str:
assert "the wrong input" in inputs
return "the right input"
mock_llm = FakeLLM(queries={"the right input": "somenumber"})
result = _run_llm_or_chain(
example, mock_llm, n_repetitions=1, input_mapper=llm_input_mapper
)
assert len(result) == 1
llm_result = result[0]
assert isinstance(llm_result, str)
assert llm_result == "somenumber"
@pytest.mark.parametrize(
"inputs",
[
{"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"},
{
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE],
"other_key": "value",
},
{"prompts": "foo"},
{},
],
)
def test__get_messages_invalid(inputs: Dict[str, Any]) -> None:
with pytest.raises(InputFormatError):
_get_messages(inputs)
@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES)
def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
llm = FakeLLM()
_run_llm(llm, inputs, mock.MagicMock())
@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS)
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
llm = FakeChatModel()
_run_llm(llm, inputs, mock.MagicMock())
@pytest.mark.asyncio
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = Dataset(
id=uuid.uuid4(),
name="test",
description="Test dataset",
owner_id="owner",
created_at=_CREATED_AT,
tenant_id=_TENANT_ID,
)
uuids = [
"0c193153-2309-4704-9a47-17aee4fb25c8",
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
]
examples = [
Example(
id=uuids[0],
created_at=_CREATED_AT,
inputs={"input": "1"},
outputs={"output": "2"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[1],
created_at=_CREATED_AT,
inputs={"input": "3"},
outputs={"output": "4"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[2],
created_at=_CREATED_AT,
inputs={"input": "5"},
outputs={"output": "6"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[3],
created_at=_CREATED_AT,
inputs={"input": "7"},
outputs={"output": "8"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[4],
created_at=_CREATED_AT,
inputs={"input": "9"},
outputs={"output": "10"},
dataset_id=str(uuid.uuid4()),
),
]
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
return dataset
def mock_list_examples(*args: Any, **kwargs: Any) -> Iterator[Example]:
return iter(examples)
async def mock_arun_chain(
example: Example,
llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
tags: Optional[List[str]] = None,
callbacks: Optional[Any] = None,
**kwargs: Any,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
def mock_create_project(*args: Any, **kwargs: Any) -> Any:
proj = mock.MagicMock()
proj.id = "123"
return proj
with mock.patch.object(
Client, "read_dataset", new=mock_read_dataset
), mock.patch.object(Client, "list_examples", new=mock_list_examples), mock.patch(
"langchain.smith.evaluation.runner_utils._arun_llm_or_chain",
new=mock_arun_chain,
), mock.patch.object(
Client, "create_project", new=mock_create_project
):
client = Client(api_url="http://localhost:1984", api_key="123")
chain = mock.MagicMock()
chain.input_keys = ["foothing"]
num_repetitions = 3
results = await arun_on_dataset(
dataset_name="test",
llm_or_chain_factory=lambda: chain,
concurrency_level=2,
project_name="test_project",
num_repetitions=num_repetitions,
client=client,
)
expected = {
uuid_: [
{"result": f"Result for example {uuid.UUID(uuid_)}"}
for _ in range(num_repetitions)
]
for uuid_ in uuids
}
assert results["results"] == expected
Loading…
Cancel
Save