langchain/tests/unit_tests/chains/query_constructor/test_parser.py

117 lines
3.6 KiB
Python
Raw Normal View History

2023-04-27 20:42:12 +00:00
"""Test LLM-generated structured query parsing."""
from typing import Any, cast
import lark
import pytest
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
)
from langchain.chains.query_constructor.parser import get_parser
DEFAULT_PARSER = get_parser()
@pytest.mark.parametrize("x", ("", "foo", 'foo("bar", "baz")'))
def test_parse_invalid_grammar(x: str) -> None:
with pytest.raises((ValueError, lark.exceptions.UnexpectedToken)):
DEFAULT_PARSER.parse(x)
def test_parse_comparison() -> None:
comp = 'gte("foo", 2)'
expected = Comparison(comparator=Comparator.GTE, attribute="foo", value=2)
for input in (
comp,
comp.replace('"', "'"),
comp.replace(" ", ""),
comp.replace(" ", " "),
comp.replace("(", " ("),
comp.replace(",", ", "),
comp.replace("2", "2.0"),
):
actual = DEFAULT_PARSER.parse(input)
assert expected == actual
def test_parse_operation() -> None:
op = 'and(eq("foo", "bar"), lt("baz", 1995.25))'
eq = Comparison(comparator=Comparator.EQ, attribute="foo", value="bar")
lt = Comparison(comparator=Comparator.LT, attribute="baz", value=1995.25)
expected = Operation(operator=Operator.AND, arguments=[eq, lt])
for input in (
op,
op.replace('"', "'"),
op.replace(" ", ""),
op.replace(" ", " "),
op.replace("(", " ("),
op.replace(",", ", "),
op.replace("25", "250"),
):
actual = DEFAULT_PARSER.parse(input)
assert expected == actual
def test_parse_nested_operation() -> None:
op = 'and(or(eq("a", "b"), eq("a", "c"), eq("a", "d")), not(eq("z", "foo")))'
eq1 = Comparison(comparator=Comparator.EQ, attribute="a", value="b")
eq2 = Comparison(comparator=Comparator.EQ, attribute="a", value="c")
eq3 = Comparison(comparator=Comparator.EQ, attribute="a", value="d")
eq4 = Comparison(comparator=Comparator.EQ, attribute="z", value="foo")
_not = Operation(operator=Operator.NOT, arguments=[eq4])
_or = Operation(operator=Operator.OR, arguments=[eq1, eq2, eq3])
expected = Operation(operator=Operator.AND, arguments=[_or, _not])
actual = DEFAULT_PARSER.parse(op)
assert expected == actual
def test_parse_disallowed_comparator() -> None:
parser = get_parser(allowed_comparators=[Comparator.EQ])
with pytest.raises(ValueError):
parser.parse('gt("a", 2)')
def test_parse_disallowed_operator() -> None:
parser = get_parser(allowed_operators=[Operator.AND])
with pytest.raises(ValueError):
parser.parse('not(gt("a", 2))')
def _test_parse_value(x: Any) -> None:
parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})')))
actual = parsed.value
assert actual == x
@pytest.mark.parametrize("x", (-1, 0, 1_000_000))
def test_parse_int_value(x: int) -> None:
_test_parse_value(x)
@pytest.mark.parametrize("x", (-1.001, 0.00000002, 1_234_567.6543210))
def test_parse_float_value(x: float) -> None:
_test_parse_value(x)
@pytest.mark.parametrize("x", ([], [1, "b", "true"]))
def test_parse_list_value(x: list) -> None:
_test_parse_value(x)
@pytest.mark.parametrize("x", ('""', '" "', '"foo"', "'foo'"))
def test_parse_string_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value
assert actual == x[1:-1]
@pytest.mark.parametrize("x", ("true", "True", "TRUE", "false", "False", "FALSE"))
def test_parse_bool_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value
expected = x.lower() == "true"
assert actual == expected