From b807a114e43a2403022246ba8eb26186f7c248e6 Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Thu, 27 Apr 2023 13:42:12 -0700 Subject: [PATCH] Add query parsing unit tests (#3672) --- langchain/chains/query_constructor/base.py | 2 +- langchain/chains/query_constructor/parser.py | 21 ++-- langchain/chains/query_constructor/prompt.py | 2 +- poetry.lock | 10 +- pyproject.toml | 1 + .../chains/query_constructor/__init__.py | 0 .../chains/query_constructor/test_parser.py | 116 ++++++++++++++++++ 7 files changed, 138 insertions(+), 14 deletions(-) create mode 100644 tests/unit_tests/chains/query_constructor/__init__.py create mode 100644 tests/unit_tests/chains/query_constructor/test_parser.py diff --git a/langchain/chains/query_constructor/base.py b/langchain/chains/query_constructor/base.py index 1ff730fc..33d64ab7 100644 --- a/langchain/chains/query_constructor/base.py +++ b/langchain/chains/query_constructor/base.py @@ -33,7 +33,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): parsed = parse_json_markdown(text, expected_keys) if len(parsed["query"]) == 0: parsed["query"] = " " - if parsed["filter"] == "NO_FILTER": + if parsed["filter"] == "NO_FILTER" or not parsed["filter"]: parsed["filter"] = None else: parsed["filter"] = self.ast_parse(parsed["filter"]) diff --git a/langchain/chains/query_constructor/parser.py b/langchain/chains/query_constructor/parser.py index e59ec057..83672e6c 100644 --- a/langchain/chains/query_constructor/parser.py +++ b/langchain/chains/query_constructor/parser.py @@ -20,19 +20,21 @@ GRAMMAR = """ func_call: CNAME "(" [args] ")" - ?value: SIGNED_NUMBER -> number + ?value: SIGNED_INT -> int + | SIGNED_FLOAT -> float | list | string - | "false" -> false - | "true" -> true + | ("false" | "False" | "FALSE") -> false + | ("true" | "True" | "TRUE") -> true args: expr ("," expr)* - string: ESCAPED_STRING + string: /'[^']*'/ | ESCAPED_STRING list: "[" [args] "]" %import common.CNAME - %import common.SIGNED_NUMBER %import common.ESCAPED_STRING + %import common.SIGNED_FLOAT + %import common.SIGNED_INT %import common.WS %ignore WS """ @@ -44,7 +46,7 @@ class QueryTransformer(Transformer): self, *args: Any, allowed_comparators: Optional[Sequence[Comparator]] = None, - allowed_operators: Optional[Sequence[Operator]], + allowed_operators: Optional[Sequence[Operator]] = None, **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -93,9 +95,14 @@ class QueryTransformer(Transformer): return True def list(self, item: Any) -> list: + if item is None: + return [] return list(item) - def number(self, item: Any) -> float: + def int(self, item: Any) -> int: + return int(item) + + def float(self, item: Any) -> float: return float(item) def string(self, item: Any) -> str: diff --git a/langchain/chains/query_constructor/prompt.py b/langchain/chains/query_constructor/prompt.py index 89d8a60a..6282cc4d 100644 --- a/langchain/chains/query_constructor/prompt.py +++ b/langchain/chains/query_constructor/prompt.py @@ -32,7 +32,7 @@ FULL_ANSWER = """\ {{ "query": "teenager love", "filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \ -lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\gg\"))" +lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))" }}""" NO_FILTER_ANSWER = """\ diff --git a/poetry.lock b/poetry.lock index 6db2e611..a5ae60c1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -571,7 +571,7 @@ name = "azure-core" version = "1.26.4" description = "Microsoft Azure Core Library for Python" category = "main" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "azure-core-1.26.4.zip", hash = "sha256:075fe06b74c3007950dd93d49440c2f3430fd9b4a5a2756ec8c79454afc989c6"}, @@ -3488,7 +3488,7 @@ name = "lark" version = "1.1.5" description = "a modern parsing library" category = "main" -optional = true +optional = false python-versions = "*" files = [ {file = "lark-1.1.5-py3-none-any.whl", hash = "sha256:8476f9903e93fbde4f6c327f74d79e9b4bd0ed9294c5dfa3164ab8c581b5de2a"}, @@ -9393,8 +9393,8 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] -azure = ["azure-cosmos", "azure-identity", "openai"] +all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "lancedb", "lark", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] +azure = ["azure-core", "azure-cosmos", "azure-identity", "openai"] cohere = ["cohere"] embeddings = ["sentence-transformers"] llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"] @@ -9404,4 +9404,4 @@ qdrant = ["qdrant-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "2979794d110362d851c1ef78075f6f394c62cbe97f7a331eeacd0d111e823b40" +content-hash = "f7ff48dfce65630ea5c67287e91d923be83b9d0d9dd68639afcbc29f5f6f9c5f" diff --git a/pyproject.toml b/pyproject.toml index 3e7a1067..b277debe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,6 +99,7 @@ pytest-watcher = "^0.2.6" freezegun = "^1.2.2" responses = "^0.22.0" pytest-asyncio = "^0.20.3" +lark = "^1.1.5" [tool.poetry.group.test_integration] optional = true diff --git a/tests/unit_tests/chains/query_constructor/__init__.py b/tests/unit_tests/chains/query_constructor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/chains/query_constructor/test_parser.py b/tests/unit_tests/chains/query_constructor/test_parser.py new file mode 100644 index 00000000..f4d68224 --- /dev/null +++ b/tests/unit_tests/chains/query_constructor/test_parser.py @@ -0,0 +1,116 @@ +"""Test LLM-generated structured query parsing.""" +from typing import Any, cast + +import lark +import pytest + +from langchain.chains.query_constructor.ir import ( + Comparator, + Comparison, + Operation, + Operator, +) +from langchain.chains.query_constructor.parser import get_parser + +DEFAULT_PARSER = get_parser() + + +@pytest.mark.parametrize("x", ("", "foo", 'foo("bar", "baz")')) +def test_parse_invalid_grammar(x: str) -> None: + with pytest.raises((ValueError, lark.exceptions.UnexpectedToken)): + DEFAULT_PARSER.parse(x) + + +def test_parse_comparison() -> None: + comp = 'gte("foo", 2)' + expected = Comparison(comparator=Comparator.GTE, attribute="foo", value=2) + for input in ( + comp, + comp.replace('"', "'"), + comp.replace(" ", ""), + comp.replace(" ", " "), + comp.replace("(", " ("), + comp.replace(",", ", "), + comp.replace("2", "2.0"), + ): + actual = DEFAULT_PARSER.parse(input) + assert expected == actual + + +def test_parse_operation() -> None: + op = 'and(eq("foo", "bar"), lt("baz", 1995.25))' + eq = Comparison(comparator=Comparator.EQ, attribute="foo", value="bar") + lt = Comparison(comparator=Comparator.LT, attribute="baz", value=1995.25) + expected = Operation(operator=Operator.AND, arguments=[eq, lt]) + for input in ( + op, + op.replace('"', "'"), + op.replace(" ", ""), + op.replace(" ", " "), + op.replace("(", " ("), + op.replace(",", ", "), + op.replace("25", "250"), + ): + actual = DEFAULT_PARSER.parse(input) + assert expected == actual + + +def test_parse_nested_operation() -> None: + op = 'and(or(eq("a", "b"), eq("a", "c"), eq("a", "d")), not(eq("z", "foo")))' + eq1 = Comparison(comparator=Comparator.EQ, attribute="a", value="b") + eq2 = Comparison(comparator=Comparator.EQ, attribute="a", value="c") + eq3 = Comparison(comparator=Comparator.EQ, attribute="a", value="d") + eq4 = Comparison(comparator=Comparator.EQ, attribute="z", value="foo") + _not = Operation(operator=Operator.NOT, arguments=[eq4]) + _or = Operation(operator=Operator.OR, arguments=[eq1, eq2, eq3]) + expected = Operation(operator=Operator.AND, arguments=[_or, _not]) + actual = DEFAULT_PARSER.parse(op) + assert expected == actual + + +def test_parse_disallowed_comparator() -> None: + parser = get_parser(allowed_comparators=[Comparator.EQ]) + with pytest.raises(ValueError): + parser.parse('gt("a", 2)') + + +def test_parse_disallowed_operator() -> None: + parser = get_parser(allowed_operators=[Operator.AND]) + with pytest.raises(ValueError): + parser.parse('not(gt("a", 2))') + + +def _test_parse_value(x: Any) -> None: + parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})'))) + actual = parsed.value + assert actual == x + + +@pytest.mark.parametrize("x", (-1, 0, 1_000_000)) +def test_parse_int_value(x: int) -> None: + _test_parse_value(x) + + +@pytest.mark.parametrize("x", (-1.001, 0.00000002, 1_234_567.6543210)) +def test_parse_float_value(x: float) -> None: + _test_parse_value(x) + + +@pytest.mark.parametrize("x", ([], [1, "b", "true"])) +def test_parse_list_value(x: list) -> None: + _test_parse_value(x) + + +@pytest.mark.parametrize("x", ('""', '" "', '"foo"', "'foo'")) +def test_parse_string_value(x: str) -> None: + parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) + actual = parsed.value + assert actual == x[1:-1] + + +@pytest.mark.parametrize("x", ("true", "True", "TRUE", "false", "False", "FALSE")) +def test_parse_bool_value(x: str) -> None: + parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) + actual = parsed.value + expected = x.lower() == "true" + assert actual == expected