diff --git a/authority/provisioner/duration.go b/authority/provisioner/duration.go new file mode 100644 index 00000000..38d504a3 --- /dev/null +++ b/authority/provisioner/duration.go @@ -0,0 +1,45 @@ +package provisioner + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal. +type Duration struct { + time.Duration +} + +// MarshalJSON parses a duration string and sets it to the duration. +// +// A duration string is a possibly signed sequence of decimal numbers, each with +// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +func (d *Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.Duration.String()) +} + +// UnmarshalJSON parses a duration string and sets it to the duration. +// +// A duration string is a possibly signed sequence of decimal numbers, each with +// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +func (d *Duration) UnmarshalJSON(data []byte) (err error) { + var ( + s string + _d time.Duration + ) + if d == nil { + return errors.New("duration cannot be nil") + } + if err = json.Unmarshal(data, &s); err != nil { + return errors.Wrapf(err, "error unmarshaling %s", data) + } + if _d, err = time.ParseDuration(s); err != nil { + return errors.Wrapf(err, "error parsing %s as duration", s) + } + d.Duration = _d + return +} diff --git a/authority/provisioner/duration_test.go b/authority/provisioner/duration_test.go new file mode 100644 index 00000000..4f7304a0 --- /dev/null +++ b/authority/provisioner/duration_test.go @@ -0,0 +1,61 @@ +package provisioner + +import ( + "reflect" + "testing" + "time" +) + +func TestDuration_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + d *Duration + args args + want *Duration + wantErr bool + }{ + {"empty", new(Duration), args{[]byte{}}, new(Duration), true}, + {"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true}, + {"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true}, + {"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true}, + {"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false}, + {"nil", nil, args{nil}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(tt.d, tt.want) { + t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want) + } + }) + } +} + +func TestDuration_MarshalJSON(t *testing.T) { + tests := []struct { + name string + d *Duration + want []byte + wantErr bool + }{ + {"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.d.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +}