diff --git a/acme/api/order.go b/acme/api/order.go index 679fe32f..2927a620 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -13,6 +13,7 @@ import ( "github.com/go-chi/chi" "go.step.sm/crypto/randutil" + "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" @@ -33,12 +34,20 @@ func (n *NewOrderRequest) Validate() error { return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty") } for _, id := range n.Identifiers { - if !(id.Type == acme.DNS || id.Type == acme.IP) { + switch id.Type { + case acme.IP: + if net.ParseIP(id.Value) == nil { + return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value) + } + case acme.DNS: + value, _ := trimIfWildcard(id.Value) + if _, err := x509util.SanitizeName(value); err != nil { + return acme.NewError(acme.ErrorMalformedType, "invalid DNS name: %s", id.Value) + } + default: return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type) } - if id.Type == acme.IP && net.ParseIP(id.Value) == nil { - return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value) - } + // TODO(hs): add some validations for DNS domains? // TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1 } @@ -218,13 +227,19 @@ func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error return policy.NewX509PolicyEngine(eak.Policy) } +func trimIfWildcard(value string) (string, bool) { + if strings.HasPrefix(value, "*.") { + return strings.TrimPrefix(value, "*."), true + } + return value, false +} + func newAuthorization(ctx context.Context, az *acme.Authorization) error { - if strings.HasPrefix(az.Identifier.Value, "*.") { - az.Wildcard = true - az.Identifier = acme.Identifier{ - Value: strings.TrimPrefix(az.Identifier.Value, "*."), - Type: az.Identifier.Type, - } + value, isWildcard := trimIfWildcard(az.Identifier.Value) + az.Wildcard = isWildcard + az.Identifier = acme.Identifier{ + Value: value, + Type: az.Identifier.Type, } chTypes := challengeTypes(az) diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 7f67c72e..724357d8 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -49,6 +49,36 @@ func TestNewOrderRequest_Validate(t *testing.T) { err: acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: foo"), } }, + "fail/bad-identifier/bad-dns": func(t *testing.T) test { + return test{ + nor: &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "xn--bücher.example.com"}, + }, + }, + err: acme.NewError(acme.ErrorMalformedType, "invalid DNS name: xn--bücher.example.com"), + } + }, + "fail/bad-identifier/dns-port": func(t *testing.T) test { + return test{ + nor: &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "example.com:8080"}, + }, + }, + err: acme.NewError(acme.ErrorMalformedType, "invalid DNS name: example.com:8080"), + } + }, + "fail/bad-identifier/dns-wildcard-port": func(t *testing.T) test { + return test{ + nor: &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "*.example.com:8080"}, + }, + }, + err: acme.NewError(acme.ErrorMalformedType, "invalid DNS name: *.example.com:8080"), + } + }, "fail/bad-ip": func(t *testing.T) test { nbf := time.Now().UTC().Add(time.Minute) naf := time.Now().UTC().Add(5 * time.Minute) @@ -72,7 +102,7 @@ func TestNewOrderRequest_Validate(t *testing.T) { nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, + {Type: "dns", Value: "*.bar.com"}, }, NotAfter: naf, NotBefore: nbf, @@ -2097,3 +2127,32 @@ func TestHandler_challengeTypes(t *testing.T) { }) } } + +func TestTrimIfWildcard(t *testing.T) { + tests := []struct { + name string + arg string + wantValue string + wantBool bool + }{ + { + name: "no trim", + arg: "smallstep.com", + wantValue: "smallstep.com", + wantBool: false, + }, + { + name: "trim", + arg: "*.smallstep.com", + wantValue: "smallstep.com", + wantBool: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v, ok := trimIfWildcard(tt.arg) + assert.Equals(t, v, tt.wantValue) + assert.Equals(t, ok, tt.wantBool) + }) + } +}