Add query parsing unit tests (#3672)

fix_agent_callbacks
Davis Chase 1 year ago committed by GitHub
parent 03c05b15f6
commit b807a114e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,7 +33,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
parsed = parse_json_markdown(text, expected_keys)
if len(parsed["query"]) == 0:
parsed["query"] = " "
if parsed["filter"] == "NO_FILTER":
if parsed["filter"] == "NO_FILTER" or not parsed["filter"]:
parsed["filter"] = None
else:
parsed["filter"] = self.ast_parse(parsed["filter"])

@ -20,19 +20,21 @@ GRAMMAR = """
func_call: CNAME "(" [args] ")"
?value: SIGNED_NUMBER -> number
?value: SIGNED_INT -> int
| SIGNED_FLOAT -> float
| list
| string
| "false" -> false
| "true" -> true
| ("false" | "False" | "FALSE") -> false
| ("true" | "True" | "TRUE") -> true
args: expr ("," expr)*
string: ESCAPED_STRING
string: /'[^']*'/ | ESCAPED_STRING
list: "[" [args] "]"
%import common.CNAME
%import common.SIGNED_NUMBER
%import common.ESCAPED_STRING
%import common.SIGNED_FLOAT
%import common.SIGNED_INT
%import common.WS
%ignore WS
"""
@ -44,7 +46,7 @@ class QueryTransformer(Transformer):
self,
*args: Any,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]],
allowed_operators: Optional[Sequence[Operator]] = None,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
@ -93,9 +95,14 @@ class QueryTransformer(Transformer):
return True
def list(self, item: Any) -> list:
if item is None:
return []
return list(item)
def number(self, item: Any) -> float:
def int(self, item: Any) -> int:
return int(item)
def float(self, item: Any) -> float:
return float(item)
def string(self, item: Any) -> str:

@ -32,7 +32,7 @@ FULL_ANSWER = """\
{{
"query": "teenager love",
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\gg\"))"
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))"
}}"""
NO_FILTER_ANSWER = """\

10
poetry.lock generated

@ -571,7 +571,7 @@ name = "azure-core"
version = "1.26.4"
description = "Microsoft Azure Core Library for Python"
category = "main"
optional = false
optional = true
python-versions = ">=3.7"
files = [
{file = "azure-core-1.26.4.zip", hash = "sha256:075fe06b74c3007950dd93d49440c2f3430fd9b4a5a2756ec8c79454afc989c6"},
@ -3488,7 +3488,7 @@ name = "lark"
version = "1.1.5"
description = "a modern parsing library"
category = "main"
optional = true
optional = false
python-versions = "*"
files = [
{file = "lark-1.1.5-py3-none-any.whl", hash = "sha256:8476f9903e93fbde4f6c327f74d79e9b4bd0ed9294c5dfa3164ab8c581b5de2a"},
@ -9393,8 +9393,8 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
cffi = ["cffi (>=1.11)"]
[extras]
all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
azure = ["azure-cosmos", "azure-identity", "openai"]
all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "lancedb", "lark", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
azure = ["azure-core", "azure-cosmos", "azure-identity", "openai"]
cohere = ["cohere"]
embeddings = ["sentence-transformers"]
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"]
@ -9404,4 +9404,4 @@ qdrant = ["qdrant-client"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "2979794d110362d851c1ef78075f6f394c62cbe97f7a331eeacd0d111e823b40"
content-hash = "f7ff48dfce65630ea5c67287e91d923be83b9d0d9dd68639afcbc29f5f6f9c5f"

@ -99,6 +99,7 @@ pytest-watcher = "^0.2.6"
freezegun = "^1.2.2"
responses = "^0.22.0"
pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
[tool.poetry.group.test_integration]
optional = true

@ -0,0 +1,116 @@
"""Test LLM-generated structured query parsing."""
from typing import Any, cast
import lark
import pytest
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
)
from langchain.chains.query_constructor.parser import get_parser
DEFAULT_PARSER = get_parser()
@pytest.mark.parametrize("x", ("", "foo", 'foo("bar", "baz")'))
def test_parse_invalid_grammar(x: str) -> None:
with pytest.raises((ValueError, lark.exceptions.UnexpectedToken)):
DEFAULT_PARSER.parse(x)
def test_parse_comparison() -> None:
comp = 'gte("foo", 2)'
expected = Comparison(comparator=Comparator.GTE, attribute="foo", value=2)
for input in (
comp,
comp.replace('"', "'"),
comp.replace(" ", ""),
comp.replace(" ", " "),
comp.replace("(", " ("),
comp.replace(",", ", "),
comp.replace("2", "2.0"),
):
actual = DEFAULT_PARSER.parse(input)
assert expected == actual
def test_parse_operation() -> None:
op = 'and(eq("foo", "bar"), lt("baz", 1995.25))'
eq = Comparison(comparator=Comparator.EQ, attribute="foo", value="bar")
lt = Comparison(comparator=Comparator.LT, attribute="baz", value=1995.25)
expected = Operation(operator=Operator.AND, arguments=[eq, lt])
for input in (
op,
op.replace('"', "'"),
op.replace(" ", ""),
op.replace(" ", " "),
op.replace("(", " ("),
op.replace(",", ", "),
op.replace("25", "250"),
):
actual = DEFAULT_PARSER.parse(input)
assert expected == actual
def test_parse_nested_operation() -> None:
op = 'and(or(eq("a", "b"), eq("a", "c"), eq("a", "d")), not(eq("z", "foo")))'
eq1 = Comparison(comparator=Comparator.EQ, attribute="a", value="b")
eq2 = Comparison(comparator=Comparator.EQ, attribute="a", value="c")
eq3 = Comparison(comparator=Comparator.EQ, attribute="a", value="d")
eq4 = Comparison(comparator=Comparator.EQ, attribute="z", value="foo")
_not = Operation(operator=Operator.NOT, arguments=[eq4])
_or = Operation(operator=Operator.OR, arguments=[eq1, eq2, eq3])
expected = Operation(operator=Operator.AND, arguments=[_or, _not])
actual = DEFAULT_PARSER.parse(op)
assert expected == actual
def test_parse_disallowed_comparator() -> None:
parser = get_parser(allowed_comparators=[Comparator.EQ])
with pytest.raises(ValueError):
parser.parse('gt("a", 2)')
def test_parse_disallowed_operator() -> None:
parser = get_parser(allowed_operators=[Operator.AND])
with pytest.raises(ValueError):
parser.parse('not(gt("a", 2))')
def _test_parse_value(x: Any) -> None:
parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})')))
actual = parsed.value
assert actual == x
@pytest.mark.parametrize("x", (-1, 0, 1_000_000))
def test_parse_int_value(x: int) -> None:
_test_parse_value(x)
@pytest.mark.parametrize("x", (-1.001, 0.00000002, 1_234_567.6543210))
def test_parse_float_value(x: float) -> None:
_test_parse_value(x)
@pytest.mark.parametrize("x", ([], [1, "b", "true"]))
def test_parse_list_value(x: list) -> None:
_test_parse_value(x)
@pytest.mark.parametrize("x", ('""', '" "', '"foo"', "'foo'"))
def test_parse_string_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value
assert actual == x[1:-1]
@pytest.mark.parametrize("x", ("true", "True", "TRUE", "false", "False", "FALSE"))
def test_parse_bool_value(x: str) -> None:
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value
expected = x.lower() == "true"
assert actual == expected
Loading…
Cancel
Save