diff --git a/api/api_test.go b/api/api_test.go index 98d612ab..70ba6a89 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -565,7 +565,7 @@ type mockAuthority struct { getSSHRoots func() (*authority.SSHKeys, error) getSSHFederation func() (*authority.SSHKeys, error) getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error) - checkSSHHost func(principal string) (bool, error) + checkSSHHost func(ctx context.Context, principal, token string) (bool, error) getSSHBastion func(user string, hostname string) (*authority.Bastion, error) version func() authority.Version } @@ -715,9 +715,9 @@ func (m *mockAuthority) GetSSHConfig(typ string, data map[string]string) ([]temp return m.ret1.([]templates.Output), m.err } -func (m *mockAuthority) CheckSSHHost(principal string) (bool, error) { +func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { if m.checkSSHHost != nil { - return m.checkSSHHost(principal) + return m.checkSSHHost(ctx, principal, token) } return m.ret1.(bool), m.err } diff --git a/api/ssh_test.go b/api/ssh_test.go index b5ff7002..cb5c7904 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -539,7 +540,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - checkSSHHost: func(_ string) (bool, error) { + checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) { return tt.exists, tt.err }, }).(*caHandler)