Additional Weaviate Filter Comparators (#10522)

### Description
When using Weaviate Self-Retrievers, certain common filter comparators
generated by user queries were unimplemented, resulting in errors. This
PR implements some of them. All linting and format commands have been
run and tests passed.
### Issue
#10474
### Dependencies
timestamp module

---------

Co-authored-by: Patrick Randell <prandell@deloitte.com.au>
pull/6605/head
Patrick Randell 1 year ago committed by GitHub
parent 79011f835f
commit 1d678f805f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,4 @@
from datetime import date, datetime
from typing import Dict, Tuple, Union
from langchain.chains.query_constructor.ir import (
@ -16,12 +17,28 @@ class WeaviateTranslator(Visitor):
allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators."""
allowed_comparators = [Comparator.EQ]
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GTE,
Comparator.LTE,
Comparator.LT,
Comparator.GT,
]
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
# https://weaviate.io/developers/weaviate/api/graphql/filters
map_dict = {Operator.AND: "And", Operator.OR: "Or", Comparator.EQ: "Equal"}
map_dict = {
Operator.AND: "And",
Operator.OR: "Or",
Comparator.EQ: "Equal",
Comparator.NE: "NotEqual",
Comparator.GTE: "GreaterThanEqual",
Comparator.LTE: "LessThanEqual",
Comparator.LT: "LessThan",
Comparator.GT: "GreaterThan",
}
return map_dict[func]
def visit_operation(self, operation: Operation) -> Dict:
@ -29,11 +46,25 @@ class WeaviateTranslator(Visitor):
return {"operator": self._format_func(operation.operator), "operands": args}
def visit_comparison(self, comparison: Comparison) -> Dict:
return {
value_type = "valueText"
if isinstance(comparison.value, bool):
value_type = "valueBoolean"
elif isinstance(comparison.value, float):
value_type = "valueNumber"
elif isinstance(comparison.value, int):
value_type = "valueInt"
elif isinstance(comparison.value, datetime) or isinstance(
comparison.value, date
):
value_type = "valueDate"
# ISO 8601 timestamp, formatted as RFC3339
comparison.value = comparison.value.strftime("%Y-%m-%dT%H:%M:%SZ")
filter = {
"path": [comparison.attribute],
"operator": self._format_func(comparison.comparator),
"valueText": comparison.value,
}
filter[value_type] = comparison.value
return filter
def visit_structured_query(
self, structured_query: StructuredQuery

@ -1,3 +1,4 @@
from datetime import date, datetime
from typing import Dict, Tuple
from langchain.chains.query_constructor.ir import (
@ -19,18 +20,75 @@ def test_visit_comparison() -> None:
assert expected == actual
def test_visit_comparison_integer() -> None:
comp = Comparison(comparator=Comparator.GTE, attribute="foo", value=1)
expected = {"operator": "GreaterThanEqual", "path": ["foo"], "valueInt": 1}
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
def test_visit_comparison_number() -> None:
comp = Comparison(comparator=Comparator.GT, attribute="foo", value=1.4)
expected = {"operator": "GreaterThan", "path": ["foo"], "valueNumber": 1.4}
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
def test_visit_comparison_boolean() -> None:
comp = Comparison(comparator=Comparator.NE, attribute="foo", value=False)
expected = {"operator": "NotEqual", "path": ["foo"], "valueBoolean": False}
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
def test_visit_comparison_datetime() -> None:
comp = Comparison(
comparator=Comparator.LTE,
attribute="foo",
value=datetime(2023, 9, 13, 4, 20, 0),
)
expected = {
"operator": "LessThanEqual",
"path": ["foo"],
"valueDate": "2023-09-13T04:20:00Z",
}
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
def test_visit_comparison_date() -> None:
comp = Comparison(
comparator=Comparator.LT, attribute="foo", value=date(2023, 9, 13)
)
expected = {
"operator": "LessThan",
"path": ["foo"],
"valueDate": "2023-09-13T00:00:00Z",
}
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
def test_visit_operation() -> None:
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
Comparison(comparator=Comparator.EQ, attribute="foo", value="hello"),
Comparison(
comparator=Comparator.GTE, attribute="bar", value=date(2023, 9, 13)
),
Comparison(comparator=Comparator.LTE, attribute="abc", value=1.4),
],
)
expected = {
"operands": [
{"operator": "Equal", "path": ["foo"], "valueText": 2},
{"operator": "Equal", "path": ["bar"], "valueText": "baz"},
{"operator": "Equal", "path": ["foo"], "valueText": "hello"},
{
"operator": "GreaterThanEqual",
"path": ["bar"],
"valueDate": "2023-09-13T00:00:00Z",
},
{"operator": "LessThanEqual", "path": ["abc"], "valueNumber": 1.4},
],
"operator": "And",
}
@ -78,7 +136,7 @@ def test_visit_structured_query() -> None:
"where_filter": {
"operator": "And",
"operands": [
{"path": ["foo"], "operator": "Equal", "valueText": 2},
{"path": ["foo"], "operator": "Equal", "valueInt": 2},
{"path": ["bar"], "operator": "Equal", "valueText": "baz"},
],
}

Loading…
Cancel
Save