diff --git a/authority/config.go b/authority/config.go index 16a10a77..029d5ebe 100644 --- a/authority/config.go +++ b/authority/config.go @@ -29,9 +29,9 @@ var ( maxTLSDur = 24 * time.Hour defaultTLSDur = 24 * time.Hour globalProvisionerClaims = ProvisionerClaims{ - MinTLSDur: (*duration)(&minTLSDur), - MaxTLSDur: (*duration)(&maxTLSDur), - DefaultTLSDur: (*duration)(&defaultTLSDur), + MinTLSDur: &Duration{5 * time.Minute}, + MaxTLSDur: &Duration{24 * time.Hour}, + DefaultTLSDur: &Duration{24 * time.Hour}, DisableRenewal: &defaultDisableRenewal, } ) diff --git a/authority/provisioner.go b/authority/provisioner.go index e3dc7d1a..6dd1b1ac 100644 --- a/authority/provisioner.go +++ b/authority/provisioner.go @@ -12,9 +12,9 @@ import ( // ProvisionerClaims so that individual provisioners can override global claims. type ProvisionerClaims struct { globalClaims *ProvisionerClaims - MinTLSDur *duration `json:"minTLSCertDuration,omitempty"` - MaxTLSDur *duration `json:"maxTLSCertDuration,omitempty"` - DefaultTLSDur *duration `json:"defaultTLSCertDuration,omitempty"` + MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` + MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` + DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` DisableRenewal *bool `json:"disableRenewal,omitempty"` } @@ -32,30 +32,30 @@ func (pc *ProvisionerClaims) Init(global *ProvisionerClaims) (*ProvisionerClaims // provisioner. If the default is not set within the provisioner, then the global // default from the authority configuration will be used. func (pc *ProvisionerClaims) DefaultTLSCertDuration() time.Duration { - if pc.DefaultTLSDur == nil || *pc.DefaultTLSDur == 0 { + if pc.DefaultTLSDur == nil || pc.DefaultTLSDur.Duration == 0 { return pc.globalClaims.DefaultTLSCertDuration() } - return time.Duration(*pc.DefaultTLSDur) + return pc.DefaultTLSDur.Duration } // MinTLSCertDuration returns the minimum TLS cert duration for the provisioner. // If the minimum is not set within the provisioner, then the global // minimum from the authority configuration will be used. func (pc *ProvisionerClaims) MinTLSCertDuration() time.Duration { - if pc.MinTLSDur == nil || *pc.MinTLSDur == 0 { + if pc.MinTLSDur == nil || pc.MinTLSDur.Duration == 0 { return pc.globalClaims.MinTLSCertDuration() } - return time.Duration(*pc.MinTLSDur) + return pc.MinTLSDur.Duration } // MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner. // If the maximum is not set within the provisioner, then the global // maximum from the authority configuration will be used. func (pc *ProvisionerClaims) MaxTLSCertDuration() time.Duration { - if pc.MaxTLSDur == nil || *pc.MaxTLSDur == 0 { + if pc.MaxTLSDur == nil || pc.MaxTLSDur.Duration == 0 { return pc.globalClaims.MaxTLSCertDuration() } - return time.Duration(*pc.MaxTLSDur) + return pc.MaxTLSDur.Duration } // IsDisableRenewal returns if the renewal flow is disabled for the diff --git a/authority/types.go b/authority/types.go index a0c7661a..f0a781d5 100644 --- a/authority/types.go +++ b/authority/types.go @@ -8,15 +8,17 @@ import ( ) // Duration is a wrapper around Time.Duration to aid with marshal/unmarshal. -type duration time.Duration +type Duration struct { + time.Duration +} // MarshalJSON parses a duration string and sets it to the duration. // // A duration string is a possibly signed sequence of decimal numbers, each with // optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *duration) MarshalJSON() ([]byte, error) { - return json.Marshal((*time.Duration)(d).String()) +func (d *Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.Duration.String()) } // UnmarshalJSON parses a duration string and sets it to the duration. @@ -24,7 +26,7 @@ func (d *duration) MarshalJSON() ([]byte, error) { // A duration string is a possibly signed sequence of decimal numbers, each with // optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *duration) UnmarshalJSON(data []byte) (err error) { +func (d *Duration) UnmarshalJSON(data []byte) (err error) { var ( s string _d time.Duration @@ -38,7 +40,7 @@ func (d *duration) UnmarshalJSON(data []byte) (err error) { if _d, err = time.ParseDuration(s); err != nil { return errors.Wrapf(err, "error parsing %s as duration", s) } - *d = duration(_d) + d.Duration = _d return } diff --git a/authority/types_test.go b/authority/types_test.go index 6050409c..c49c368f 100644 --- a/authority/types_test.go +++ b/authority/types_test.go @@ -102,38 +102,32 @@ func Test_multiString_UnmarshalJSON(t *testing.T) { } } -func durPtr(_d time.Duration) *duration { - d := new(duration) - *d = duration(_d) - return d -} - -func Test_duration_UnmarshalJSON(t *testing.T) { +func TestDuration_UnmarshalJSON(t *testing.T) { type args struct { data []byte } tests := []struct { name string - d *duration + d *Duration args args - want *duration + want *Duration wantErr bool }{ - {"empty", new(duration), args{[]byte{}}, new(duration), true}, - {"bad type", new(duration), args{[]byte(`15`)}, new(duration), true}, - {"empty string", new(duration), args{[]byte(`""`)}, new(duration), true}, - {"non duration", new(duration), args{[]byte(`"15"`)}, new(duration), true}, - {"duration", new(duration), args{[]byte(`"15m30s"`)}, durPtr(15*time.Minute + 30*time.Second), false}, + {"empty", new(Duration), args{[]byte{}}, new(Duration), true}, + {"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true}, + {"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true}, + {"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true}, + {"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false}, {"nil", nil, args{nil}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { - t.Errorf("multiString.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(tt.d, tt.want) { - t.Errorf("multiString.UnmarshalJSON() = %v, want %v", tt.d, tt.want) + t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want) } }) } @@ -142,21 +136,21 @@ func Test_duration_UnmarshalJSON(t *testing.T) { func Test_duration_MarshalJSON(t *testing.T) { tests := []struct { name string - d *duration + d *Duration want []byte wantErr bool }{ - {"string", durPtr(15*time.Minute + 30*time.Second), []byte(`"15m30s"`), false}, + {"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.d.MarshalJSON() if (err != nil) != tt.wantErr { - t.Errorf("duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("duration.MarshalJSON() = %v, want %v", got, tt.want) + t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want) } }) }