mirror of
https://github.com/smallstep/certificates.git
synced 2024-11-17 15:29:21 +00:00
395 lines
9.8 KiB
Go
395 lines
9.8 KiB
Go
package nosql
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/smallstep/certificates/acme"
|
|
certificatesdb "github.com/smallstep/certificates/db"
|
|
"github.com/smallstep/nosql"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestDB_GetDpopToken(t *testing.T) {
|
|
type test struct {
|
|
db *DB
|
|
orderID string
|
|
expected map[string]any
|
|
expectedErr error
|
|
}
|
|
var tests = map[string]func(t *testing.T) test{
|
|
"fail/acme-not-found": func(t *testing.T) test {
|
|
dir := t.TempDir()
|
|
db, err := nosql.New("badgerv2", dir)
|
|
require.NoError(t, err)
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
expectedErr: &acme.Error{
|
|
Type: "urn:ietf:params:acme:error:malformed",
|
|
Status: 400,
|
|
Detail: "The request message was malformed",
|
|
Err: errors.New(`dpop token "orderID" not found`),
|
|
},
|
|
}
|
|
},
|
|
"fail/unmarshal-error": func(t *testing.T) test {
|
|
dir := t.TempDir()
|
|
db, err := nosql.New("badgerv2", dir)
|
|
require.NoError(t, err)
|
|
token := dbDpopToken{
|
|
ID: "orderID",
|
|
Content: []byte("{}"),
|
|
CreatedAt: time.Now(),
|
|
}
|
|
b, err := json.Marshal(token)
|
|
require.NoError(t, err)
|
|
err = db.Set(wireDpopTokenTable, []byte("orderID"), b[1:]) // start at index 1; corrupt JSON data
|
|
require.NoError(t, err)
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
expectedErr: errors.New(`error unmarshaling dpop "orderID" into dbDpopToken: invalid character ':' after top-level value`),
|
|
}
|
|
},
|
|
"fail/db.Get": func(t *testing.T) test {
|
|
db := &certificatesdb.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equal(t, wireDpopTokenTable, bucket)
|
|
assert.Equal(t, []byte("orderID"), key)
|
|
return nil, errors.New("fail")
|
|
},
|
|
}
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
expectedErr: errors.New(`error loading dpop "orderID": fail`),
|
|
}
|
|
},
|
|
"ok": func(t *testing.T) test {
|
|
dir := t.TempDir()
|
|
db, err := nosql.New("badgerv2", dir)
|
|
require.NoError(t, err)
|
|
token := dbDpopToken{
|
|
ID: "orderID",
|
|
Content: []byte(`{"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com"}`),
|
|
CreatedAt: time.Now(),
|
|
}
|
|
b, err := json.Marshal(token)
|
|
require.NoError(t, err)
|
|
err = db.Set(wireDpopTokenTable, []byte("orderID"), b)
|
|
require.NoError(t, err)
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
expected: map[string]any{
|
|
"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com",
|
|
},
|
|
}
|
|
},
|
|
}
|
|
for name, run := range tests {
|
|
tc := run(t)
|
|
t.Run(name, func(t *testing.T) {
|
|
got, err := tc.db.GetDpopToken(context.Background(), tc.orderID)
|
|
if tc.expectedErr != nil {
|
|
assert.EqualError(t, err, tc.expectedErr.Error())
|
|
ae := &acme.Error{}
|
|
if errors.As(err, &ae) {
|
|
ee := &acme.Error{}
|
|
require.True(t, errors.As(tc.expectedErr, &ee))
|
|
assert.Equal(t, ee.Detail, ae.Detail)
|
|
assert.Equal(t, ee.Type, ae.Type)
|
|
assert.Equal(t, ee.Status, ae.Status)
|
|
}
|
|
assert.Nil(t, got)
|
|
return
|
|
}
|
|
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tc.expected, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDB_CreateDpopToken(t *testing.T) {
|
|
type test struct {
|
|
db *DB
|
|
orderID string
|
|
dpop map[string]any
|
|
expectedErr error
|
|
}
|
|
var tests = map[string]func(t *testing.T) test{
|
|
"fail/db.Save": func(t *testing.T) test {
|
|
db := &certificatesdb.MockNoSQLDB{
|
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
assert.Equal(t, wireDpopTokenTable, bucket)
|
|
assert.Equal(t, []byte("orderID"), key)
|
|
return nil, false, errors.New("fail")
|
|
},
|
|
}
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
dpop: map[string]any{
|
|
"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com",
|
|
},
|
|
expectedErr: errors.New("failed saving dpop token: error saving acme dpop: fail"),
|
|
}
|
|
},
|
|
"ok": func(t *testing.T) test {
|
|
dir := t.TempDir()
|
|
db, err := nosql.New("badgerv2", dir)
|
|
require.NoError(t, err)
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
dpop: map[string]any{
|
|
"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com",
|
|
},
|
|
}
|
|
},
|
|
"ok/nil": func(t *testing.T) test {
|
|
dir := t.TempDir()
|
|
db, err := nosql.New("badgerv2", dir)
|
|
require.NoError(t, err)
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
dpop: nil,
|
|
}
|
|
},
|
|
}
|
|
for name, run := range tests {
|
|
tc := run(t)
|
|
t.Run(name, func(t *testing.T) {
|
|
err := tc.db.CreateDpopToken(context.Background(), tc.orderID, tc.dpop)
|
|
if tc.expectedErr != nil {
|
|
assert.EqualError(t, err, tc.expectedErr.Error())
|
|
return
|
|
}
|
|
|
|
assert.NoError(t, err)
|
|
|
|
dpop, err := tc.db.getDBDpopToken(context.Background(), tc.orderID)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, tc.orderID, dpop.ID)
|
|
var m map[string]any
|
|
err = json.Unmarshal(dpop.Content, &m)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, tc.dpop, m)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDB_GetOidcToken(t *testing.T) {
|
|
type test struct {
|
|
db *DB
|
|
orderID string
|
|
expected map[string]any
|
|
expectedErr error
|
|
}
|
|
var tests = map[string]func(t *testing.T) test{
|
|
"fail/acme-not-found": func(t *testing.T) test {
|
|
dir := t.TempDir()
|
|
db, err := nosql.New("badgerv2", dir)
|
|
require.NoError(t, err)
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
expectedErr: &acme.Error{
|
|
Type: "urn:ietf:params:acme:error:malformed",
|
|
Status: 400,
|
|
Detail: "The request message was malformed",
|
|
Err: errors.New(`oidc token "orderID" not found`),
|
|
},
|
|
}
|
|
},
|
|
"fail/unmarshal-error": func(t *testing.T) test {
|
|
dir := t.TempDir()
|
|
db, err := nosql.New("badgerv2", dir)
|
|
require.NoError(t, err)
|
|
token := dbOidcToken{
|
|
ID: "orderID",
|
|
Content: []byte("{}"),
|
|
CreatedAt: time.Now(),
|
|
}
|
|
b, err := json.Marshal(token)
|
|
require.NoError(t, err)
|
|
err = db.Set(wireOidcTokenTable, []byte("orderID"), b[1:]) // start at index 1; corrupt JSON data
|
|
require.NoError(t, err)
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
expectedErr: errors.New(`error unmarshaling oidc token "orderID" into dbOidcToken: invalid character ':' after top-level value`),
|
|
}
|
|
},
|
|
"fail/db.Get": func(t *testing.T) test {
|
|
db := &certificatesdb.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equal(t, wireOidcTokenTable, bucket)
|
|
assert.Equal(t, []byte("orderID"), key)
|
|
return nil, errors.New("fail")
|
|
},
|
|
}
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
expectedErr: errors.New(`error loading oidc token "orderID": fail`),
|
|
}
|
|
},
|
|
"ok": func(t *testing.T) test {
|
|
dir := t.TempDir()
|
|
db, err := nosql.New("badgerv2", dir)
|
|
require.NoError(t, err)
|
|
token := dbOidcToken{
|
|
ID: "orderID",
|
|
Content: []byte(`{"name": "Alice Smith", "handle": "@alice.smith"}`),
|
|
CreatedAt: time.Now(),
|
|
}
|
|
b, err := json.Marshal(token)
|
|
require.NoError(t, err)
|
|
err = db.Set(wireOidcTokenTable, []byte("orderID"), b)
|
|
require.NoError(t, err)
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
expected: map[string]any{
|
|
"name": "Alice Smith",
|
|
"handle": "@alice.smith",
|
|
},
|
|
}
|
|
},
|
|
}
|
|
for name, run := range tests {
|
|
tc := run(t)
|
|
t.Run(name, func(t *testing.T) {
|
|
got, err := tc.db.GetOidcToken(context.Background(), tc.orderID)
|
|
if tc.expectedErr != nil {
|
|
assert.EqualError(t, err, tc.expectedErr.Error())
|
|
ae := &acme.Error{}
|
|
if errors.As(err, &ae) {
|
|
ee := &acme.Error{}
|
|
require.True(t, errors.As(tc.expectedErr, &ee))
|
|
assert.Equal(t, ee.Detail, ae.Detail)
|
|
assert.Equal(t, ee.Type, ae.Type)
|
|
assert.Equal(t, ee.Status, ae.Status)
|
|
}
|
|
assert.Nil(t, got)
|
|
return
|
|
}
|
|
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tc.expected, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDB_CreateOidcToken(t *testing.T) {
|
|
type test struct {
|
|
db *DB
|
|
orderID string
|
|
oidc map[string]any
|
|
expectedErr error
|
|
}
|
|
var tests = map[string]func(t *testing.T) test{
|
|
"fail/db.Save": func(t *testing.T) test {
|
|
db := &certificatesdb.MockNoSQLDB{
|
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
assert.Equal(t, wireOidcTokenTable, bucket)
|
|
assert.Equal(t, []byte("orderID"), key)
|
|
return nil, false, errors.New("fail")
|
|
},
|
|
}
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
oidc: map[string]any{
|
|
"name": "Alice Smith",
|
|
"handle": "@alice.smith",
|
|
},
|
|
expectedErr: errors.New("failed saving oidc token: error saving acme oidc: fail"),
|
|
}
|
|
},
|
|
"ok": func(t *testing.T) test {
|
|
dir := t.TempDir()
|
|
db, err := nosql.New("badgerv2", dir)
|
|
require.NoError(t, err)
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
oidc: map[string]any{
|
|
"name": "Alice Smith",
|
|
"handle": "@alice.smith",
|
|
},
|
|
}
|
|
},
|
|
"ok/nil": func(t *testing.T) test {
|
|
dir := t.TempDir()
|
|
db, err := nosql.New("badgerv2", dir)
|
|
require.NoError(t, err)
|
|
return test{
|
|
db: &DB{
|
|
db: db,
|
|
},
|
|
orderID: "orderID",
|
|
oidc: nil,
|
|
}
|
|
},
|
|
}
|
|
for name, run := range tests {
|
|
tc := run(t)
|
|
t.Run(name, func(t *testing.T) {
|
|
err := tc.db.CreateOidcToken(context.Background(), tc.orderID, tc.oidc)
|
|
if tc.expectedErr != nil {
|
|
assert.EqualError(t, err, tc.expectedErr.Error())
|
|
return
|
|
}
|
|
|
|
assert.NoError(t, err)
|
|
|
|
oidc, err := tc.db.getDBOidcToken(context.Background(), tc.orderID)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, tc.orderID, oidc.ID)
|
|
var m map[string]any
|
|
err = json.Unmarshal(oidc.Content, &m)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, tc.oidc, m)
|
|
})
|
|
}
|
|
}
|