diff --git a/langchain/chains/query_constructor/parser.py b/langchain/chains/query_constructor/parser.py index a0c7f229..4c560bfe 100644 --- a/langchain/chains/query_constructor/parser.py +++ b/langchain/chains/query_constructor/parser.py @@ -68,11 +68,14 @@ class QueryTransformer(Transformer): def program(self, *items: Any) -> tuple: return items - def func_call(self, func_name: Any, *args: Any) -> FilterDirective: + def func_call(self, func_name: Any, args: list) -> FilterDirective: func = self._match_func_name(str(func_name)) if isinstance(func, Comparator): - return Comparison(comparator=func, attribute=args[0][0], value=args[0][1]) - return Operation(operator=func, arguments=args[0]) + return Comparison(comparator=func, attribute=args[0], value=args[1]) + elif len(args) == 1 and func in (Operator.AND, Operator.OR): + return args[0] + else: + return Operation(operator=func, arguments=args) def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]: if func_name in set(Comparator): diff --git a/tests/unit_tests/chains/query_constructor/test_parser.py b/tests/unit_tests/chains/query_constructor/test_parser.py index f4d68224..4b8c48b3 100644 --- a/tests/unit_tests/chains/query_constructor/test_parser.py +++ b/tests/unit_tests/chains/query_constructor/test_parser.py @@ -114,3 +114,11 @@ def test_parse_bool_value(x: str) -> None: actual = parsed.value expected = x.lower() == "true" assert actual == expected + + +@pytest.mark.parametrize("op", ("and", "or")) +@pytest.mark.parametrize("arg", ('eq("foo", 2)', 'and(eq("foo", 2), lte("bar", 1.1))')) +def test_parser_unpack_single_arg_operation(op: str, arg: str) -> None: + expected = DEFAULT_PARSER.parse(arg) + actual = DEFAULT_PARSER.parse(f"{op}({arg})") + assert expected == actual