mirror of
https://github.com/miguelmota/cointop
synced 2024-11-10 13:10:26 +00:00
616 lines
14 KiB
Go
616 lines
14 KiB
Go
package checker
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"github.com/antonmedv/expr/ast"
|
|
"github.com/antonmedv/expr/conf"
|
|
"github.com/antonmedv/expr/file"
|
|
"github.com/antonmedv/expr/parser"
|
|
)
|
|
|
|
var errorType = reflect.TypeOf((*error)(nil)).Elem()
|
|
|
|
func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) {
|
|
v := &visitor{
|
|
collections: make([]reflect.Type, 0),
|
|
}
|
|
if config != nil {
|
|
v.types = config.Types
|
|
v.operators = config.Operators
|
|
v.expect = config.Expect
|
|
v.strict = config.Strict
|
|
v.defaultType = config.DefaultType
|
|
}
|
|
|
|
t := v.visit(tree.Node)
|
|
|
|
if v.expect != reflect.Invalid {
|
|
switch v.expect {
|
|
case reflect.Int64, reflect.Float64:
|
|
if !isNumber(t) {
|
|
return nil, fmt.Errorf("expected %v, but got %v", v.expect, t)
|
|
}
|
|
default:
|
|
if t.Kind() != v.expect {
|
|
return nil, fmt.Errorf("expected %v, but got %v", v.expect, t)
|
|
}
|
|
}
|
|
}
|
|
|
|
if v.err != nil {
|
|
return t, v.err.Bind(tree.Source)
|
|
}
|
|
|
|
return t, nil
|
|
}
|
|
|
|
type visitor struct {
|
|
types conf.TypesTable
|
|
operators conf.OperatorsTable
|
|
expect reflect.Kind
|
|
collections []reflect.Type
|
|
strict bool
|
|
defaultType reflect.Type
|
|
err *file.Error
|
|
}
|
|
|
|
func (v *visitor) visit(node ast.Node) reflect.Type {
|
|
var t reflect.Type
|
|
switch n := node.(type) {
|
|
case *ast.NilNode:
|
|
t = v.NilNode(n)
|
|
case *ast.IdentifierNode:
|
|
t = v.IdentifierNode(n)
|
|
case *ast.IntegerNode:
|
|
t = v.IntegerNode(n)
|
|
case *ast.FloatNode:
|
|
t = v.FloatNode(n)
|
|
case *ast.BoolNode:
|
|
t = v.BoolNode(n)
|
|
case *ast.StringNode:
|
|
t = v.StringNode(n)
|
|
case *ast.ConstantNode:
|
|
t = v.ConstantNode(n)
|
|
case *ast.UnaryNode:
|
|
t = v.UnaryNode(n)
|
|
case *ast.BinaryNode:
|
|
t = v.BinaryNode(n)
|
|
case *ast.MatchesNode:
|
|
t = v.MatchesNode(n)
|
|
case *ast.PropertyNode:
|
|
t = v.PropertyNode(n)
|
|
case *ast.IndexNode:
|
|
t = v.IndexNode(n)
|
|
case *ast.SliceNode:
|
|
t = v.SliceNode(n)
|
|
case *ast.MethodNode:
|
|
t = v.MethodNode(n)
|
|
case *ast.FunctionNode:
|
|
t = v.FunctionNode(n)
|
|
case *ast.BuiltinNode:
|
|
t = v.BuiltinNode(n)
|
|
case *ast.ClosureNode:
|
|
t = v.ClosureNode(n)
|
|
case *ast.PointerNode:
|
|
t = v.PointerNode(n)
|
|
case *ast.ConditionalNode:
|
|
t = v.ConditionalNode(n)
|
|
case *ast.ArrayNode:
|
|
t = v.ArrayNode(n)
|
|
case *ast.MapNode:
|
|
t = v.MapNode(n)
|
|
case *ast.PairNode:
|
|
t = v.PairNode(n)
|
|
default:
|
|
panic(fmt.Sprintf("undefined node type (%T)", node))
|
|
}
|
|
node.SetType(t)
|
|
return t
|
|
}
|
|
|
|
func (v *visitor) error(node ast.Node, format string, args ...interface{}) reflect.Type {
|
|
if v.err == nil { // show first error
|
|
v.err = &file.Error{
|
|
Location: node.Location(),
|
|
Message: fmt.Sprintf(format, args...),
|
|
}
|
|
}
|
|
return interfaceType // interface represent undefined type
|
|
}
|
|
|
|
func (v *visitor) NilNode(*ast.NilNode) reflect.Type {
|
|
return nilType
|
|
}
|
|
|
|
func (v *visitor) IdentifierNode(node *ast.IdentifierNode) reflect.Type {
|
|
if v.types == nil {
|
|
return interfaceType
|
|
}
|
|
if t, ok := v.types[node.Value]; ok {
|
|
if t.Ambiguous {
|
|
return v.error(node, "ambiguous identifier %v", node.Value)
|
|
}
|
|
return t.Type
|
|
}
|
|
if !v.strict {
|
|
if v.defaultType != nil {
|
|
return v.defaultType
|
|
}
|
|
return interfaceType
|
|
}
|
|
if !node.NilSafe {
|
|
return v.error(node, "unknown name %v", node.Value)
|
|
}
|
|
return nilType
|
|
}
|
|
|
|
func (v *visitor) IntegerNode(*ast.IntegerNode) reflect.Type {
|
|
return integerType
|
|
}
|
|
|
|
func (v *visitor) FloatNode(*ast.FloatNode) reflect.Type {
|
|
return floatType
|
|
}
|
|
|
|
func (v *visitor) BoolNode(*ast.BoolNode) reflect.Type {
|
|
return boolType
|
|
}
|
|
|
|
func (v *visitor) StringNode(*ast.StringNode) reflect.Type {
|
|
return stringType
|
|
}
|
|
|
|
func (v *visitor) ConstantNode(node *ast.ConstantNode) reflect.Type {
|
|
return reflect.TypeOf(node.Value)
|
|
}
|
|
|
|
func (v *visitor) UnaryNode(node *ast.UnaryNode) reflect.Type {
|
|
t := v.visit(node.Node)
|
|
|
|
switch node.Operator {
|
|
|
|
case "!", "not":
|
|
if isBool(t) {
|
|
return boolType
|
|
}
|
|
|
|
case "+", "-":
|
|
if isNumber(t) {
|
|
return t
|
|
}
|
|
|
|
default:
|
|
return v.error(node, "unknown operator (%v)", node.Operator)
|
|
}
|
|
|
|
return v.error(node, `invalid operation: %v (mismatched type %v)`, node.Operator, t)
|
|
}
|
|
|
|
func (v *visitor) BinaryNode(node *ast.BinaryNode) reflect.Type {
|
|
l := v.visit(node.Left)
|
|
r := v.visit(node.Right)
|
|
|
|
// check operator overloading
|
|
if fns, ok := v.operators[node.Operator]; ok {
|
|
t, _, ok := conf.FindSuitableOperatorOverload(fns, v.types, l, r)
|
|
if ok {
|
|
return t
|
|
}
|
|
}
|
|
|
|
switch node.Operator {
|
|
case "==", "!=":
|
|
if isNumber(l) && isNumber(r) {
|
|
return boolType
|
|
}
|
|
if isComparable(l, r) {
|
|
return boolType
|
|
}
|
|
|
|
case "or", "||", "and", "&&":
|
|
if isBool(l) && isBool(r) {
|
|
return boolType
|
|
}
|
|
|
|
case "in", "not in":
|
|
if isString(l) && isStruct(r) {
|
|
return boolType
|
|
}
|
|
if isMap(r) {
|
|
return boolType
|
|
}
|
|
if isArray(r) {
|
|
return boolType
|
|
}
|
|
|
|
case "<", ">", ">=", "<=":
|
|
if isNumber(l) && isNumber(r) {
|
|
return boolType
|
|
}
|
|
if isString(l) && isString(r) {
|
|
return boolType
|
|
}
|
|
|
|
case "/", "-", "*":
|
|
if isNumber(l) && isNumber(r) {
|
|
return combined(l, r)
|
|
}
|
|
|
|
case "**":
|
|
if isNumber(l) && isNumber(r) {
|
|
return floatType
|
|
}
|
|
|
|
case "%":
|
|
if isInteger(l) && isInteger(r) {
|
|
return combined(l, r)
|
|
}
|
|
|
|
case "+":
|
|
if isNumber(l) && isNumber(r) {
|
|
return combined(l, r)
|
|
}
|
|
if isString(l) && isString(r) {
|
|
return stringType
|
|
}
|
|
|
|
case "contains", "startsWith", "endsWith":
|
|
if isString(l) && isString(r) {
|
|
return boolType
|
|
}
|
|
|
|
case "..":
|
|
if isInteger(l) && isInteger(r) {
|
|
return reflect.SliceOf(integerType)
|
|
}
|
|
|
|
default:
|
|
return v.error(node, "unknown operator (%v)", node.Operator)
|
|
|
|
}
|
|
|
|
return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r)
|
|
}
|
|
|
|
func (v *visitor) MatchesNode(node *ast.MatchesNode) reflect.Type {
|
|
l := v.visit(node.Left)
|
|
r := v.visit(node.Right)
|
|
|
|
if isString(l) && isString(r) {
|
|
return boolType
|
|
}
|
|
|
|
return v.error(node, `invalid operation: matches (mismatched types %v and %v)`, l, r)
|
|
}
|
|
|
|
func (v *visitor) PropertyNode(node *ast.PropertyNode) reflect.Type {
|
|
t := v.visit(node.Node)
|
|
if t, ok := fieldType(t, node.Property); ok {
|
|
return t
|
|
}
|
|
if !node.NilSafe {
|
|
return v.error(node, "type %v has no field %v", t, node.Property)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (v *visitor) IndexNode(node *ast.IndexNode) reflect.Type {
|
|
t := v.visit(node.Node)
|
|
i := v.visit(node.Index)
|
|
|
|
if t, ok := indexType(t); ok {
|
|
if !isInteger(i) && !isString(i) {
|
|
return v.error(node, "invalid operation: cannot use %v as index to %v", i, t)
|
|
}
|
|
return t
|
|
}
|
|
|
|
return v.error(node, "invalid operation: type %v does not support indexing", t)
|
|
}
|
|
|
|
func (v *visitor) SliceNode(node *ast.SliceNode) reflect.Type {
|
|
t := v.visit(node.Node)
|
|
|
|
_, isIndex := indexType(t)
|
|
|
|
if isIndex || isString(t) {
|
|
if node.From != nil {
|
|
from := v.visit(node.From)
|
|
if !isInteger(from) {
|
|
return v.error(node.From, "invalid operation: non-integer slice index %v", from)
|
|
}
|
|
}
|
|
if node.To != nil {
|
|
to := v.visit(node.To)
|
|
if !isInteger(to) {
|
|
return v.error(node.To, "invalid operation: non-integer slice index %v", to)
|
|
}
|
|
}
|
|
return t
|
|
}
|
|
|
|
return v.error(node, "invalid operation: cannot slice %v", t)
|
|
}
|
|
|
|
func (v *visitor) FunctionNode(node *ast.FunctionNode) reflect.Type {
|
|
if f, ok := v.types[node.Name]; ok {
|
|
if fn, ok := isFuncType(f.Type); ok {
|
|
|
|
inputParamsCount := 1 // for functions
|
|
if f.Method {
|
|
inputParamsCount = 2 // for methods
|
|
}
|
|
|
|
if !isInterface(fn) &&
|
|
fn.IsVariadic() &&
|
|
fn.NumIn() == inputParamsCount &&
|
|
((fn.NumOut() == 1 && // Function with one return value
|
|
fn.Out(0).Kind() == reflect.Interface) ||
|
|
(fn.NumOut() == 2 && // Function with one return value and an error
|
|
fn.Out(0).Kind() == reflect.Interface &&
|
|
fn.Out(1) == errorType)) {
|
|
rest := fn.In(fn.NumIn() - 1) // function has only one param for functions and two for methods
|
|
if rest.Kind() == reflect.Slice && rest.Elem().Kind() == reflect.Interface {
|
|
node.Fast = true
|
|
}
|
|
}
|
|
|
|
return v.checkFunc(fn, f.Method, node, node.Name, node.Arguments)
|
|
}
|
|
}
|
|
if !v.strict {
|
|
if v.defaultType != nil {
|
|
return v.defaultType
|
|
}
|
|
return interfaceType
|
|
}
|
|
return v.error(node, "unknown func %v", node.Name)
|
|
}
|
|
|
|
func (v *visitor) MethodNode(node *ast.MethodNode) reflect.Type {
|
|
t := v.visit(node.Node)
|
|
if f, method, ok := methodType(t, node.Method); ok {
|
|
if fn, ok := isFuncType(f); ok {
|
|
return v.checkFunc(fn, method, node, node.Method, node.Arguments)
|
|
}
|
|
}
|
|
if !node.NilSafe {
|
|
return v.error(node, "type %v has no method %v", t, node.Method)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// checkFunc checks func arguments and returns "return type" of func or method.
|
|
func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name string, arguments []ast.Node) reflect.Type {
|
|
if isInterface(fn) {
|
|
return interfaceType
|
|
}
|
|
|
|
if fn.NumOut() == 0 {
|
|
return v.error(node, "func %v doesn't return value", name)
|
|
}
|
|
if numOut := fn.NumOut(); numOut > 2 {
|
|
return v.error(node, "func %v returns more then two values", name)
|
|
}
|
|
|
|
numIn := fn.NumIn()
|
|
|
|
// If func is method on an env, first argument should be a receiver,
|
|
// and actual arguments less then numIn by one.
|
|
if method {
|
|
numIn--
|
|
}
|
|
|
|
if fn.IsVariadic() {
|
|
if len(arguments) < numIn-1 {
|
|
return v.error(node, "not enough arguments to call %v", name)
|
|
}
|
|
} else {
|
|
if len(arguments) > numIn {
|
|
return v.error(node, "too many arguments to call %v", name)
|
|
}
|
|
if len(arguments) < numIn {
|
|
return v.error(node, "not enough arguments to call %v", name)
|
|
}
|
|
}
|
|
|
|
offset := 0
|
|
|
|
// Skip first argument in case of the receiver.
|
|
if method {
|
|
offset = 1
|
|
}
|
|
|
|
for i, arg := range arguments {
|
|
t := v.visit(arg)
|
|
|
|
var in reflect.Type
|
|
if fn.IsVariadic() && i >= numIn-1 {
|
|
// For variadic arguments fn(xs ...int), go replaces type of xs (int) with ([]int).
|
|
// As we compare arguments one by one, we need underling type.
|
|
in = fn.In(fn.NumIn() - 1)
|
|
in, _ = indexType(in)
|
|
} else {
|
|
in = fn.In(i + offset)
|
|
}
|
|
|
|
if isIntegerOrArithmeticOperation(arg) {
|
|
t = in
|
|
setTypeForIntegers(arg, t)
|
|
}
|
|
|
|
if t == nil {
|
|
continue
|
|
}
|
|
|
|
if !t.AssignableTo(in) && t.Kind() != reflect.Interface {
|
|
return v.error(arg, "cannot use %v as argument (type %v) to call %v ", t, in, name)
|
|
}
|
|
}
|
|
|
|
return fn.Out(0)
|
|
}
|
|
|
|
func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type {
|
|
switch node.Name {
|
|
|
|
case "len":
|
|
param := v.visit(node.Arguments[0])
|
|
if isArray(param) || isMap(param) || isString(param) {
|
|
return integerType
|
|
}
|
|
return v.error(node, "invalid argument for len (type %v)", param)
|
|
|
|
case "all", "none", "any", "one":
|
|
collection := v.visit(node.Arguments[0])
|
|
if !isArray(collection) {
|
|
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
|
|
}
|
|
|
|
v.collections = append(v.collections, collection)
|
|
closure := v.visit(node.Arguments[1])
|
|
v.collections = v.collections[:len(v.collections)-1]
|
|
|
|
if isFunc(closure) &&
|
|
closure.NumOut() == 1 &&
|
|
closure.NumIn() == 1 && isInterface(closure.In(0)) {
|
|
|
|
if !isBool(closure.Out(0)) {
|
|
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
|
|
}
|
|
return boolType
|
|
}
|
|
return v.error(node.Arguments[1], "closure should has one input and one output param")
|
|
|
|
case "filter":
|
|
collection := v.visit(node.Arguments[0])
|
|
if !isArray(collection) {
|
|
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
|
|
}
|
|
|
|
v.collections = append(v.collections, collection)
|
|
closure := v.visit(node.Arguments[1])
|
|
v.collections = v.collections[:len(v.collections)-1]
|
|
|
|
if isFunc(closure) &&
|
|
closure.NumOut() == 1 &&
|
|
closure.NumIn() == 1 && isInterface(closure.In(0)) {
|
|
|
|
if !isBool(closure.Out(0)) {
|
|
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
|
|
}
|
|
if isInterface(collection) {
|
|
return arrayType
|
|
}
|
|
return reflect.SliceOf(collection.Elem())
|
|
}
|
|
return v.error(node.Arguments[1], "closure should has one input and one output param")
|
|
|
|
case "map":
|
|
collection := v.visit(node.Arguments[0])
|
|
if !isArray(collection) {
|
|
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
|
|
}
|
|
|
|
v.collections = append(v.collections, collection)
|
|
closure := v.visit(node.Arguments[1])
|
|
v.collections = v.collections[:len(v.collections)-1]
|
|
|
|
if isFunc(closure) &&
|
|
closure.NumOut() == 1 &&
|
|
closure.NumIn() == 1 && isInterface(closure.In(0)) {
|
|
|
|
return reflect.SliceOf(closure.Out(0))
|
|
}
|
|
return v.error(node.Arguments[1], "closure should has one input and one output param")
|
|
|
|
case "count":
|
|
collection := v.visit(node.Arguments[0])
|
|
if !isArray(collection) {
|
|
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
|
|
}
|
|
|
|
v.collections = append(v.collections, collection)
|
|
closure := v.visit(node.Arguments[1])
|
|
v.collections = v.collections[:len(v.collections)-1]
|
|
|
|
if isFunc(closure) &&
|
|
closure.NumOut() == 1 &&
|
|
closure.NumIn() == 1 && isInterface(closure.In(0)) {
|
|
if !isBool(closure.Out(0)) {
|
|
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
|
|
}
|
|
|
|
return integerType
|
|
}
|
|
return v.error(node.Arguments[1], "closure should has one input and one output param")
|
|
|
|
default:
|
|
return v.error(node, "unknown builtin %v", node.Name)
|
|
}
|
|
}
|
|
|
|
func (v *visitor) ClosureNode(node *ast.ClosureNode) reflect.Type {
|
|
t := v.visit(node.Node)
|
|
return reflect.FuncOf([]reflect.Type{interfaceType}, []reflect.Type{t}, false)
|
|
}
|
|
|
|
func (v *visitor) PointerNode(node *ast.PointerNode) reflect.Type {
|
|
if len(v.collections) == 0 {
|
|
return v.error(node, "cannot use pointer accessor outside closure")
|
|
}
|
|
|
|
collection := v.collections[len(v.collections)-1]
|
|
|
|
if t, ok := indexType(collection); ok {
|
|
return t
|
|
}
|
|
return v.error(node, "cannot use %v as array", collection)
|
|
}
|
|
|
|
func (v *visitor) ConditionalNode(node *ast.ConditionalNode) reflect.Type {
|
|
c := v.visit(node.Cond)
|
|
if !isBool(c) {
|
|
return v.error(node.Cond, "non-bool expression (type %v) used as condition", c)
|
|
}
|
|
|
|
t1 := v.visit(node.Exp1)
|
|
t2 := v.visit(node.Exp2)
|
|
|
|
if t1 == nil && t2 != nil {
|
|
return t2
|
|
}
|
|
if t1 != nil && t2 == nil {
|
|
return t1
|
|
}
|
|
if t1 == nil && t2 == nil {
|
|
return nilType
|
|
}
|
|
if t1.AssignableTo(t2) {
|
|
return t1
|
|
}
|
|
return interfaceType
|
|
}
|
|
|
|
func (v *visitor) ArrayNode(node *ast.ArrayNode) reflect.Type {
|
|
for _, node := range node.Nodes {
|
|
_ = v.visit(node)
|
|
}
|
|
return arrayType
|
|
}
|
|
|
|
func (v *visitor) MapNode(node *ast.MapNode) reflect.Type {
|
|
for _, pair := range node.Pairs {
|
|
v.visit(pair)
|
|
}
|
|
return mapType
|
|
}
|
|
|
|
func (v *visitor) PairNode(node *ast.PairNode) reflect.Type {
|
|
v.visit(node.Key)
|
|
v.visit(node.Value)
|
|
return nilType
|
|
}
|