mirror of
https://github.com/lightninglabs/loop
synced 2024-11-11 13:11:12 +00:00
Merge pull request #145 from guggero/stream-interceptor
lsat: add stream interceptor
This commit is contained in:
commit
ba6b4e782e
@ -89,6 +89,15 @@ func NewInterceptor(lnd *lndclient.LndServices, store Store,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// interceptContext is a struct that contains all information about a call that
|
||||||
|
// is intercepted by the interceptor.
|
||||||
|
type interceptContext struct {
|
||||||
|
mainCtx context.Context
|
||||||
|
opts []grpc.CallOption
|
||||||
|
metadata *metadata.MD
|
||||||
|
token *Token
|
||||||
|
}
|
||||||
|
|
||||||
// UnaryInterceptor is an interceptor method that can be used directly by gRPC
|
// UnaryInterceptor is an interceptor method that can be used directly by gRPC
|
||||||
// for unary calls. If the store contains a token, it is attached as credentials
|
// for unary calls. If the store contains a token, it is attached as credentials
|
||||||
// to every call before patching it through. The response error is also
|
// to every call before patching it through. The response error is also
|
||||||
@ -105,21 +114,100 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string,
|
|||||||
i.lock.Lock()
|
i.lock.Lock()
|
||||||
defer i.lock.Unlock()
|
defer i.lock.Unlock()
|
||||||
|
|
||||||
addLsatCredentials := func(token *Token) error {
|
// Create the context that we'll use to initiate the real request. This
|
||||||
macaroon, err := token.PaidMacaroon()
|
// contains the means to extract response headers and possibly also an
|
||||||
|
// auth token, if we already have paid for one.
|
||||||
|
iCtx, err := i.newInterceptContext(ctx, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
opts = append(opts, grpc.PerRPCCredentials(
|
|
||||||
macaroons.NewMacaroonCredential(macaroon),
|
// Try executing the call now. If anything goes wrong, we only handle
|
||||||
))
|
// the LSAT error message that comes in the form of a gRPC status error.
|
||||||
return nil
|
rpcCtx, cancel := context.WithTimeout(ctx, i.callTimeout)
|
||||||
|
defer cancel()
|
||||||
|
err = invoker(rpcCtx, method, req, reply, cc, iCtx.opts...)
|
||||||
|
if !isPaymentRequired(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find out if we need to pay for a new token or perhaps resume
|
||||||
|
// a previously aborted payment.
|
||||||
|
err = i.handlePayment(iCtx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the same request again, now with the LSAT
|
||||||
|
// token added as an RPC credential.
|
||||||
|
rpcCtx2, cancel2 := context.WithTimeout(ctx, i.callTimeout)
|
||||||
|
defer cancel2()
|
||||||
|
return invoker(rpcCtx2, method, req, reply, cc, iCtx.opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamInterceptor is an interceptor method that can be used directly by gRPC
|
||||||
|
// for streaming calls. If the store contains a token, it is attached as
|
||||||
|
// credentials to every stream establishment call before patching it through.
|
||||||
|
// The response error is also intercepted for every initial stream initiation.
|
||||||
|
// If there is an error returned and it is indicating a payment challenge, a
|
||||||
|
// token is acquired and paid for automatically. The original request is then
|
||||||
|
// repeated back to the server, now with the new token attached.
|
||||||
|
func (i *Interceptor) StreamInterceptor(ctx context.Context,
|
||||||
|
desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||||
|
streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream,
|
||||||
|
error) {
|
||||||
|
|
||||||
|
// To avoid paying for a token twice if two parallel requests are
|
||||||
|
// happening, we require an exclusive lock here.
|
||||||
|
i.lock.Lock()
|
||||||
|
defer i.lock.Unlock()
|
||||||
|
|
||||||
|
// Create the context that we'll use to initiate the real request. This
|
||||||
|
// contains the means to extract response headers and possibly also an
|
||||||
|
// auth token, if we already have paid for one.
|
||||||
|
iCtx, err := i.newInterceptContext(ctx, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try establishing the stream now. If anything goes wrong, we only
|
||||||
|
// handle the LSAT error message that comes in the form of a gRPC status
|
||||||
|
// error. The context of a stream will be used for the whole lifetime of
|
||||||
|
// it, so we can't really clamp down on the initial call with a timeout.
|
||||||
|
stream, err := streamer(ctx, desc, cc, method, iCtx.opts...)
|
||||||
|
if !isPaymentRequired(err) {
|
||||||
|
return stream, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find out if we need to pay for a new token or perhaps resume
|
||||||
|
// a previously aborted payment.
|
||||||
|
err = i.handlePayment(iCtx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the same request again, now with the LSAT token added
|
||||||
|
// as an RPC credential.
|
||||||
|
return streamer(ctx, desc, cc, method, iCtx.opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newInterceptContext creates the initial intercept context that can capture
|
||||||
|
// metadata from the server and sends the local token to the server if one
|
||||||
|
// already exists.
|
||||||
|
func (i *Interceptor) newInterceptContext(ctx context.Context,
|
||||||
|
opts []grpc.CallOption) (*interceptContext, error) {
|
||||||
|
|
||||||
|
iCtx := &interceptContext{
|
||||||
|
mainCtx: ctx,
|
||||||
|
opts: opts,
|
||||||
|
metadata: &metadata.MD{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let's see if the store already contains a token and what state it
|
// Let's see if the store already contains a token and what state it
|
||||||
// might be in. If a previous call was aborted, we might have a pending
|
// might be in. If a previous call was aborted, we might have a pending
|
||||||
// token that needs to be handled separately.
|
// token that needs to be handled separately.
|
||||||
token, err := i.store.CurrentToken()
|
var err error
|
||||||
|
iCtx.token, err = i.store.CurrentToken()
|
||||||
switch {
|
switch {
|
||||||
// If there is no token yet, nothing to do at this point.
|
// If there is no token yet, nothing to do at this point.
|
||||||
case err == ErrNoToken:
|
case err == ErrNoToken:
|
||||||
@ -127,16 +215,18 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string,
|
|||||||
// Some other error happened that we have to surface.
|
// Some other error happened that we have to surface.
|
||||||
case err != nil:
|
case err != nil:
|
||||||
log.Errorf("Failed to get token from store: %v", err)
|
log.Errorf("Failed to get token from store: %v", err)
|
||||||
return fmt.Errorf("getting token from store failed: %v", err)
|
return nil, fmt.Errorf("getting token from store failed: %v",
|
||||||
|
err)
|
||||||
|
|
||||||
// Only if we have a paid token append it. We don't resume a pending
|
// Only if we have a paid token append it. We don't resume a pending
|
||||||
// payment just yet, since we don't even know if a token is required for
|
// payment just yet, since we don't even know if a token is required for
|
||||||
// this call. We also never send a pending payment to the server since
|
// this call. We also never send a pending payment to the server since
|
||||||
// we know it's not valid.
|
// we know it's not valid.
|
||||||
case !token.isPending():
|
case !iCtx.token.isPending():
|
||||||
if err = addLsatCredentials(token); err != nil {
|
if err = i.addLsatCredentials(iCtx); err != nil {
|
||||||
log.Errorf("Adding macaroon to request failed: %v", err)
|
log.Errorf("Adding macaroon to request failed: %v", err)
|
||||||
return fmt.Errorf("adding macaroon failed: %v", err)
|
return nil, fmt.Errorf("adding macaroon failed: %v",
|
||||||
|
err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,60 +235,59 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string,
|
|||||||
// option. We execute the request and inspect the error. If it's the
|
// option. We execute the request and inspect the error. If it's the
|
||||||
// LSAT specific payment required error, we might execute the same
|
// LSAT specific payment required error, we might execute the same
|
||||||
// method again later with the paid LSAT token.
|
// method again later with the paid LSAT token.
|
||||||
trailerMetadata := &metadata.MD{}
|
iCtx.opts = append(iCtx.opts, grpc.Trailer(iCtx.metadata))
|
||||||
opts = append(opts, grpc.Trailer(trailerMetadata))
|
return iCtx, nil
|
||||||
rpcCtx, cancel := context.WithTimeout(ctx, i.callTimeout)
|
|
||||||
defer cancel()
|
|
||||||
err = invoker(rpcCtx, method, req, reply, cc, opts...)
|
|
||||||
|
|
||||||
// Only handle the LSAT error message that comes in the form of
|
|
||||||
// a gRPC status error.
|
|
||||||
if isPaymentRequired(err) {
|
|
||||||
paidToken, err := i.handlePayment(ctx, token, trailerMetadata)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err = addLsatCredentials(paidToken); err != nil {
|
|
||||||
log.Errorf("Adding macaroon to request failed: %v", err)
|
|
||||||
return fmt.Errorf("adding macaroon failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute the same request again, now with the LSAT
|
|
||||||
// token added as an RPC credential.
|
|
||||||
rpcCtx2, cancel2 := context.WithTimeout(ctx, i.callTimeout)
|
|
||||||
defer cancel2()
|
|
||||||
return invoker(rpcCtx2, method, req, reply, cc, opts...)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handlePayment tries to obtain a valid token by either tracking the payment
|
// handlePayment tries to obtain a valid token by either tracking the payment
|
||||||
// status of a pending token or paying for a new one.
|
// status of a pending token or paying for a new one.
|
||||||
func (i *Interceptor) handlePayment(ctx context.Context, token *Token,
|
func (i *Interceptor) handlePayment(iCtx *interceptContext) error {
|
||||||
md *metadata.MD) (*Token, error) {
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
// Resume/track a pending payment if it was interrupted for some reason.
|
// Resume/track a pending payment if it was interrupted for some reason.
|
||||||
case token != nil && token.isPending():
|
case iCtx.token != nil && iCtx.token.isPending():
|
||||||
log.Infof("Payment of LSAT token is required, resuming/" +
|
log.Infof("Payment of LSAT token is required, resuming/" +
|
||||||
"tracking previous payment from pending LSAT token")
|
"tracking previous payment from pending LSAT token")
|
||||||
err := i.trackPayment(ctx, token)
|
err := i.trackPayment(iCtx.mainCtx, iCtx.token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
return token, nil
|
|
||||||
|
|
||||||
// We don't have a token yet, try to get a new one.
|
// We don't have a token yet, try to get a new one.
|
||||||
case token == nil:
|
case iCtx.token == nil:
|
||||||
// We don't have a token yet, get a new one.
|
// We don't have a token yet, get a new one.
|
||||||
log.Infof("Payment of LSAT token is required, paying invoice")
|
log.Infof("Payment of LSAT token is required, paying invoice")
|
||||||
return i.payLsatToken(ctx, md)
|
var err error
|
||||||
|
iCtx.token, err = i.payLsatToken(iCtx.mainCtx, iCtx.metadata)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// We have a token and it's valid, nothing more to do here.
|
// We have a token and it's valid, nothing more to do here.
|
||||||
default:
|
default:
|
||||||
log.Debugf("Found valid LSAT token to add to request")
|
log.Debugf("Found valid LSAT token to add to request")
|
||||||
return token, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := i.addLsatCredentials(iCtx); err != nil {
|
||||||
|
log.Errorf("Adding macaroon to request failed: %v", err)
|
||||||
|
return fmt.Errorf("adding macaroon failed: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addLsatCredentials adds an LSAT token to the given intercept context.
|
||||||
|
func (i *Interceptor) addLsatCredentials(iCtx *interceptContext) error {
|
||||||
|
if iCtx.token == nil {
|
||||||
|
return fmt.Errorf("cannot add nil token to context")
|
||||||
|
}
|
||||||
|
|
||||||
|
macaroon, err := iCtx.token.PaidMacaroon()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
iCtx.opts = append(iCtx.opts, grpc.PerRPCCredentials(
|
||||||
|
macaroons.NewMacaroonCredential(macaroon),
|
||||||
|
))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// payLsatToken reads the payment challenge from the response metadata and tries
|
// payLsatToken reads the payment challenge from the response metadata and tries
|
||||||
|
@ -19,6 +19,21 @@ import (
|
|||||||
"gopkg.in/macaroon.v2"
|
"gopkg.in/macaroon.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type interceptTestCase struct {
|
||||||
|
name string
|
||||||
|
initialPreimage *lntypes.Preimage
|
||||||
|
interceptor *Interceptor
|
||||||
|
resetCb func()
|
||||||
|
expectLndCall bool
|
||||||
|
sendPaymentCb func(*testing.T, test.PaymentChannelMessage)
|
||||||
|
trackPaymentCb func(*testing.T, test.TrackPaymentMessage)
|
||||||
|
expectToken bool
|
||||||
|
expectInterceptErr string
|
||||||
|
expectBackendCalls int
|
||||||
|
expectMacaroonCall1 bool
|
||||||
|
expectMacaroonCall2 bool
|
||||||
|
}
|
||||||
|
|
||||||
type mockStore struct {
|
type mockStore struct {
|
||||||
token *Token
|
token *Token
|
||||||
}
|
}
|
||||||
@ -39,12 +54,7 @@ func (s *mockStore) StoreToken(token *Token) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestInterceptor tests that the interceptor can handle LSAT protocol responses
|
var (
|
||||||
// and pay the token.
|
|
||||||
func TestInterceptor(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var (
|
|
||||||
lnd = test.NewMockLnd()
|
lnd = test.NewMockLnd()
|
||||||
store = &mockStore{}
|
store = &mockStore{}
|
||||||
testTimeout = 5 * time.Second
|
testTimeout = 5 * time.Second
|
||||||
@ -52,53 +62,21 @@ func TestInterceptor(t *testing.T) {
|
|||||||
&lnd.LndServices, store, testTimeout,
|
&lnd.LndServices, store, testTimeout,
|
||||||
DefaultMaxCostSats, DefaultMaxRoutingFeeSats,
|
DefaultMaxCostSats, DefaultMaxRoutingFeeSats,
|
||||||
)
|
)
|
||||||
testMac = makeMac(t)
|
testMac = makeMac()
|
||||||
testMacBytes = serializeMac(t, testMac)
|
testMacBytes = serializeMac(testMac)
|
||||||
testMacHex = hex.EncodeToString(testMacBytes)
|
testMacHex = hex.EncodeToString(testMacBytes)
|
||||||
paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5}
|
paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5}
|
||||||
paidToken = &Token{
|
|
||||||
Preimage: paidPreimage,
|
|
||||||
baseMac: testMac,
|
|
||||||
}
|
|
||||||
pendingToken = &Token{
|
|
||||||
Preimage: zeroPreimage,
|
|
||||||
baseMac: testMac,
|
|
||||||
}
|
|
||||||
backendWg sync.WaitGroup
|
|
||||||
backendErr error
|
backendErr error
|
||||||
backendAuth = ""
|
backendAuth = ""
|
||||||
callMD map[string]string
|
callMD map[string]string
|
||||||
numBackendCalls = 0
|
numBackendCalls = 0
|
||||||
)
|
overallWg sync.WaitGroup
|
||||||
|
backendWg sync.WaitGroup
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
testCases = []interceptTestCase{
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// resetBackend is used by the test cases to define the behaviour of the
|
|
||||||
// simulated backend and reset its starting conditions.
|
|
||||||
resetBackend := func(expectedErr error, expectedAuth string) {
|
|
||||||
backendErr = expectedErr
|
|
||||||
backendAuth = expectedAuth
|
|
||||||
callMD = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
initialToken *Token
|
|
||||||
interceptor *Interceptor
|
|
||||||
resetCb func()
|
|
||||||
expectLndCall bool
|
|
||||||
sendPaymentCb func(msg test.PaymentChannelMessage)
|
|
||||||
trackPaymentCb func(msg test.TrackPaymentMessage)
|
|
||||||
expectToken bool
|
|
||||||
expectInterceptErr string
|
|
||||||
expectBackendCalls int
|
|
||||||
expectMacaroonCall1 bool
|
|
||||||
expectMacaroonCall2 bool
|
|
||||||
}{
|
|
||||||
{
|
{
|
||||||
name: "no auth required happy path",
|
name: "no auth required happy path",
|
||||||
initialToken: nil,
|
initialPreimage: nil,
|
||||||
interceptor: interceptor,
|
interceptor: interceptor,
|
||||||
resetCb: func() { resetBackend(nil, "") },
|
resetCb: func() { resetBackend(nil, "") },
|
||||||
expectLndCall: false,
|
expectLndCall: false,
|
||||||
@ -109,7 +87,7 @@ func TestInterceptor(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "auth required, no token yet",
|
name: "auth required, no token yet",
|
||||||
initialToken: nil,
|
initialPreimage: nil,
|
||||||
interceptor: interceptor,
|
interceptor: interceptor,
|
||||||
resetCb: func() {
|
resetCb: func() {
|
||||||
resetBackend(
|
resetBackend(
|
||||||
@ -120,7 +98,9 @@ func TestInterceptor(t *testing.T) {
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
expectLndCall: true,
|
expectLndCall: true,
|
||||||
sendPaymentCb: func(msg test.PaymentChannelMessage) {
|
sendPaymentCb: func(t *testing.T,
|
||||||
|
msg test.PaymentChannelMessage) {
|
||||||
|
|
||||||
if len(callMD) != 0 {
|
if len(callMD) != 0 {
|
||||||
t.Fatalf("unexpected call metadata: "+
|
t.Fatalf("unexpected call metadata: "+
|
||||||
"%v", callMD)
|
"%v", callMD)
|
||||||
@ -134,7 +114,9 @@ func TestInterceptor(t *testing.T) {
|
|||||||
PaidFee: 345,
|
PaidFee: 345,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
trackPaymentCb: func(msg test.TrackPaymentMessage) {
|
trackPaymentCb: func(t *testing.T,
|
||||||
|
msg test.TrackPaymentMessage) {
|
||||||
|
|
||||||
t.Fatal("didn't expect call to trackPayment")
|
t.Fatal("didn't expect call to trackPayment")
|
||||||
},
|
},
|
||||||
expectToken: true,
|
expectToken: true,
|
||||||
@ -144,7 +126,7 @@ func TestInterceptor(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "auth required, has token",
|
name: "auth required, has token",
|
||||||
initialToken: paidToken,
|
initialPreimage: &paidPreimage,
|
||||||
interceptor: interceptor,
|
interceptor: interceptor,
|
||||||
resetCb: func() { resetBackend(nil, "") },
|
resetCb: func() { resetBackend(nil, "") },
|
||||||
expectLndCall: false,
|
expectLndCall: false,
|
||||||
@ -155,7 +137,7 @@ func TestInterceptor(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "auth required, has pending token",
|
name: "auth required, has pending token",
|
||||||
initialToken: pendingToken,
|
initialPreimage: &zeroPreimage,
|
||||||
interceptor: interceptor,
|
interceptor: interceptor,
|
||||||
resetCb: func() {
|
resetCb: func() {
|
||||||
resetBackend(
|
resetBackend(
|
||||||
@ -166,10 +148,14 @@ func TestInterceptor(t *testing.T) {
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
expectLndCall: true,
|
expectLndCall: true,
|
||||||
sendPaymentCb: func(msg test.PaymentChannelMessage) {
|
sendPaymentCb: func(t *testing.T,
|
||||||
|
msg test.PaymentChannelMessage) {
|
||||||
|
|
||||||
t.Fatal("didn't expect call to sendPayment")
|
t.Fatal("didn't expect call to sendPayment")
|
||||||
},
|
},
|
||||||
trackPaymentCb: func(msg test.TrackPaymentMessage) {
|
trackPaymentCb: func(t *testing.T,
|
||||||
|
msg test.TrackPaymentMessage) {
|
||||||
|
|
||||||
// The next call to the "backend" shouldn't
|
// The next call to the "backend" shouldn't
|
||||||
// return an error.
|
// return an error.
|
||||||
resetBackend(nil, "")
|
resetBackend(nil, "")
|
||||||
@ -186,7 +172,7 @@ func TestInterceptor(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "auth required, no token yet, cost limit",
|
name: "auth required, no token yet, cost limit",
|
||||||
initialToken: nil,
|
initialPreimage: nil,
|
||||||
interceptor: NewInterceptor(
|
interceptor: NewInterceptor(
|
||||||
&lnd.LndServices, store, testTimeout,
|
&lnd.LndServices, store, testTimeout,
|
||||||
100, DefaultMaxRoutingFeeSats,
|
100, DefaultMaxRoutingFeeSats,
|
||||||
@ -209,15 +195,20 @@ func TestInterceptor(t *testing.T) {
|
|||||||
expectMacaroonCall2: false,
|
expectMacaroonCall2: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
// The invoker is a simple function that simulates the actual call to
|
// resetBackend is used by the test cases to define the behaviour of the
|
||||||
// the server. We can track if it's been called and we can dictate what
|
// simulated backend and reset its starting conditions.
|
||||||
// error it should return.
|
func resetBackend(expectedErr error, expectedAuth string) {
|
||||||
invoker := func(_ context.Context, _ string, _ interface{},
|
backendErr = expectedErr
|
||||||
_ interface{}, _ *grpc.ClientConn,
|
backendAuth = expectedAuth
|
||||||
opts ...grpc.CallOption) error {
|
callMD = nil
|
||||||
|
}
|
||||||
|
|
||||||
defer backendWg.Done()
|
// The invoker is a simple function that simulates the actual call to
|
||||||
|
// the server. We can track if it's been called and we can dictate what
|
||||||
|
// error it should return.
|
||||||
|
func invoker(opts []grpc.CallOption) error {
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
// Extract the macaroon in case it was set in the
|
// Extract the macaroon in case it was set in the
|
||||||
// request call options.
|
// request call options.
|
||||||
@ -238,28 +229,84 @@ func TestInterceptor(t *testing.T) {
|
|||||||
}
|
}
|
||||||
numBackendCalls++
|
numBackendCalls++
|
||||||
return backendErr
|
return backendErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUnaryInterceptor tests that the interceptor can handle LSAT protocol
|
||||||
|
// responses for unary calls and pay the token.
|
||||||
|
func TestUnaryInterceptor(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
unaryInvoker := func(_ context.Context, _ string,
|
||||||
|
_ interface{}, _ interface{}, _ *grpc.ClientConn,
|
||||||
|
opts ...grpc.CallOption) error {
|
||||||
|
|
||||||
|
defer backendWg.Done()
|
||||||
|
return invoker(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run through the test cases.
|
// Run through the test cases.
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
intercept := func() error {
|
||||||
|
return tc.interceptor.UnaryInterceptor(
|
||||||
|
ctx, "", nil, nil, nil, unaryInvoker, nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
testInterceptor(t, tc, intercept)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamInterceptor tests that the interceptor can handle LSAT protocol
|
||||||
|
// responses in streams and pay the token.
|
||||||
|
func TestStreamInterceptor(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
streamInvoker := func(_ context.Context,
|
||||||
|
_ *grpc.StreamDesc, _ *grpc.ClientConn,
|
||||||
|
_ string, opts ...grpc.CallOption) (
|
||||||
|
grpc.ClientStream, error) { // nolint: unparam
|
||||||
|
|
||||||
|
defer backendWg.Done()
|
||||||
|
return nil, invoker(opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run through the test cases.
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
intercept := func() error {
|
||||||
|
_, err := tc.interceptor.StreamInterceptor(
|
||||||
|
ctx, nil, nil, "", streamInvoker,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
testInterceptor(t, tc, intercept)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInterceptor(t *testing.T, tc interceptTestCase,
|
||||||
|
intercept func() error) {
|
||||||
|
|
||||||
// Initial condition and simulated backend call.
|
// Initial condition and simulated backend call.
|
||||||
store.token = tc.initialToken
|
store.token = makeToken(tc.initialPreimage)
|
||||||
tc.resetCb()
|
tc.resetCb()
|
||||||
numBackendCalls = 0
|
numBackendCalls = 0
|
||||||
var overallWg sync.WaitGroup
|
|
||||||
backendWg.Add(1)
|
backendWg.Add(1)
|
||||||
overallWg.Add(1)
|
overallWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
err := tc.interceptor.UnaryInterceptor(
|
defer overallWg.Done()
|
||||||
ctx, "", nil, nil, nil, invoker, nil,
|
err := intercept()
|
||||||
)
|
|
||||||
if err != nil && tc.expectInterceptErr != "" &&
|
if err != nil && tc.expectInterceptErr != "" &&
|
||||||
err.Error() != tc.expectInterceptErr {
|
err.Error() != tc.expectInterceptErr {
|
||||||
panic(fmt.Errorf("unexpected error '%s', "+
|
panic(fmt.Errorf("unexpected error '%s', "+
|
||||||
"expected '%s'", err.Error(),
|
"expected '%s'", err.Error(),
|
||||||
tc.expectInterceptErr))
|
tc.expectInterceptErr))
|
||||||
}
|
}
|
||||||
overallWg.Done()
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
backendWg.Wait()
|
backendWg.Wait()
|
||||||
@ -286,10 +333,10 @@ func TestInterceptor(t *testing.T) {
|
|||||||
if tc.expectLndCall {
|
if tc.expectLndCall {
|
||||||
select {
|
select {
|
||||||
case payment := <-lnd.SendPaymentChannel:
|
case payment := <-lnd.SendPaymentChannel:
|
||||||
tc.sendPaymentCb(payment)
|
tc.sendPaymentCb(t, payment)
|
||||||
|
|
||||||
case track := <-lnd.TrackPaymentChannel:
|
case track := <-lnd.TrackPaymentChannel:
|
||||||
tc.trackPaymentCb(track)
|
tc.trackPaymentCb(t, track)
|
||||||
|
|
||||||
case <-time.After(testTimeout):
|
case <-time.After(testTimeout):
|
||||||
t.Fatalf("[%s]: no payment request received",
|
t.Fatalf("[%s]: no payment request received",
|
||||||
@ -299,7 +346,6 @@ func TestInterceptor(t *testing.T) {
|
|||||||
backendWg.Wait()
|
backendWg.Wait()
|
||||||
overallWg.Wait()
|
overallWg.Wait()
|
||||||
|
|
||||||
// Interpret result/expectations.
|
|
||||||
if tc.expectToken {
|
if tc.expectToken {
|
||||||
if _, err := store.CurrentToken(); err != nil {
|
if _, err := store.CurrentToken(); err != nil {
|
||||||
t.Fatalf("[%s] expected store to contain token",
|
t.Fatalf("[%s] expected store to contain token",
|
||||||
@ -327,26 +373,33 @@ func TestInterceptor(t *testing.T) {
|
|||||||
"expected times", numBackendCalls,
|
"expected times", numBackendCalls,
|
||||||
tc.expectBackendCalls)
|
tc.expectBackendCalls)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeToken(preimage *lntypes.Preimage) *Token {
|
||||||
|
if preimage == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &Token{
|
||||||
|
Preimage: *preimage,
|
||||||
|
baseMac: testMac,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeMac(t *testing.T) *macaroon.Macaroon {
|
func makeMac() *macaroon.Macaroon {
|
||||||
dummyMac, err := macaroon.New(
|
dummyMac, err := macaroon.New(
|
||||||
[]byte("aabbccddeeff00112233445566778899"), []byte("AA=="),
|
[]byte("aabbccddeeff00112233445566778899"), []byte("AA=="),
|
||||||
"LSAT", macaroon.LatestVersion,
|
"LSAT", macaroon.LatestVersion,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create macaroon: %v", err)
|
panic(fmt.Errorf("unable to create macaroon: %v", err))
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return dummyMac
|
return dummyMac
|
||||||
}
|
}
|
||||||
|
|
||||||
func serializeMac(t *testing.T, mac *macaroon.Macaroon) []byte {
|
func serializeMac(mac *macaroon.Macaroon) []byte {
|
||||||
macBytes, err := mac.MarshalBinary()
|
macBytes, err := mac.MarshalBinary()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to serialize macaroon: %v", err)
|
panic(fmt.Errorf("unable to serialize macaroon: %v", err))
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return macBytes
|
return macBytes
|
||||||
}
|
}
|
||||||
|
@ -23,11 +23,11 @@ func TestFileStore(t *testing.T) {
|
|||||||
paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5}
|
paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5}
|
||||||
paidToken = &Token{
|
paidToken = &Token{
|
||||||
Preimage: paidPreimage,
|
Preimage: paidPreimage,
|
||||||
baseMac: makeMac(t),
|
baseMac: makeMac(),
|
||||||
}
|
}
|
||||||
pendingToken = &Token{
|
pendingToken = &Token{
|
||||||
Preimage: zeroPreimage,
|
Preimage: zeroPreimage,
|
||||||
baseMac: makeMac(t),
|
baseMac: makeMac(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user