You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/chains/query_constructor/parser.py

133 lines
3.9 KiB
Python

from typing import Any, Optional, Sequence, Union
try:
import lark
from packaging import version
if version.parse(lark.__version__) < version.parse("1.1.5"):
raise ValueError(
f"Lark should be at least version 1.1.5, got {lark.__version__}"
)
from lark import Lark, Transformer, v_args
except ImportError:
def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore
return lambda _: None
Transformer = object # type: ignore
Lark = object # type: ignore
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
FilterDirective,
Operation,
Operator,
)
GRAMMAR = """
?program: func_call
?expr: func_call
| value
func_call: CNAME "(" [args] ")"
?value: SIGNED_INT -> int
| SIGNED_FLOAT -> float
| list
| string
| ("false" | "False" | "FALSE") -> false
| ("true" | "True" | "TRUE") -> true
args: expr ("," expr)*
string: /'[^']*'/ | ESCAPED_STRING
list: "[" [args] "]"
%import common.CNAME
%import common.ESCAPED_STRING
%import common.SIGNED_FLOAT
%import common.SIGNED_INT
%import common.WS
%ignore WS
"""
@v_args(inline=True)
class QueryTransformer(Transformer):
def __init__(
self,
*args: Any,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.allowed_comparators = allowed_comparators
self.allowed_operators = allowed_operators
def program(self, *items: Any) -> tuple:
return items
def func_call(self, func_name: Any, *args: Any) -> FilterDirective:
func = self._match_func_name(str(func_name))
if isinstance(func, Comparator):
return Comparison(comparator=func, attribute=args[0][0], value=args[0][1])
return Operation(operator=func, arguments=args[0])
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
if func_name in set(Comparator):
if self.allowed_comparators is not None:
if func_name not in self.allowed_comparators:
raise ValueError(
f"Received disallowed comparator {func_name}. Allowed "
f"comparators are {self.allowed_comparators}"
)
return Comparator(func_name)
elif func_name in set(Operator):
if self.allowed_operators is not None:
if func_name not in self.allowed_operators:
raise ValueError(
f"Received disallowed operator {func_name}. Allowed operators"
f" are {self.allowed_operators}"
)
return Operator(func_name)
else:
raise ValueError(
f"Received unrecognized function {func_name}. Valid functions are "
f"{list(Operator) + list(Comparator)}"
)
def args(self, *items: Any) -> tuple:
return items
def false(self) -> bool:
return False
def true(self) -> bool:
return True
def list(self, item: Any) -> list:
if item is None:
return []
return list(item)
def int(self, item: Any) -> int:
return int(item)
def float(self, item: Any) -> float:
return float(item)
def string(self, item: Any) -> str:
# Remove escaped quotes
return str(item).strip("\"'")
def get_parser(
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
) -> Lark:
transformer = QueryTransformer(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
)
return Lark(GRAMMAR, parser="lalr", transformer=transformer, start="program")