From 599fc1058c65a684ddee2da318cc563e9dcc0afe Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 10 Jun 2019 13:21:06 -0700 Subject: [PATCH] loadOrStore -> cmpAndSwap --- Gopkg.lock | 4 ++-- db/db.go | 8 +++----- db/db_test.go | 32 +++++++++++++++++--------------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/Gopkg.lock b/Gopkg.lock index 8f6edc08..7628ffce 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -363,7 +363,7 @@ [[projects]] branch = "master" - digest = "1:5e778214d472b6d2ad4d544d293d1478d9b222db8ffc6079623fbe3e58e1841e" + digest = "1:9c1b7052fa8f2c918efd60ed5ae3c70ccbba08967c58ec71067535449a3ba220" name = "github.com/smallstep/nosql" packages = [ ".", @@ -373,7 +373,7 @@ "mysql", ] pruneopts = "UT" - revision = "b66b34823456721912ba037126e92414690c07d6" + revision = "a0934e12468769d8cbede3ed316c47a4b88de4ca" [[projects]] branch = "master" diff --git a/db/db.go b/db/db.go index 0494c1ac..3438046c 100644 --- a/db/db.go +++ b/db/db.go @@ -131,14 +131,12 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error { // UseToken returns true if we were able to successfully store the token for // for the first time, false otherwise. func (db *DB) UseToken(id, tok string) (bool, error) { - // If the error is `Not Found` then the certificate has not been revoked. - // Any other error should be propagated to the caller. - _, found, err := db.LoadOrStore(usedOTTTable, []byte(id), []byte(tok)) + _, swapped, err := db.CmpAndSwap(usedOTTTable, []byte(id), nil, []byte(tok)) switch { case err != nil: - return false, errors.Wrapf(err, "error LoadOrStore-ing token %s/%s", + return false, errors.Wrapf(err, "error storing used token %s/%s", string(usedOTTTable), id) - case found: + case !swapped: return false, nil default: return true, nil diff --git a/db/db_test.go b/db/db_test.go index c5dbe4fe..a486fd84 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -20,12 +20,12 @@ type MockNoSQLDB struct { del func(bucket, key []byte) error list func(bucket []byte) ([]*database.Entry, error) update func(tx *database.Tx) error - loadOrStore func(bucket, key, value []byte) ([]byte, bool, error) + cmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error) } -func (m *MockNoSQLDB) LoadOrStore(bucket, key, value []byte) ([]byte, bool, error) { - if m.get != nil { - return m.loadOrStore(bucket, key, value) +func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) { + if m.cmpAndSwap != nil { + return m.cmpAndSwap(bucket, key, old, newval) } if m.ret1 == nil { return nil, false, m.err @@ -210,37 +210,37 @@ func TestUseToken(t *testing.T) { db *DB want result }{ - "fail/force-LoadOrStore-error": { + "fail/force-CmpAndSwap-error": { id: "id", tok: "token", db: &DB{&MockNoSQLDB{ - loadOrStore: func(bucket, key, value []byte) ([]byte, bool, error) { + cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return nil, false, errors.New("force") }, }, true}, want: result{ ok: false, - err: errors.New("error LoadOrStore-ing token id/token"), + err: errors.New("error storing used token used_ott/id"), }, }, - "fail/LoadOrStore-found": { + "fail/CmpAndSwap-already-exists": { id: "id", tok: "token", db: &DB{&MockNoSQLDB{ - loadOrStore: func(bucket, key, value []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil + cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), false, nil }, }, true}, want: result{ ok: false, }, }, - "ok/LoadOrStore-not-found": { + "ok/cmpAndSwap-success": { id: "id", tok: "token", db: &DB{&MockNoSQLDB{ - loadOrStore: func(bucket, key, value []byte) ([]byte, bool, error) { - return nil, false, nil + cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("bar"), true, nil }, }, true}, want: result{ @@ -253,11 +253,13 @@ func TestUseToken(t *testing.T) { ok, err := tc.db.UseToken(tc.id, tc.tok) if err != nil { if assert.NotNil(t, tc.want.err) { - assert.HasPrefix(t, tc.want.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tc.want.err.Error()) } assert.False(t, ok) + } else if ok { + assert.True(t, tc.want.ok) } else { - assert.True(t, ok) + assert.False(t, tc.want.ok) } }) }