Add query parsing unit tests (#3672)

fix_agent_callbacks
Davis Chase 1 year ago committed by GitHub
parent 03c05b15f6
commit b807a114e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,7 +33,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
parsed = parse_json_markdown(text, expected_keys) parsed = parse_json_markdown(text, expected_keys)
if len(parsed["query"]) == 0: if len(parsed["query"]) == 0:
parsed["query"] = " " parsed["query"] = " "
if parsed["filter"] == "NO_FILTER": if parsed["filter"] == "NO_FILTER" or not parsed["filter"]:
parsed["filter"] = None parsed["filter"] = None
else: else:
parsed["filter"] = self.ast_parse(parsed["filter"]) parsed["filter"] = self.ast_parse(parsed["filter"])

@ -20,19 +20,21 @@ GRAMMAR = """
func_call: CNAME "(" [args] ")" func_call: CNAME "(" [args] ")"
?value: SIGNED_NUMBER -> number ?value: SIGNED_INT -> int
| SIGNED_FLOAT -> float
| list | list
| string | string
| "false" -> false | ("false" | "False" | "FALSE") -> false
| "true" -> true | ("true" | "True" | "TRUE") -> true
args: expr ("," expr)* args: expr ("," expr)*
string: ESCAPED_STRING string: /'[^']*'/ | ESCAPED_STRING
list: "[" [args] "]" list: "[" [args] "]"
%import common.CNAME %import common.CNAME
%import common.SIGNED_NUMBER
%import common.ESCAPED_STRING %import common.ESCAPED_STRING
%import common.SIGNED_FLOAT
%import common.SIGNED_INT
%import common.WS %import common.WS
%ignore WS %ignore WS
""" """
@ -44,7 +46,7 @@ class QueryTransformer(Transformer):
self, self,
*args: Any, *args: Any,
allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]], allowed_operators: Optional[Sequence[Operator]] = None,
**kwargs: Any, **kwargs: Any,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -93,9 +95,14 @@ class QueryTransformer(Transformer):
return True return True
def list(self, item: Any) -> list: def list(self, item: Any) -> list:
if item is None:
return []
return list(item) return list(item)
def number(self, item: Any) -> float: def int(self, item: Any) -> int:
return int(item)
def float(self, item: Any) -> float:
return float(item) return float(item)
def string(self, item: Any) -> str: def string(self, item: Any) -> str:

@ -32,7 +32,7 @@ FULL_ANSWER = """\
{{ {{
"query": "teenager love", "query": "teenager love",
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \ "filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\gg\"))" lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))"
}}""" }}"""
NO_FILTER_ANSWER = """\ NO_FILTER_ANSWER = """\

10
poetry.lock generated

@ -571,7 +571,7 @@ name = "azure-core"
version = "1.26.4" version = "1.26.4"
description = "Microsoft Azure Core Library for Python" description = "Microsoft Azure Core Library for Python"
category = "main" category = "main"
optional = false optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "azure-core-1.26.4.zip", hash = "sha256:075fe06b74c3007950dd93d49440c2f3430fd9b4a5a2756ec8c79454afc989c6"}, {file = "azure-core-1.26.4.zip", hash = "sha256:075fe06b74c3007950dd93d49440c2f3430fd9b4a5a2756ec8c79454afc989c6"},
@ -3488,7 +3488,7 @@ name = "lark"
version = "1.1.5" version = "1.1.5"
description = "a modern parsing library" description = "a modern parsing library"
category = "main" category = "main"
optional = true optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "lark-1.1.5-py3-none-any.whl", hash = "sha256:8476f9903e93fbde4f6c327f74d79e9b4bd0ed9294c5dfa3164ab8c581b5de2a"}, {file = "lark-1.1.5-py3-none-any.whl", hash = "sha256:8476f9903e93fbde4f6c327f74d79e9b4bd0ed9294c5dfa3164ab8c581b5de2a"},
@ -9393,8 +9393,8 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
cffi = ["cffi (>=1.11)"] cffi = ["cffi (>=1.11)"]
[extras] [extras]
all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "lancedb", "lark", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
azure = ["azure-cosmos", "azure-identity", "openai"] azure = ["azure-core", "azure-cosmos", "azure-identity", "openai"]
cohere = ["cohere"] cohere = ["cohere"]
embeddings = ["sentence-transformers"] embeddings = ["sentence-transformers"]
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"] llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"]
@ -9404,4 +9404,4 @@ qdrant = ["qdrant-client"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "2979794d110362d851c1ef78075f6f394c62cbe97f7a331eeacd0d111e823b40" content-hash = "f7ff48dfce65630ea5c67287e91d923be83b9d0d9dd68639afcbc29f5f6f9c5f"

@ -99,6 +99,7 @@ pytest-watcher = "^0.2.6"
freezegun = "^1.2.2" freezegun = "^1.2.2"
responses = "^0.22.0" responses = "^0.22.0"
pytest-asyncio = "^0.20.3" pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
[tool.poetry.group.test_integration] [tool.poetry.group.test_integration]
optional = true optional = true

@ -0,0 +1,116 @@
"""Test LLM-generated structured query parsing."""
from typing import Any, cast
import lark
import pytest
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
)
from langchain.chains.query_constructor.parser import get_parser
DEFAULT_PARSER = get_parser()
@pytest.mark.parametrize("x", ("", "foo", 'foo("bar", "baz")'))
def test_parse_invalid_grammar(x: str) -> None:
with pytest.raises((ValueError, lark.exceptions.UnexpectedToken)):
DEFAULT_PARSER.parse(x)
def test_parse_comparison() -> None:
comp = 'gte("foo", 2)'
expected = Comparison(comparator=Comparator.GTE, attribute="foo", value=2)
for input in (
comp,
comp.replace('"', "'"),
comp.replace(" ", ""),
comp.replace(" ", " "),
comp.replace("(", " ("),
comp.replace(",", ", "),
comp.replace("2", "2.0"),
):
actual = DEFAULT_PARSER.parse(input)
assert expected == actual
def test_parse_operation() -> None:
op = 'and(eq("foo", "bar"), lt("baz", 1995.25))'
eq = Comparison(comparator=Comparator.EQ, attribute="foo", value="bar")
lt = Comparison(comparator=Comparator.LT, attribute="baz", value=1995.25)
expected = Operation(operator=Operator.AND, arguments=[eq, lt])
for input in (
op,
op.replace('"', "'"),
op.replace(" ", ""),
op.replace(" ", " "),
op.replace("(", " ("),
op.replace(",", ", "),
op.replace("25", "250"),
):
actual = DEFAULT_PARSER.parse(input)
assert expected == actual
def test_parse_nested_operation() -> None:
op = 'and(or(eq("a", "b"), eq("a", "c"), eq("a", "d")), not(eq("z", "foo")))'
eq1 = Comparison(comparator=Comparator.EQ, attribute="a", value="b")
eq2 = Comparison(comparator=Comparator.EQ, attribute="a", value="c")
eq3 = Comparison(comparator=Comparator.EQ, attribute="a", value="d")
eq4 = Comparison(comparator=Comparator.EQ, attribute="z", value="foo")
_not = Operation(operator=Operator.NOT, arguments=[eq4])
_or = Operation(operator=Operator.OR, arguments=[eq1, eq2, eq3])
expected = Operation(operator=Operator.AND, arguments=[_or, _not])
actual = DEFAULT_PARSER.parse(op)
assert expected == actual
def test_parse_disallowed_comparator() -> None:
parser = get_parser(allowed_comparators=[Comparator.EQ])
with pytest.raises(ValueError):
parser.parse('gt("a", 2)')
def test_parse_disallowed_operator() -> None:
parser = get_parser(allowed_operators=[Operator.AND])
with pytest.raises(ValueError):
parser.parse('not(gt("a", 2))')
def _test_parse_value(x: Any) -> None:
parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})')))
actual = parsed.value
assert actual == x
@pytest.mark.parametrize("x", (-1, 0, 1_000_000))
def test_parse_int_value(x: int) -> None:
_test_parse_value(x)
@pytest.mark.parametrize("x", (-1.001, 0.00000002, 1_234_567.6543210))
def test_parse_float_value(x: float) -> None:
_test_parse_value(x)
@pytest.mark.parametrize("x", ([], [1, "b", "true"]))
def test_parse_list_value(x: list) -> None:
_test_parse_value(x)
@pytest.mark.parametrize("x", ('""', '" "', '"foo"', "'foo'"))
def test_parse_string_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value
assert actual == x[1:-1]
@pytest.mark.parametrize("x", ("true", "True", "TRUE", "false", "False", "FALSE"))
def test_parse_bool_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value
expected = x.lower() == "true"
assert actual == expected
Loading…
Cancel
Save