forked from Archives/langchain
Add query parsing unit tests (#3672)
parent
03c05b15f6
commit
b807a114e4
@ -0,0 +1,116 @@
|
||||
"""Test LLM-generated structured query parsing."""
|
||||
from typing import Any, cast
|
||||
|
||||
import lark
|
||||
import pytest
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
)
|
||||
from langchain.chains.query_constructor.parser import get_parser
|
||||
|
||||
DEFAULT_PARSER = get_parser()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ("", "foo", 'foo("bar", "baz")'))
|
||||
def test_parse_invalid_grammar(x: str) -> None:
|
||||
with pytest.raises((ValueError, lark.exceptions.UnexpectedToken)):
|
||||
DEFAULT_PARSER.parse(x)
|
||||
|
||||
|
||||
def test_parse_comparison() -> None:
|
||||
comp = 'gte("foo", 2)'
|
||||
expected = Comparison(comparator=Comparator.GTE, attribute="foo", value=2)
|
||||
for input in (
|
||||
comp,
|
||||
comp.replace('"', "'"),
|
||||
comp.replace(" ", ""),
|
||||
comp.replace(" ", " "),
|
||||
comp.replace("(", " ("),
|
||||
comp.replace(",", ", "),
|
||||
comp.replace("2", "2.0"),
|
||||
):
|
||||
actual = DEFAULT_PARSER.parse(input)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_parse_operation() -> None:
|
||||
op = 'and(eq("foo", "bar"), lt("baz", 1995.25))'
|
||||
eq = Comparison(comparator=Comparator.EQ, attribute="foo", value="bar")
|
||||
lt = Comparison(comparator=Comparator.LT, attribute="baz", value=1995.25)
|
||||
expected = Operation(operator=Operator.AND, arguments=[eq, lt])
|
||||
for input in (
|
||||
op,
|
||||
op.replace('"', "'"),
|
||||
op.replace(" ", ""),
|
||||
op.replace(" ", " "),
|
||||
op.replace("(", " ("),
|
||||
op.replace(",", ", "),
|
||||
op.replace("25", "250"),
|
||||
):
|
||||
actual = DEFAULT_PARSER.parse(input)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_parse_nested_operation() -> None:
|
||||
op = 'and(or(eq("a", "b"), eq("a", "c"), eq("a", "d")), not(eq("z", "foo")))'
|
||||
eq1 = Comparison(comparator=Comparator.EQ, attribute="a", value="b")
|
||||
eq2 = Comparison(comparator=Comparator.EQ, attribute="a", value="c")
|
||||
eq3 = Comparison(comparator=Comparator.EQ, attribute="a", value="d")
|
||||
eq4 = Comparison(comparator=Comparator.EQ, attribute="z", value="foo")
|
||||
_not = Operation(operator=Operator.NOT, arguments=[eq4])
|
||||
_or = Operation(operator=Operator.OR, arguments=[eq1, eq2, eq3])
|
||||
expected = Operation(operator=Operator.AND, arguments=[_or, _not])
|
||||
actual = DEFAULT_PARSER.parse(op)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_parse_disallowed_comparator() -> None:
|
||||
parser = get_parser(allowed_comparators=[Comparator.EQ])
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse('gt("a", 2)')
|
||||
|
||||
|
||||
def test_parse_disallowed_operator() -> None:
|
||||
parser = get_parser(allowed_operators=[Operator.AND])
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse('not(gt("a", 2))')
|
||||
|
||||
|
||||
def _test_parse_value(x: Any) -> None:
|
||||
parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})')))
|
||||
actual = parsed.value
|
||||
assert actual == x
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", (-1, 0, 1_000_000))
|
||||
def test_parse_int_value(x: int) -> None:
|
||||
_test_parse_value(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", (-1.001, 0.00000002, 1_234_567.6543210))
|
||||
def test_parse_float_value(x: float) -> None:
|
||||
_test_parse_value(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ([], [1, "b", "true"]))
|
||||
def test_parse_list_value(x: list) -> None:
|
||||
_test_parse_value(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ('""', '" "', '"foo"', "'foo'"))
|
||||
def test_parse_string_value(x: str) -> None:
|
||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
actual = parsed.value
|
||||
assert actual == x[1:-1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x", ("true", "True", "TRUE", "false", "False", "FALSE"))
|
||||
def test_parse_bool_value(x: str) -> None:
|
||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||
actual = parsed.value
|
||||
expected = x.lower() == "true"
|
||||
assert actual == expected
|
Loading…
Reference in New Issue