langchain/tests/unit_tests/retrievers/self_query/test_myscale.py
Harrison Chase 9bf5b0defa
Harrison/myscale self query (#6376)
Co-authored-by: Fangrui Liu <fangruil@moqi.ai>
Co-authored-by: 刘 方瑞 <fangrui.liu@outlook.com>
Co-authored-by: Fangrui.Liu <fangrui.liu@ubc.ca>
2023-06-18 16:53:10 -07:00

45 lines
1.3 KiB
Python

from typing import Any, Tuple
import pytest
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
)
from langchain.retrievers.self_query.myscale import MyScaleTranslator
DEFAULT_TRANSLATOR = MyScaleTranslator()
@pytest.mark.parametrize(
"triplet",
[
(Comparator.LT, 2, "metadata.foo < 2"),
(Comparator.LTE, 2, "metadata.foo <= 2"),
(Comparator.GT, 2, "metadata.foo > 2"),
(Comparator.GTE, 2, "metadata.foo >= 2"),
(Comparator.CONTAIN, 2, "has(metadata.foo,2)"),
(Comparator.LIKE, "bar", "metadata.foo ILIKE '%bar%'"),
],
)
def test_visit_comparison(triplet: Tuple[Comparator, Any, str]) -> None:
comparator, value, expected = triplet
comp = Comparison(comparator=comparator, attribute="foo", value=value)
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
def test_visit_operation() -> None:
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
expected = "metadata.foo < 2 AND metadata.bar = 'baz'"
actual = DEFAULT_TRANSLATOR.visit_operation(op)
assert expected == actual