diff --git a/scep/api/api.go b/scep/api/api.go index c3159a71..6747b7cc 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -171,10 +171,7 @@ func decodeRequest(r *http.Request) (request, error) { }, nil case opnPKIOperation: message := query.Get("message") - if message == "" { - return request{}, errors.New("message must not be empty") - } - decodedMessage, err := base64.StdEncoding.DecodeString(message) + decodedMessage, err := decodeMessage(message, r) if err != nil { return request{}, fmt.Errorf("failed decoding message: %w", err) } @@ -199,6 +196,76 @@ func decodeRequest(r *http.Request) (request, error) { } } +func decodeMessage(message string, r *http.Request) ([]byte, error) { + if message == "" { + return nil, errors.New("message must not be empty") + } + + // decode the message, which should be base64 standard encoded. Any characters that + // were escaped in the original query, were unescaped as part of url.ParseQuery, so + // that doesn't need to be performed here. Return early if successfull. + decodedMessage, err := base64.StdEncoding.DecodeString(message) + if err == nil { + return decodedMessage, nil + } + + // only interested in corrupt input errors below this. This type of error is the + // most likely to return, but better safe than sorry. + if _, ok := err.(base64.CorruptInputError); !ok { + return nil, fmt.Errorf("failed base64 decoding message: %w", err) + } + + // the below code is a workaround for macOS when it sends a GET PKIOperation, which seems to result + // in a query with the '+' and '/' not being percent encoded; only the padding ('=') is encoded. + // When that is unescaped in the code before this, this results in invalid base64. The workaround + // is to obtain the original query, extract the message, apply transformation(s) to make it valid + // base64 and try decoding it again. If it succeeds, the happy path can be followed with the patched + // message. Otherwise we still return an error. + rawQuery, err := parseRawQuery(r.URL.RawQuery) + if err != nil { + return nil, fmt.Errorf("failed to parse raw query: %w", err) + } + + rawMessage := rawQuery.Get("message") + if rawMessage == "" { + return nil, errors.New("no message in raw query") + } + + rawMessage = strings.ReplaceAll(rawMessage, "%3D", "=") // apparently the padding arrives encoded; the others (+, /) not? + decodedMessage, err = base64.StdEncoding.DecodeString(rawMessage) + if err != nil { + return nil, fmt.Errorf("failed decoding raw message: %w", err) + } + + return decodedMessage, nil +} + +// parseRawQuery parses a URL query into url.Values. It skips +// unescaping keys and values. This code is based on url.ParseQuery. +func parseRawQuery(query string) (url.Values, error) { + m := make(url.Values) + err := parseRawQueryWithoutUnescaping(m, query) + return m, err +} + +// parseRawQueryWithoutUnescaping parses the raw query into url.Values, skipping +// unescaping of the parts. This code is based on url.parseQuery. +func parseRawQueryWithoutUnescaping(m url.Values, query string) (err error) { + for query != "" { + var key string + key, query, _ = strings.Cut(query, "&") + if strings.Contains(key, ";") { + return errors.New("invalid semicolon separator in query") + } + if key == "" { + continue + } + key, value, _ := strings.Cut(key, "=") + m[key] = append(m[key], value) + } + return err +} + // lookupProvisioner loads the provisioner associated with the request. // Responds 404 if the provisioner does not exist. func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { diff --git a/scep/api/api_test.go b/scep/api/api_test.go index 2a26f534..a1782933 100644 --- a/scep/api/api_test.go +++ b/scep/api/api_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "testing/iotest" @@ -20,6 +21,9 @@ func Test_decodeRequest(t *testing.T) { randomB64 := "wx/1mQ49TpdLRfvVjQhXNSe8RB3hjZEarqYp5XVIxpSbvOhQSs8hP2TgucID1IputbA8JC6CbsUpcVae3+8hRNqs5pTsSHP2aNxsw8AHGSX9dZVymSclkUV8irk+ztfEfs7aLA==" expectedRandom, err := base64.StdEncoding.DecodeString(randomB64) require.NoError(t, err) + weirdMacOSCase := "wx/1mQ49TpdLRfvVjQhXNSe8RB3hjZEarqYp5XVIxpSbvOhQSs8hP2TgucID1IputbA8JC6CbsUpcVae3+8hRNqs5pTsSHP2aNxsw8AHGSX9dZVymSclkUV8irk+ztfEfs7aLA%3D%3D" + expectedWeirdMacOSCase, err := base64.StdEncoding.DecodeString(strings.ReplaceAll(weirdMacOSCase, "%3D", "=")) + require.NoError(t, err) type args struct { r *http.Request } @@ -77,14 +81,6 @@ func Test_decodeRequest(t *testing.T) { want: request{}, wantErr: true, }, - { - name: "fail/get-PKIOperation-not-escaped", - args: args{ - r: httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://scep:8080/?operation=PKIOperation&message=%s", randomB64), http.NoBody), - }, - want: request{}, - wantErr: true, - }, { name: "fail/post-PKIOperation", args: args{ @@ -137,6 +133,28 @@ func Test_decodeRequest(t *testing.T) { }, wantErr: false, }, + { + name: "ok/get-PKIOperation-not-escaped", // bit of a special case, but this is supported because of the macOS case now + args: args{ + r: httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://scep:8080/?operation=PKIOperation&message=%s", randomB64), http.NoBody), + }, + want: request{ + Operation: "PKIOperation", + Message: expectedRandom, + }, + wantErr: false, + }, + { + name: "ok/get-PKIOperation-weird-macos-case", // a special case for macOS, which seems to result in the message not arriving fully percent-encoded + args: args{ + r: httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://scep:8080/?operation=PKIOperation&message=%s", weirdMacOSCase), http.NoBody), + }, + want: request{ + Operation: "PKIOperation", + Message: expectedWeirdMacOSCase, + }, + wantErr: false, + }, { name: "ok/post-PKIOperation", args: args{