From 99845d38bbda1310e5ef8c7e74cf372db9503938 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 4 Jan 2022 12:06:44 -0800 Subject: [PATCH] Add some extra unit tests for nebula. --- authority/provisioner/nebula_test.go | 281 ++++++++++++++++++++++++++- 1 file changed, 276 insertions(+), 5 deletions(-) diff --git a/authority/provisioner/nebula_test.go b/authority/provisioner/nebula_test.go index 4ed72171..84bf2926 100644 --- a/authority/provisioner/nebula_test.go +++ b/authority/provisioner/nebula_test.go @@ -5,6 +5,7 @@ import ( "crypto" "crypto/ed25519" "crypto/rand" + "crypto/x509" "net" "reflect" "strings" @@ -16,6 +17,7 @@ import ( "go.step.sm/crypto/randutil" "go.step.sm/crypto/x25519" "go.step.sm/crypto/x509util" + "golang.org/x/crypto/ssh" ) func mustNebulaIPNet(t *testing.T, s string) *net.IPNet { @@ -98,11 +100,14 @@ func mustNebulaProvisioner(t *testing.T) (*Nebula, *cert.NebulaCertificate, ed25 if err != nil { t.Fatal(err) } - + bTrue := true p := &Nebula{ Type: TypeNebula.String(), Name: "nebulous", Roots: ncPem, + Claims: &Claims{ + EnableSSHCA: &bTrue, + }, } if err := p.Init(Config{ Claims: globalProvisionerClaims, @@ -157,6 +162,54 @@ func mustNebulaToken(t *testing.T, sub, iss, aud string, iat time.Time, sans []s return tok } +func mustNebulaSSHToken(t *testing.T, sub, iss, aud string, iat time.Time, opts *SignSSHOptions, nc *cert.NebulaCertificate, key crypto.Signer) string { + t.Helper() + ncPEM, err := nc.MarshalToPEM() + if err != nil { + t.Fatal(err) + } + + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader(NebulaCertHeader, ncPEM) + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.XEdDSA, Key: key}, so) + if err != nil { + t.Fatal(err) + } + + id, err := randutil.ASCII(64) + if err != nil { + t.Fatal(err) + } + + claims := struct { + jose.Claims + Step *stepPayload `json:"step,omitempty"` + }{ + Claims: jose.Claims{ + ID: id, + Subject: sub, + Issuer: iss, + IssuedAt: jose.NewNumericDate(iat), + NotBefore: jose.NewNumericDate(iat), + Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), + Audience: []string{aud}, + }, + } + if opts != nil { + claims.Step = &stepPayload{ + SSH: opts, + } + } + + tok, err := jose.Signed(sig).Claims(claims).CompactSerialize() + if err != nil { + t.Fatal(err) + } + return tok +} + func TestNebula_Init(t *testing.T) { nc, _ := mustNebulaCA(t) ncPem, err := nc.MarshalToPEM() @@ -375,6 +428,7 @@ func TestNebula_AuthorizeSign(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) + okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), nil, crt, priv) pBadOptions, _, _ := mustNebulaProvisioner(t) pBadOptions.caPool = p.caPool @@ -395,6 +449,7 @@ func TestNebula_AuthorizeSign(t *testing.T) { wantErr bool }{ {"ok", p, args{ctx, ok}, false}, + {"ok no sans", p, args{ctx, okNoSANs}, false}, {"fail token", p, args{ctx, "token"}, true}, {"fail template", pBadOptions, args{ctx, ok}, true}, } @@ -410,6 +465,44 @@ func TestNebula_AuthorizeSign(t *testing.T) { } func TestNebula_AuthorizeSSHSign(t *testing.T) { + ctx := context.TODO() + // Ok provisioner + p, ca, signer := mustNebulaProvisioner(t) + crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) + ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + CertType: "host", + KeyID: "test.lan", + Principals: []string{"test.lan", "10.1.0.1"}, + }, crt, priv) + okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), nil, crt, priv) + okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)), + ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)), + }, crt, priv) + failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + CertType: "user", + }, crt, priv) + failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + CertType: "host", + KeyID: "test.lan", + Principals: []string{"test.lan", "10.1.0.1", "foo.bar"}, + }, crt, priv) + + // Provisioner with SSH disabled + var bFalse bool + pDisabled, _, _ := mustNebulaProvisioner(t) + pDisabled.caPool = p.caPool + pDisabled.Claims.EnableSSHCA = &bFalse + + // Provisioner with bad templates + pBadOptions, _, _ := mustNebulaProvisioner(t) + pBadOptions.caPool = p.caPool + pBadOptions.Options = &Options{ + SSH: &SSHOptions{ + TemplateData: []byte(`{""}`), + }, + } + type args struct { ctx context.Context token string @@ -418,20 +511,198 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) { name string p *Nebula args args - want []SignOption wantErr bool }{ - // TODO: Add test cases. + {"ok", p, args{ctx, ok}, false}, + {"ok no options", p, args{ctx, okNoOptions}, false}, + {"ok with validity", p, args{ctx, okWithValidity}, false}, + {"fail token", p, args{ctx, "token"}, true}, + {"fail user", p, args{ctx, failUserCert}, true}, + {"fail principals", p, args{ctx, failPrincipals}, true}, + {"fail disabled", pDisabled, args{ctx, ok}, true}, + {"fail template", pBadOptions, args{ctx, ok}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.p.AuthorizeSSHSign(tt.args.ctx, tt.args.token) + _, err := tt.p.AuthorizeSSHSign(tt.args.ctx, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Nebula.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } + }) + } +} + +func TestNebula_AuthorizeRenew(t *testing.T) { + ctx := context.TODO() + // Ok provisioner + p, _, _ := mustNebulaProvisioner(t) + + // Provisioner with renewal disabled + bTrue := true + pDisabled, _, _ := mustNebulaProvisioner(t) + pDisabled.Claims.DisableRenewal = &bTrue + + type args struct { + ctx context.Context + crt *x509.Certificate + } + tests := []struct { + name string + p *Nebula + args args + wantErr bool + }{ + {"ok", p, args{ctx, &x509.Certificate{}}, false}, + {"fail disabled", pDisabled, args{ctx, &x509.Certificate{}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.p.AuthorizeRenew(tt.args.ctx, tt.args.crt); (err != nil) != tt.wantErr { + t.Errorf("Nebula.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestNebula_AuthorizeRevoke(t *testing.T) { + ctx := context.TODO() + // Ok provisioner + p, ca, signer := mustNebulaProvisioner(t) + crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) + ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) + + // Fail different CA + nc, signer := mustNebulaCA(t) + crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) + failToken := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) + + type args struct { + ctx context.Context + token string + } + tests := []struct { + name string + p *Nebula + args args + wantErr bool + }{ + {"ok", p, args{ctx, ok}, false}, + {"fail token", p, args{ctx, failToken}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.p.AuthorizeRevoke(tt.args.ctx, tt.args.token); (err != nil) != tt.wantErr { + t.Errorf("Nebula.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestNebula_AuthorizeSSHRevoke(t *testing.T) { + ctx := context.TODO() + // Ok provisioner + p, ca, signer := mustNebulaProvisioner(t) + crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) + ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) + + // Fail different CA + nc, signer := mustNebulaCA(t) + crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) + failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) + + // Provisioner with SSH disabled + var bFalse bool + pDisabled, _, _ := mustNebulaProvisioner(t) + pDisabled.caPool = p.caPool + pDisabled.Claims.EnableSSHCA = &bFalse + + type args struct { + ctx context.Context + token string + } + tests := []struct { + name string + p *Nebula + args args + wantErr bool + }{ + {"ok", p, args{ctx, ok}, false}, + {"fail token", p, args{ctx, failToken}, true}, + {"fail disabled", pDisabled, args{ctx, ok}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.p.AuthorizeSSHRevoke(tt.args.ctx, tt.args.token); (err != nil) != tt.wantErr { + t.Errorf("Nebula.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestNebula_AuthorizeSSHRenew(t *testing.T) { + p, ca, signer := mustNebulaProvisioner(t) + crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) + t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRenew[0], now(), nil, crt, priv) + + type args struct { + ctx context.Context + token string + } + tests := []struct { + name string + p *Nebula + args args + want *ssh.Certificate + wantErr bool + }{ + {"fail", p, args{context.TODO(), t1}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.p.AuthorizeSSHRenew(tt.args.ctx, tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("Nebula.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Nebula.AuthorizeSSHRenew() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNebula_AuthorizeSSHRekey(t *testing.T) { + p, ca, signer := mustNebulaProvisioner(t) + crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) + t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRekey[0], now(), nil, crt, priv) + + type args struct { + ctx context.Context + token string + } + tests := []struct { + name string + p *Nebula + args args + want *ssh.Certificate + want1 []SignOption + wantErr bool + }{ + {"fail", p, args{context.TODO(), t1}, nil, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1, err := tt.p.AuthorizeSSHRekey(tt.args.ctx, tt.args.token) + if (err != nil) != tt.wantErr { + t.Errorf("Nebula.AuthorizeSSHRekey() error = %v, wantErr %v", err, tt.wantErr) + return + } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Nebula.AuthorizeSSHSign() = %v, want %v", got, tt.want) + t.Errorf("Nebula.AuthorizeSSHRekey() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(got1, tt.want1) { + t.Errorf("Nebula.AuthorizeSSHRekey() got1 = %v, want %v", got1, tt.want1) } }) }