diff --git a/libs/langchain/langchain/chains/query_constructor/parser.py b/libs/langchain/langchain/chains/query_constructor/parser.py index af65c99ff2..d0ac44cab7 100644 --- a/libs/langchain/langchain/chains/query_constructor/parser.py +++ b/libs/langchain/langchain/chains/query_constructor/parser.py @@ -35,6 +35,7 @@ GRAMMAR = r""" ?value: SIGNED_INT -> int | SIGNED_FLOAT -> float | DATE -> date + | DATETIME -> datetime | list | string | ("false" | "False" | "FALSE") -> false @@ -42,6 +43,7 @@ GRAMMAR = r""" args: expr ("," expr)* DATE.2: /["']?(\d{4}-[01]\d-[0-3]\d)["']?/ + DATETIME.2: /["']?\d{4}-[01]\d-[0-3]\dT[0-2]\d:[0-5]\d:[0-5]\d[Zz]?["']?/ string: /'[^']*'/ | ESCAPED_STRING list: "[" [args] "]" @@ -61,6 +63,13 @@ class ISO8601Date(TypedDict): type: Literal["date"] +class ISO8601DateTime(TypedDict): + """A datetime in ISO 8601 format (YYYY-MM-DDTHH:MM:SS).""" + + datetime: str + type: Literal["datetime"] + + @v_args(inline=True) class QueryTransformer(Transformer): """Transform a query string into an intermediate representation.""" @@ -149,6 +158,20 @@ class QueryTransformer(Transformer): ) return {"date": item, "type": "date"} + def datetime(self, item: Any) -> ISO8601DateTime: + item = str(item).strip("\"'") + try: + # Parse full ISO 8601 datetime format + datetime.datetime.strptime(item, "%Y-%m-%dT%H:%M:%S%z") + except ValueError: + try: + datetime.datetime.strptime(item, "%Y-%m-%dT%H:%M:%S") + except ValueError: + raise ValueError( + "Datetime values are expected to be in ISO 8601 format." + ) + return {"datetime": item, "type": "datetime"} + def string(self, item: Any) -> str: # Remove escaped quotes return str(item).strip("\"'") diff --git a/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py b/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py index a25eb3c0a5..9bcef121a5 100644 --- a/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py +++ b/libs/langchain/tests/unit_tests/chains/query_constructor/test_parser.py @@ -130,3 +130,40 @@ def test_parse_date_value(x: str) -> None: parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})')) actual = parsed.value["date"] assert actual == x.strip("'\"") + + +@pytest.mark.parametrize( + "x, expected", + [ + ( + '"2021-01-01T00:00:00"', + {"datetime": "2021-01-01T00:00:00", "type": "datetime"}, + ), + ( + '"2021-12-31T23:59:59Z"', + {"datetime": "2021-12-31T23:59:59Z", "type": "datetime"}, + ), + ( + '"invalid-datetime"', + None, # Expecting failure or handling of invalid input + ), + ], +) +def test_parse_datetime_value(x: str, expected: dict) -> None: + """Test parsing of datetime values with ISO 8601 format.""" + try: + parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("publishedAt", {x})')) + actual = parsed.value + assert actual == expected, f"Expected {expected}, got {actual}" + except ValueError as e: + # Handling the case where parsing should fail + if expected is None: + assert True # Correctly raised an error for invalid input + else: + pytest.fail(f"Unexpected error {e} for input {x}") + except Exception as e: + # If any other unexpected exception type is raised + if expected is None: + assert True # Correctly identified that input was invalid + else: + pytest.fail(f"Unhandled exception {e} for input {x}")