diff --git a/authority/config.go b/authority/config.go index a6a78523..3bc8e810 100644 --- a/authority/config.go +++ b/authority/config.go @@ -26,9 +26,9 @@ var ( } defaultDisableRenewal = false globalProvisionerClaims = ProvisionerClaims{ - MinTLSDur: &duration{5 * time.Minute}, - MaxTLSDur: &duration{24 * time.Hour}, - DefaultTLSDur: &duration{24 * time.Hour}, + MinTLSDur: &Duration{5 * time.Minute}, + MaxTLSDur: &Duration{24 * time.Hour}, + DefaultTLSDur: &Duration{24 * time.Hour}, DisableRenewal: &defaultDisableRenewal, } ) diff --git a/authority/provisioner.go b/authority/provisioner.go index 53372f11..6dd1b1ac 100644 --- a/authority/provisioner.go +++ b/authority/provisioner.go @@ -12,9 +12,9 @@ import ( // ProvisionerClaims so that individual provisioners can override global claims. type ProvisionerClaims struct { globalClaims *ProvisionerClaims - MinTLSDur *duration `json:"minTLSCertDuration,omitempty"` - MaxTLSDur *duration `json:"maxTLSCertDuration,omitempty"` - DefaultTLSDur *duration `json:"defaultTLSCertDuration,omitempty"` + MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` + MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` + DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` DisableRenewal *bool `json:"disableRenewal,omitempty"` } diff --git a/authority/types.go b/authority/types.go index d9120f59..f0a781d5 100644 --- a/authority/types.go +++ b/authority/types.go @@ -7,7 +7,8 @@ import ( "github.com/pkg/errors" ) -type duration struct { +// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal. +type Duration struct { time.Duration } @@ -16,8 +17,8 @@ type duration struct { // A duration string is a possibly signed sequence of decimal numbers, each with // optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *duration) MarshalJSON() ([]byte, error) { - return json.Marshal(d.String()) +func (d *Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.Duration.String()) } // UnmarshalJSON parses a duration string and sets it to the duration. @@ -25,14 +26,21 @@ func (d *duration) MarshalJSON() ([]byte, error) { // A duration string is a possibly signed sequence of decimal numbers, each with // optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *duration) UnmarshalJSON(data []byte) (err error) { - var s string +func (d *Duration) UnmarshalJSON(data []byte) (err error) { + var ( + s string + _d time.Duration + ) + if d == nil { + return errors.New("duration cannot be nil") + } if err = json.Unmarshal(data, &s); err != nil { return errors.Wrapf(err, "error unmarshalling %s", data) } - if d.Duration, err = time.ParseDuration(s); err != nil { + if _d, err = time.ParseDuration(s); err != nil { return errors.Wrapf(err, "error parsing %s as duration", s) } + d.Duration = _d return } diff --git a/authority/types_test.go b/authority/types_test.go index 36877dcc..c49c368f 100644 --- a/authority/types_test.go +++ b/authority/types_test.go @@ -3,6 +3,7 @@ package authority import ( "reflect" "testing" + "time" ) func Test_multiString_First(t *testing.T) { @@ -71,7 +72,6 @@ func Test_multiString_MarshalJSON(t *testing.T) { } func Test_multiString_UnmarshalJSON(t *testing.T) { - type args struct { data []byte } @@ -101,3 +101,57 @@ func Test_multiString_UnmarshalJSON(t *testing.T) { }) } } + +func TestDuration_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + d *Duration + args args + want *Duration + wantErr bool + }{ + {"empty", new(Duration), args{[]byte{}}, new(Duration), true}, + {"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true}, + {"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true}, + {"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true}, + {"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false}, + {"nil", nil, args{nil}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(tt.d, tt.want) { + t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want) + } + }) + } +} + +func Test_duration_MarshalJSON(t *testing.T) { + tests := []struct { + name string + d *Duration + want []byte + wantErr bool + }{ + {"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.d.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +}