2
0
mirror of https://github.com/lightninglabs/loop synced 2024-11-11 13:11:12 +00:00

Merge pull request #145 from guggero/stream-interceptor

lsat: add stream interceptor
This commit is contained in:
Olaoluwa Osuntokun 2020-02-17 16:00:29 -08:00 committed by GitHub
commit ba6b4e782e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 376 additions and 234 deletions

View File

@ -89,6 +89,15 @@ func NewInterceptor(lnd *lndclient.LndServices, store Store,
} }
} }
// interceptContext is a struct that contains all information about a call that
// is intercepted by the interceptor.
type interceptContext struct {
mainCtx context.Context
opts []grpc.CallOption
metadata *metadata.MD
token *Token
}
// UnaryInterceptor is an interceptor method that can be used directly by gRPC // UnaryInterceptor is an interceptor method that can be used directly by gRPC
// for unary calls. If the store contains a token, it is attached as credentials // for unary calls. If the store contains a token, it is attached as credentials
// to every call before patching it through. The response error is also // to every call before patching it through. The response error is also
@ -105,21 +114,100 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string,
i.lock.Lock() i.lock.Lock()
defer i.lock.Unlock() defer i.lock.Unlock()
addLsatCredentials := func(token *Token) error { // Create the context that we'll use to initiate the real request. This
macaroon, err := token.PaidMacaroon() // contains the means to extract response headers and possibly also an
// auth token, if we already have paid for one.
iCtx, err := i.newInterceptContext(ctx, opts)
if err != nil { if err != nil {
return err return err
} }
opts = append(opts, grpc.PerRPCCredentials(
macaroons.NewMacaroonCredential(macaroon), // Try executing the call now. If anything goes wrong, we only handle
)) // the LSAT error message that comes in the form of a gRPC status error.
return nil rpcCtx, cancel := context.WithTimeout(ctx, i.callTimeout)
defer cancel()
err = invoker(rpcCtx, method, req, reply, cc, iCtx.opts...)
if !isPaymentRequired(err) {
return err
}
// Find out if we need to pay for a new token or perhaps resume
// a previously aborted payment.
err = i.handlePayment(iCtx)
if err != nil {
return err
}
// Execute the same request again, now with the LSAT
// token added as an RPC credential.
rpcCtx2, cancel2 := context.WithTimeout(ctx, i.callTimeout)
defer cancel2()
return invoker(rpcCtx2, method, req, reply, cc, iCtx.opts...)
}
// StreamInterceptor is an interceptor method that can be used directly by gRPC
// for streaming calls. If the store contains a token, it is attached as
// credentials to every stream establishment call before patching it through.
// The response error is also intercepted for every initial stream initiation.
// If there is an error returned and it is indicating a payment challenge, a
// token is acquired and paid for automatically. The original request is then
// repeated back to the server, now with the new token attached.
func (i *Interceptor) StreamInterceptor(ctx context.Context,
desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream,
error) {
// To avoid paying for a token twice if two parallel requests are
// happening, we require an exclusive lock here.
i.lock.Lock()
defer i.lock.Unlock()
// Create the context that we'll use to initiate the real request. This
// contains the means to extract response headers and possibly also an
// auth token, if we already have paid for one.
iCtx, err := i.newInterceptContext(ctx, opts)
if err != nil {
return nil, err
}
// Try establishing the stream now. If anything goes wrong, we only
// handle the LSAT error message that comes in the form of a gRPC status
// error. The context of a stream will be used for the whole lifetime of
// it, so we can't really clamp down on the initial call with a timeout.
stream, err := streamer(ctx, desc, cc, method, iCtx.opts...)
if !isPaymentRequired(err) {
return stream, err
}
// Find out if we need to pay for a new token or perhaps resume
// a previously aborted payment.
err = i.handlePayment(iCtx)
if err != nil {
return nil, err
}
// Execute the same request again, now with the LSAT token added
// as an RPC credential.
return streamer(ctx, desc, cc, method, iCtx.opts...)
}
// newInterceptContext creates the initial intercept context that can capture
// metadata from the server and sends the local token to the server if one
// already exists.
func (i *Interceptor) newInterceptContext(ctx context.Context,
opts []grpc.CallOption) (*interceptContext, error) {
iCtx := &interceptContext{
mainCtx: ctx,
opts: opts,
metadata: &metadata.MD{},
} }
// Let's see if the store already contains a token and what state it // Let's see if the store already contains a token and what state it
// might be in. If a previous call was aborted, we might have a pending // might be in. If a previous call was aborted, we might have a pending
// token that needs to be handled separately. // token that needs to be handled separately.
token, err := i.store.CurrentToken() var err error
iCtx.token, err = i.store.CurrentToken()
switch { switch {
// If there is no token yet, nothing to do at this point. // If there is no token yet, nothing to do at this point.
case err == ErrNoToken: case err == ErrNoToken:
@ -127,16 +215,18 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string,
// Some other error happened that we have to surface. // Some other error happened that we have to surface.
case err != nil: case err != nil:
log.Errorf("Failed to get token from store: %v", err) log.Errorf("Failed to get token from store: %v", err)
return fmt.Errorf("getting token from store failed: %v", err) return nil, fmt.Errorf("getting token from store failed: %v",
err)
// Only if we have a paid token append it. We don't resume a pending // Only if we have a paid token append it. We don't resume a pending
// payment just yet, since we don't even know if a token is required for // payment just yet, since we don't even know if a token is required for
// this call. We also never send a pending payment to the server since // this call. We also never send a pending payment to the server since
// we know it's not valid. // we know it's not valid.
case !token.isPending(): case !iCtx.token.isPending():
if err = addLsatCredentials(token); err != nil { if err = i.addLsatCredentials(iCtx); err != nil {
log.Errorf("Adding macaroon to request failed: %v", err) log.Errorf("Adding macaroon to request failed: %v", err)
return fmt.Errorf("adding macaroon failed: %v", err) return nil, fmt.Errorf("adding macaroon failed: %v",
err)
} }
} }
@ -145,60 +235,59 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string,
// option. We execute the request and inspect the error. If it's the // option. We execute the request and inspect the error. If it's the
// LSAT specific payment required error, we might execute the same // LSAT specific payment required error, we might execute the same
// method again later with the paid LSAT token. // method again later with the paid LSAT token.
trailerMetadata := &metadata.MD{} iCtx.opts = append(iCtx.opts, grpc.Trailer(iCtx.metadata))
opts = append(opts, grpc.Trailer(trailerMetadata)) return iCtx, nil
rpcCtx, cancel := context.WithTimeout(ctx, i.callTimeout)
defer cancel()
err = invoker(rpcCtx, method, req, reply, cc, opts...)
// Only handle the LSAT error message that comes in the form of
// a gRPC status error.
if isPaymentRequired(err) {
paidToken, err := i.handlePayment(ctx, token, trailerMetadata)
if err != nil {
return err
}
if err = addLsatCredentials(paidToken); err != nil {
log.Errorf("Adding macaroon to request failed: %v", err)
return fmt.Errorf("adding macaroon failed: %v", err)
}
// Execute the same request again, now with the LSAT
// token added as an RPC credential.
rpcCtx2, cancel2 := context.WithTimeout(ctx, i.callTimeout)
defer cancel2()
return invoker(rpcCtx2, method, req, reply, cc, opts...)
}
return err
} }
// handlePayment tries to obtain a valid token by either tracking the payment // handlePayment tries to obtain a valid token by either tracking the payment
// status of a pending token or paying for a new one. // status of a pending token or paying for a new one.
func (i *Interceptor) handlePayment(ctx context.Context, token *Token, func (i *Interceptor) handlePayment(iCtx *interceptContext) error {
md *metadata.MD) (*Token, error) {
switch { switch {
// Resume/track a pending payment if it was interrupted for some reason. // Resume/track a pending payment if it was interrupted for some reason.
case token != nil && token.isPending(): case iCtx.token != nil && iCtx.token.isPending():
log.Infof("Payment of LSAT token is required, resuming/" + log.Infof("Payment of LSAT token is required, resuming/" +
"tracking previous payment from pending LSAT token") "tracking previous payment from pending LSAT token")
err := i.trackPayment(ctx, token) err := i.trackPayment(iCtx.mainCtx, iCtx.token)
if err != nil { if err != nil {
return nil, err return err
} }
return token, nil
// We don't have a token yet, try to get a new one. // We don't have a token yet, try to get a new one.
case token == nil: case iCtx.token == nil:
// We don't have a token yet, get a new one. // We don't have a token yet, get a new one.
log.Infof("Payment of LSAT token is required, paying invoice") log.Infof("Payment of LSAT token is required, paying invoice")
return i.payLsatToken(ctx, md) var err error
iCtx.token, err = i.payLsatToken(iCtx.mainCtx, iCtx.metadata)
if err != nil {
return err
}
// We have a token and it's valid, nothing more to do here. // We have a token and it's valid, nothing more to do here.
default: default:
log.Debugf("Found valid LSAT token to add to request") log.Debugf("Found valid LSAT token to add to request")
return token, nil
} }
if err := i.addLsatCredentials(iCtx); err != nil {
log.Errorf("Adding macaroon to request failed: %v", err)
return fmt.Errorf("adding macaroon failed: %v", err)
}
return nil
}
// addLsatCredentials adds an LSAT token to the given intercept context.
func (i *Interceptor) addLsatCredentials(iCtx *interceptContext) error {
if iCtx.token == nil {
return fmt.Errorf("cannot add nil token to context")
}
macaroon, err := iCtx.token.PaidMacaroon()
if err != nil {
return err
}
iCtx.opts = append(iCtx.opts, grpc.PerRPCCredentials(
macaroons.NewMacaroonCredential(macaroon),
))
return nil
} }
// payLsatToken reads the payment challenge from the response metadata and tries // payLsatToken reads the payment challenge from the response metadata and tries

