langchain[patch]: support more comparators in Milvus self-querying retriever (#16076)

- **Description:** Support IN and LIKE comparators in Milvus
self-querying retriever, based on [Boolean Expression
Rules](https://milvus.io/docs/boolean.md)
  - **Issue:** No
  - **Dependencies:** No
  - **Twitter handle:** No

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
pull/16062/head
ChengZi 5 months ago committed by GitHub
parent 9c2f1f07a0
commit 8597484195
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -16,28 +16,36 @@ COMPARATOR_TO_BER = {
Comparator.GTE: ">=",
Comparator.LT: "<",
Comparator.LTE: "<=",
Comparator.IN: "in",
Comparator.LIKE: "like",
}
UNARY_OPERATORS = [Operator.NOT]
def process_value(value: Union[int, float, str]) -> str:
def process_value(value: Union[int, float, str], comparator: Comparator) -> str:
"""Convert a value to a string and add double quotes if it is a string.
It required for comparators involving strings.
Args:
value: The value to convert.
comparator: The comparator.
Returns:
The converted value as a string.
"""
#
if isinstance(value, str):
# If the value is already a string, add double quotes
return f'"{value}"'
if comparator is Comparator.LIKE:
# If the comparator is LIKE, add a percent sign after it for prefix matching
# and add double quotes
return f'"{value}%"'
else:
# If the value is already a string, add double quotes
return f'"{value}"'
else:
# If the valueis not a string, convert it to a string without double quotes
# If the value is not a string, convert it to a string without double quotes
return str(value)
@ -54,6 +62,8 @@ class MilvusTranslator(Visitor):
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.IN,
Comparator.LIKE,
]
def _format_func(self, func: Union[Operator, Comparator]) -> str:
@ -78,7 +88,7 @@ class MilvusTranslator(Visitor):
def visit_comparison(self, comparison: Comparison) -> str:
comparator = self._format_func(comparison.comparator)
processed_value = process_value(comparison.value)
processed_value = process_value(comparison.value, comparison.comparator)
attribute = comparison.attribute
return "( " + attribute + " " + comparator + " " + processed_value + " )"

@ -1,4 +1,6 @@
from typing import Dict, Tuple
from typing import Any, Dict, Tuple
import pytest
from langchain.chains.query_constructor.ir import (
Comparator,
@ -12,11 +14,22 @@ from langchain.retrievers.self_query.milvus import MilvusTranslator
DEFAULT_TRANSLATOR = MilvusTranslator()
def test_visit_comparison() -> None:
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=4)
expected = "( foo < 4 )"
@pytest.mark.parametrize(
"triplet",
[
(Comparator.EQ, 2, "( foo == 2 )"),
(Comparator.GT, 2, "( foo > 2 )"),
(Comparator.GTE, 2, "( foo >= 2 )"),
(Comparator.LT, 2, "( foo < 2 )"),
(Comparator.LTE, 2, "( foo <= 2 )"),
(Comparator.IN, ["bar", "abc"], "( foo in ['bar', 'abc'] )"),
(Comparator.LIKE, "bar", '( foo like "bar%" )'),
],
)
def test_visit_comparison(triplet: Tuple[Comparator, Any, str]) -> None:
comparator, value, expected = triplet
comp = Comparison(comparator=comparator, attribute="foo", value=value)
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual

Loading…
Cancel
Save