View File

@ -19,6 +19,21 @@ import (
"gopkg.in/macaroon.v2" "gopkg.in/macaroon.v2"
) )
type interceptTestCase struct {
name string
initialPreimage *lntypes.Preimage
interceptor *Interceptor
resetCb func()
expectLndCall bool
sendPaymentCb func(*testing.T, test.PaymentChannelMessage)
trackPaymentCb func(*testing.T, test.TrackPaymentMessage)
expectToken bool
expectInterceptErr string
expectBackendCalls int
expectMacaroonCall1 bool
expectMacaroonCall2 bool
}
type mockStore struct { type mockStore struct {
token *Token token *Token
} }
@ -39,12 +54,7 @@ func (s *mockStore) StoreToken(token *Token) error {
return nil return nil
} }
// TestInterceptor tests that the interceptor can handle LSAT protocol responses var (
// and pay the token.
func TestInterceptor(t *testing.T) {
t.Parallel()
var (
lnd = test.NewMockLnd() lnd = test.NewMockLnd()
store = &mockStore{} store = &mockStore{}
testTimeout = 5 * time.Second testTimeout = 5 * time.Second
@ -52,53 +62,21 @@ func TestInterceptor(t *testing.T) {
&lnd.LndServices, store, testTimeout, &lnd.LndServices, store, testTimeout,
DefaultMaxCostSats, DefaultMaxRoutingFeeSats, DefaultMaxCostSats, DefaultMaxRoutingFeeSats,
) )
testMac = makeMac(t) testMac = makeMac()
testMacBytes = serializeMac(t, testMac) testMacBytes = serializeMac(testMac)
testMacHex = hex.EncodeToString(testMacBytes) testMacHex = hex.EncodeToString(testMacBytes)
paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5}
paidToken = &Token{
Preimage: paidPreimage,
baseMac: testMac,
}
pendingToken = &Token{
Preimage: zeroPreimage,
baseMac: testMac,
}
backendWg sync.WaitGroup
backendErr error backendErr error
backendAuth = "" backendAuth = ""
callMD map[string]string callMD map[string]string
numBackendCalls = 0 numBackendCalls = 0
) overallWg sync.WaitGroup
backendWg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) testCases = []interceptTestCase{
defer cancel()
// resetBackend is used by the test cases to define the behaviour of the
// simulated backend and reset its starting conditions.
resetBackend := func(expectedErr error, expectedAuth string) {
backendErr = expectedErr
backendAuth = expectedAuth
callMD = nil
}
testCases := []struct {
name string
initialToken *Token
interceptor *Interceptor
resetCb func()
expectLndCall bool
sendPaymentCb func(msg test.PaymentChannelMessage)
trackPaymentCb func(msg test.TrackPaymentMessage)
expectToken bool
expectInterceptErr string
expectBackendCalls int
expectMacaroonCall1 bool
expectMacaroonCall2 bool
}{
{ {
name: "no auth required happy path", name: "no auth required happy path",
initialToken: nil, initialPreimage: nil,
interceptor: interceptor, interceptor: interceptor,
resetCb: func() { resetBackend(nil, "") }, resetCb: func() { resetBackend(nil, "") },
expectLndCall: false, expectLndCall: false,
@ -109,7 +87,7 @@ func TestInterceptor(t *testing.T) {
}, },
{ {
name: "auth required, no token yet", name: "auth required, no token yet",
initialToken: nil, initialPreimage: nil,
interceptor: interceptor, interceptor: interceptor,
resetCb: func() { resetCb: func() {
resetBackend( resetBackend(
@ -120,7 +98,9 @@ func TestInterceptor(t *testing.T) {
) )
}, },
expectLndCall: true, expectLndCall: true,
sendPaymentCb: func(msg test.PaymentChannelMessage) { sendPaymentCb: func(t *testing.T,
msg test.PaymentChannelMessage) {
if len(callMD) != 0 { if len(callMD) != 0 {
t.Fatalf("unexpected call metadata: "+ t.Fatalf("unexpected call metadata: "+
"%v", callMD) "%v", callMD)
@ -134,7 +114,9 @@ func TestInterceptor(t *testing.T) {
PaidFee: 345, PaidFee: 345,
} }
}, },
trackPaymentCb: func(msg test.TrackPaymentMessage) { trackPaymentCb: func(t *testing.T,
msg test.TrackPaymentMessage) {
t.Fatal("didn't expect call to trackPayment") t.Fatal("didn't expect call to trackPayment")
}, },
expectToken: true, expectToken: true,
@ -144,7 +126,7 @@ func TestInterceptor(t *testing.T) {
}, },
{ {
name: "auth required, has token", name: "auth required, has token",
initialToken: paidToken, initialPreimage: &paidPreimage,
interceptor: interceptor, interceptor: interceptor,
resetCb: func() { resetBackend(nil, "") }, resetCb: func() { resetBackend(nil, "") },
expectLndCall: false, expectLndCall: false,
@ -155,7 +137,7 @@ func TestInterceptor(t *testing.T) {
}, },
{ {
name: "auth required, has pending token", name: "auth required, has pending token",
initialToken: pendingToken, initialPreimage: &zeroPreimage,
interceptor: interceptor, interceptor: interceptor,
resetCb: func() { resetCb: func() {
resetBackend( resetBackend(
@ -166,10 +148,14 @@ func TestInterceptor(t *testing.T) {
) )
}, },
expectLndCall: true, expectLndCall: true,
sendPaymentCb: func(msg test.PaymentChannelMessage) { sendPaymentCb: func(t *testing.T,
msg test.PaymentChannelMessage) {
t.Fatal("didn't expect call to sendPayment") t.Fatal("didn't expect call to sendPayment")
}, },
trackPaymentCb: func(msg test.TrackPaymentMessage) { trackPaymentCb: func(t *testing.T,
msg test.TrackPaymentMessage) {
// The next call to the "backend" shouldn't // The next call to the "backend" shouldn't
// return an error. // return an error.
resetBackend(nil, "") resetBackend(nil, "")
@ -186,7 +172,7 @@ func TestInterceptor(t *testing.T) {
}, },
{ {
name: "auth required, no token yet, cost limit", name: "auth required, no token yet, cost limit",
initialToken: nil, initialPreimage: nil,
interceptor: NewInterceptor( interceptor: NewInterceptor(
&lnd.LndServices, store, testTimeout, &lnd.LndServices, store, testTimeout,
100, DefaultMaxRoutingFeeSats, 100, DefaultMaxRoutingFeeSats,
@ -209,15 +195,20 @@ func TestInterceptor(t *testing.T) {
expectMacaroonCall2: false, expectMacaroonCall2: false,
}, },
} }
)
// The invoker is a simple function that simulates the actual call to // resetBackend is used by the test cases to define the behaviour of the
// the server. We can track if it's been called and we can dictate what // simulated backend and reset its starting conditions.
// error it should return. func resetBackend(expectedErr error, expectedAuth string) {
invoker := func(_ context.Context, _ string, _ interface{}, backendErr = expectedErr
_ interface{}, _ *grpc.ClientConn, backendAuth = expectedAuth
opts ...grpc.CallOption) error { callMD = nil
}
defer backendWg.Done() // The invoker is a simple function that simulates the actual call to
// the server. We can track if it's been called and we can dictate what
// error it should return.
func invoker(opts []grpc.CallOption) error {
for _, opt := range opts { for _, opt := range opts {
// Extract the macaroon in case it was set in the // Extract the macaroon in case it was set in the
// request call options. // request call options.
@ -238,28 +229,84 @@ func TestInterceptor(t *testing.T) {
} }
numBackendCalls++ numBackendCalls++
return backendErr return backendErr
}
// TestUnaryInterceptor tests that the interceptor can handle LSAT protocol
// responses for unary calls and pay the token.
func TestUnaryInterceptor(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
unaryInvoker := func(_ context.Context, _ string,
_ interface{}, _ interface{}, _ *grpc.ClientConn,
opts ...grpc.CallOption) error {
defer backendWg.Done()
return invoker(opts)
} }
// Run through the test cases. // Run through the test cases.
for _, tc := range testCases { for _, tc := range testCases {
tc := tc
intercept := func() error {
return tc.interceptor.UnaryInterceptor(
ctx, "", nil, nil, nil, unaryInvoker, nil,
)
}
t.Run(tc.name, func(t *testing.T) {
testInterceptor(t, tc, intercept)
})
}
}
// TestStreamInterceptor tests that the interceptor can handle LSAT protocol
// responses in streams and pay the token.
func TestStreamInterceptor(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
streamInvoker := func(_ context.Context,
_ *grpc.StreamDesc, _ *grpc.ClientConn,
_ string, opts ...grpc.CallOption) (
grpc.ClientStream, error) { // nolint: unparam
defer backendWg.Done()
return nil, invoker(opts)
}
// Run through the test cases.
for _, tc := range testCases {
tc := tc
intercept := func() error {
_, err := tc.interceptor.StreamInterceptor(
ctx, nil, nil, "", streamInvoker,
)
return err
}
t.Run(tc.name, func(t *testing.T) {
testInterceptor(t, tc, intercept)
})
}
}
func testInterceptor(t *testing.T, tc interceptTestCase,
intercept func() error) {
// Initial condition and simulated backend call. // Initial condition and simulated backend call.
store.token = tc.initialToken store.token = makeToken(tc.initialPreimage)
tc.resetCb() tc.resetCb()
numBackendCalls = 0 numBackendCalls = 0
var overallWg sync.WaitGroup
backendWg.Add(1) backendWg.Add(1)
overallWg.Add(1) overallWg.Add(1)
go func() { go func() {
err := tc.interceptor.UnaryInterceptor( defer overallWg.Done()
ctx, "", nil, nil, nil, invoker, nil, err := intercept()
)
if err != nil && tc.expectInterceptErr != "" && if err != nil && tc.expectInterceptErr != "" &&
err.Error() != tc.expectInterceptErr { err.Error() != tc.expectInterceptErr {
panic(fmt.Errorf("unexpected error '%s', "+ panic(fmt.Errorf("unexpected error '%s', "+
"expected '%s'", err.Error(), "expected '%s'", err.Error(),
tc.expectInterceptErr)) tc.expectInterceptErr))
} }
overallWg.Done()
}() }()
backendWg.Wait() backendWg.Wait()
@ -286,10 +333,10 @@ func TestInterceptor(t *testing.T) {
if tc.expectLndCall { if tc.expectLndCall {
select { select {
case payment := <-lnd.SendPaymentChannel: case payment := <-lnd.SendPaymentChannel:
tc.sendPaymentCb(payment) tc.sendPaymentCb(t, payment)
case track := <-lnd.TrackPaymentChannel: case track := <-lnd.TrackPaymentChannel:
tc.trackPaymentCb(track) tc.trackPaymentCb(t, track)
case <-time.After(testTimeout): case <-time.After(testTimeout):
t.Fatalf("[%s]: no payment request received", t.Fatalf("[%s]: no payment request received",
@ -299,7 +346,6 @@ func TestInterceptor(t *testing.T) {
backendWg.Wait() backendWg.Wait()
overallWg.Wait() overallWg.Wait()
// Interpret result/expectations.
if tc.expectToken { if tc.expectToken {
if _, err := store.CurrentToken(); err != nil { if _, err := store.CurrentToken(); err != nil {
t.Fatalf("[%s] expected store to contain token", t.Fatalf("[%s] expected store to contain token",
@ -327,26 +373,33 @@ func TestInterceptor(t *testing.T) {
"expected times", numBackendCalls, "expected times", numBackendCalls,
tc.expectBackendCalls) tc.expectBackendCalls)
} }
}
func makeToken(preimage *lntypes.Preimage) *Token {
if preimage == nil {
return nil
}
return &Token{
Preimage: *preimage,
baseMac: testMac,
} }
} }
func makeMac(t *testing.T) *macaroon.Macaroon { func makeMac() *macaroon.Macaroon {
dummyMac, err := macaroon.New( dummyMac, err := macaroon.New(
[]byte("aabbccddeeff00112233445566778899"), []byte("AA=="), []byte("aabbccddeeff00112233445566778899"), []byte("AA=="),
"LSAT", macaroon.LatestVersion, "LSAT", macaroon.LatestVersion,
) )
if err != nil { if err != nil {
t.Fatalf("unable to create macaroon: %v", err) panic(fmt.Errorf("unable to create macaroon: %v", err))
return nil
} }
return dummyMac return dummyMac
} }
func serializeMac(t *testing.T, mac *macaroon.Macaroon) []byte { func serializeMac(mac *macaroon.Macaroon) []byte {
macBytes, err := mac.MarshalBinary() macBytes, err := mac.MarshalBinary()
if err != nil { if err != nil {
t.Fatalf("unable to serialize macaroon: %v", err) panic(fmt.Errorf("unable to serialize macaroon: %v", err))
return nil
} }
return macBytes return macBytes
} }

View File

@ -23,11 +23,11 @@ func TestFileStore(t *testing.T) {
paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5}
paidToken = &Token{ paidToken = &Token{
Preimage: paidPreimage, Preimage: paidPreimage,
baseMac: makeMac(t), baseMac: makeMac(),
} }
pendingToken = &Token{ pendingToken = &Token{
Preimage: zeroPreimage, Preimage: zeroPreimage,
baseMac: makeMac(t), baseMac: makeMac(),
} }
) )