From 4798b67dfeb179c74beb69a9f2ffe49b6c8de61f Mon Sep 17 00:00:00 2001 From: Chip Senkbeil Date: Sat, 19 Nov 2022 20:27:49 -0600 Subject: [PATCH] Rewrite to support custom authentication, handshakes for encryption/compression, and reconnecting (#146) --- .config/nextest.toml | 2 +- .github/workflows/ci.yml | 1 + Cargo.lock | 248 +- Cargo.toml | 10 +- distant-core/Cargo.toml | 9 +- distant-core/src/api.rs | 39 +- distant-core/src/api/local.rs | 151 +- distant-core/src/api/local/process/pty.rs | 5 +- .../src/api/local/process/simple/tasks.rs | 6 +- distant-core/src/api/local/state.rs | 6 +- distant-core/src/api/local/state/process.rs | 2 +- .../src/api/local/state/process/instance.rs | 16 +- distant-core/src/api/local/state/search.rs | 37 +- distant-core/src/api/local/state/watcher.rs | 29 +- .../src/api/local/state/watcher/path.rs | 8 +- distant-core/src/api/reply.rs | 2 +- distant-core/src/client.rs | 2 +- distant-core/src/client/ext.rs | 2 +- distant-core/src/client/lsp.rs | 96 +- distant-core/src/client/lsp/msg.rs | 3 +- distant-core/src/client/process.rs | 136 +- distant-core/src/client/searcher.rs | 55 +- distant-core/src/client/watcher.rs | 55 +- distant-core/src/constants.rs | 4 +- distant-core/src/credentials.rs | 8 +- distant-core/src/data.rs | 5 +- distant-core/src/data/search.rs | 1 + distant-core/src/lib.rs | 5 +- distant-core/src/manager/client.rs | 783 ------ distant-core/src/manager/client/config.rs | 85 - distant-core/src/manager/client/ext.rs | 14 - distant-core/src/manager/client/ext/tcp.rs | 50 - distant-core/src/manager/client/ext/unix.rs | 54 - .../src/manager/client/ext/windows.rs | 91 - distant-core/src/manager/data/id.rs | 5 - distant-core/src/manager/server.rs | 719 ------ distant-core/src/manager/server/connection.rs | 202 -- distant-core/src/manager/server/ext.rs | 14 - distant-core/src/manager/server/ext/tcp.rs | 30 - distant-core/src/manager/server/ext/unix.rs | 50 - .../src/manager/server/ext/windows.rs | 48 - distant-core/src/manager/server/handler.rs | 68 - distant-core/src/manager/server/ref.rs | 73 - distant-core/tests/manager_tests.rs | 96 - .../tests/stress/distant/large_file.rs | 3 +- distant-core/tests/stress/distant/watch.rs | 3 +- distant-core/tests/stress/fixtures.rs | 54 +- distant-core/tests/stress/mod.rs | 1 - distant-core/tests/stress/utils.rs | 23 - distant-net/Cargo.toml | 9 +- distant-net/src/auth.rs | 122 - distant-net/src/auth/client.rs | 817 ------ distant-net/src/auth/server.rs | 653 ----- distant-net/src/client.rs | 1123 ++++++++- distant-net/src/client/builder.rs | 142 ++ distant-net/src/client/builder/tcp.rs | 31 + distant-net/src/client/builder/unix.rs | 30 + distant-net/src/client/builder/windows.rs | 50 + distant-net/src/client/channel.rs | 524 +++- distant-net/src/client/channel/mailbox.rs | 204 +- distant-net/src/client/ext/tcp.rs | 49 - distant-net/src/client/ext/unix.rs | 54 - distant-net/src/client/ext/windows.rs | 86 - distant-net/src/client/reconnect.rs | 208 ++ distant-net/src/client/shutdown.rs | 36 + distant-net/src/codec.rs | 38 - distant-net/src/codec/plain.rs | 193 -- distant-net/src/codec/xchacha20poly1305.rs | 269 -- distant-net/src/common.rs | 20 + distant-net/src/{ => common}/any.rs | 0 distant-net/src/common/authentication.rs | 10 + .../common/authentication/authenticator.rs | 672 +++++ .../src/common/authentication/handler.rs | 343 +++ .../common/authentication/handler/methods.rs | 33 + .../authentication/handler/methods/prompt.rs | 88 + .../handler/methods/static_key.rs | 171 ++ .../src/common/authentication/keychain.rs | 156 ++ .../src/common/authentication/methods.rs | 376 +++ .../src/common/authentication/methods/none.rs | 32 + .../authentication/methods/static_key.rs | 129 + distant-net/src/common/authentication/msg.rs | 216 ++ distant-net/src/common/connection.rs | 1291 ++++++++++ .../src/common}/destination.rs | 19 +- .../src/common}/destination/host.rs | 4 +- .../src/common}/destination/parser.rs | 0 distant-net/src/{ => common}/listener.rs | 0 .../src/{ => common}/listener/mapped.rs | 2 +- distant-net/src/{ => common}/listener/mpsc.rs | 2 +- .../src/{ => common}/listener/oneshot.rs | 7 +- distant-net/src/{ => common}/listener/tcp.rs | 25 +- distant-net/src/{ => common}/listener/unix.rs | 27 +- .../src/{ => common}/listener/windows.rs | 25 +- .../data => distant-net/src/common}/map.rs | 11 +- distant-net/src/common/packet.rs | 628 +++++ .../src/{ => common}/packet/request.rs | 47 +- .../src/{ => common}/packet/response.rs | 61 +- distant-net/src/{ => common}/port.rs | 0 distant-net/src/common/transport.rs | 629 +++++ distant-net/src/common/transport/framed.rs | 2237 +++++++++++++++++ .../src/common/transport/framed/backup.rs | 201 ++ .../src/common/transport/framed/codec.rs | 68 + .../common/transport/framed/codec/chain.rs | 160 ++ .../transport/framed/codec/compression.rs | 263 ++ .../transport/framed/codec/encryption.rs | 255 ++ .../transport/framed/codec/encryption/key.rs | 318 +++ .../common/transport/framed/codec/plain.rs | 22 + .../transport/framed/codec/predicate.rs | 180 ++ .../transport/framed/exchange.rs} | 24 +- .../transport/framed/exchange}/pkb.rs | 0 .../transport/framed/exchange}/salt.rs | 0 .../src/common/transport/framed/frame.rs | 343 +++ .../src/common/transport/framed/handshake.rs | 57 + distant-net/src/common/transport/inmemory.rs | 512 ++++ distant-net/src/common/transport/tcp.rs | 222 ++ distant-net/src/common/transport/test.rs | 48 + distant-net/src/common/transport/unix.rs | 216 ++ distant-net/src/common/transport/windows.rs | 186 ++ .../src/common/transport/windows/pipe.rs | 92 + distant-net/src/{ => common}/utils.rs | 82 +- distant-net/src/id.rs | 2 - distant-net/src/key.rs | 100 - distant-net/src/lib.rs | 29 +- {distant-core => distant-net}/src/manager.rs | 0 distant-net/src/manager/client.rs | 626 +++++ distant-net/src/manager/client/channel.rs | 174 ++ .../src/manager/data.rs | 9 +- .../src/manager/data/capabilities.rs | 0 .../src/manager/data/info.rs | 3 +- .../src/manager/data/list.rs | 10 +- .../src/manager/data/request.rs | 46 +- .../src/manager/data/response.rs | 43 +- distant-net/src/manager/server.rs | 584 +++++ .../src/manager/server/authentication.rs | 103 + .../src/manager/server/config.rs | 7 +- distant-net/src/manager/server/connection.rs | 218 ++ distant-net/src/manager/server/handler.rs | 312 +++ distant-net/src/packet.rs | 254 -- distant-net/src/server.rs | 433 +++- .../src/{client/ext.rs => server/builder.rs} | 9 +- distant-net/src/server/builder/tcp.rs | 102 + distant-net/src/server/builder/unix.rs | 108 + distant-net/src/server/builder/windows.rs | 116 + distant-net/src/server/config.rs | 16 +- distant-net/src/server/connection.rs | 766 +++++- distant-net/src/server/context.rs | 20 +- distant-net/src/server/ext.rs | 440 ---- distant-net/src/server/ext/tcp.rs | 94 - distant-net/src/server/ext/unix.rs | 97 - distant-net/src/server/ext/windows.rs | 109 - distant-net/src/server/ref.rs | 18 +- distant-net/src/server/ref/tcp.rs | 6 +- distant-net/src/server/ref/unix.rs | 6 +- distant-net/src/server/ref/windows.rs | 6 +- distant-net/src/server/reply.rs | 2 +- distant-net/src/server/shutdown_timer.rs | 96 + distant-net/src/server/state.rs | 20 +- distant-net/src/transport.rs | 112 - distant-net/src/transport/framed.rs | 215 -- distant-net/src/transport/framed/read.rs | 115 - distant-net/src/transport/framed/test.rs | 12 - distant-net/src/transport/framed/write.rs | 72 - distant-net/src/transport/inmemory.rs | 225 -- distant-net/src/transport/inmemory/read.rs | 249 -- distant-net/src/transport/inmemory/write.rs | 147 -- distant-net/src/transport/mpsc.rs | 66 - distant-net/src/transport/mpsc/read.rs | 22 - distant-net/src/transport/mpsc/write.rs | 25 - distant-net/src/transport/router.rs | 370 --- distant-net/src/transport/tcp.rs | 196 -- distant-net/src/transport/unix.rs | 187 -- distant-net/src/transport/untyped.rs | 61 - distant-net/src/transport/windows.rs | 202 -- distant-net/src/transport/windows/pipe.rs | 101 - distant-net/tests/auth.rs | 169 -- distant-net/tests/lib.rs | 1 - distant-net/tests/manager_tests.rs | 125 + distant-net/tests/typed_tests.rs | 70 + distant-net/tests/untyped_tests.rs | 112 + distant-ssh2/Cargo.toml | 7 +- distant-ssh2/src/api.rs | 6 +- distant-ssh2/src/lib.rs | 74 +- distant-ssh2/src/process.rs | 4 +- distant-ssh2/tests/ssh2/client.rs | 113 +- distant-ssh2/tests/ssh2/launched.rs | 113 +- distant-ssh2/tests/ssh2/ssh.rs | 3 +- distant-ssh2/tests/sshd/mod.rs | 26 +- src/cli/cache.rs | 2 +- src/cli/client.rs | 357 ++- src/cli/commands/client.rs | 256 +- src/cli/commands/client/format.rs | 2 +- src/cli/commands/generate.rs | 2 +- src/cli/commands/manager.rs | 90 +- src/cli/commands/manager/handlers.rs | 150 +- src/cli/commands/server.rs | 58 +- src/cli/manager.rs | 48 +- src/config/client/connect.rs | 2 +- src/config/client/launch.rs | 2 +- src/config/manager.rs | 2 +- src/config/server/listen.rs | 7 +- tests/cli/action/capabilities.rs | 1 + tests/cli/action/copy.rs | 9 +- tests/cli/action/dir_create.rs | 9 +- tests/cli/action/dir_read.rs | 15 +- tests/cli/action/exists.rs | 6 +- tests/cli/action/file_append.rs | 6 +- tests/cli/action/file_append_text.rs | 6 +- tests/cli/action/file_read.rs | 6 +- tests/cli/action/file_read_text.rs | 6 +- tests/cli/action/file_write.rs | 6 +- tests/cli/action/file_write_text.rs | 6 +- tests/cli/action/metadata.rs | 15 +- tests/cli/action/proc_spawn.rs | 18 +- tests/cli/action/remove.rs | 12 +- tests/cli/action/rename.rs | 9 +- tests/cli/action/search.rs | 1 + tests/cli/action/system_info.rs | 1 + tests/cli/action/watch.rs | 9 +- tests/cli/fixtures.rs | 269 +- tests/cli/manager/capabilities.rs | 7 +- tests/cli/repl/capabilities.rs | 9 +- tests/cli/repl/copy.rs | 29 +- tests/cli/repl/dir_create.rs | 29 +- tests/cli/repl/dir_read.rs | 47 +- tests/cli/repl/exists.rs | 19 +- tests/cli/repl/file_append.rs | 20 +- tests/cli/repl/file_append_text.rs | 20 +- tests/cli/repl/file_read.rs | 20 +- tests/cli/repl/file_read_text.rs | 20 +- tests/cli/repl/file_write.rs | 20 +- tests/cli/repl/file_write_text.rs | 20 +- tests/cli/repl/metadata.rs | 72 +- tests/cli/repl/proc_spawn.rs | 79 +- tests/cli/repl/remove.rs | 38 +- tests/cli/repl/rename.rs | 29 +- tests/cli/repl/search.rs | 19 +- tests/cli/repl/system_info.rs | 10 +- tests/cli/repl/watch.rs | 78 +- 237 files changed, 18594 insertions(+), 10120 deletions(-) delete mode 100644 distant-core/src/manager/client.rs delete mode 100644 distant-core/src/manager/client/config.rs delete mode 100644 distant-core/src/manager/client/ext.rs delete mode 100644 distant-core/src/manager/client/ext/tcp.rs delete mode 100644 distant-core/src/manager/client/ext/unix.rs delete mode 100644 distant-core/src/manager/client/ext/windows.rs delete mode 100644 distant-core/src/manager/data/id.rs delete mode 100644 distant-core/src/manager/server.rs delete mode 100644 distant-core/src/manager/server/connection.rs delete mode 100644 distant-core/src/manager/server/ext.rs delete mode 100644 distant-core/src/manager/server/ext/tcp.rs delete mode 100644 distant-core/src/manager/server/ext/unix.rs delete mode 100644 distant-core/src/manager/server/ext/windows.rs delete mode 100644 distant-core/src/manager/server/handler.rs delete mode 100644 distant-core/src/manager/server/ref.rs delete mode 100644 distant-core/tests/manager_tests.rs delete mode 100644 distant-core/tests/stress/utils.rs delete mode 100644 distant-net/src/auth.rs delete mode 100644 distant-net/src/auth/client.rs delete mode 100644 distant-net/src/auth/server.rs create mode 100644 distant-net/src/client/builder.rs create mode 100644 distant-net/src/client/builder/tcp.rs create mode 100644 distant-net/src/client/builder/unix.rs create mode 100644 distant-net/src/client/builder/windows.rs delete mode 100644 distant-net/src/client/ext/tcp.rs delete mode 100644 distant-net/src/client/ext/unix.rs delete mode 100644 distant-net/src/client/ext/windows.rs create mode 100644 distant-net/src/client/reconnect.rs create mode 100644 distant-net/src/client/shutdown.rs delete mode 100644 distant-net/src/codec.rs delete mode 100644 distant-net/src/codec/plain.rs delete mode 100644 distant-net/src/codec/xchacha20poly1305.rs create mode 100644 distant-net/src/common.rs rename distant-net/src/{ => common}/any.rs (100%) create mode 100644 distant-net/src/common/authentication.rs create mode 100644 distant-net/src/common/authentication/authenticator.rs create mode 100644 distant-net/src/common/authentication/handler.rs create mode 100644 distant-net/src/common/authentication/handler/methods.rs create mode 100644 distant-net/src/common/authentication/handler/methods/prompt.rs create mode 100644 distant-net/src/common/authentication/handler/methods/static_key.rs create mode 100644 distant-net/src/common/authentication/keychain.rs create mode 100644 distant-net/src/common/authentication/methods.rs create mode 100644 distant-net/src/common/authentication/methods/none.rs create mode 100644 distant-net/src/common/authentication/methods/static_key.rs create mode 100644 distant-net/src/common/authentication/msg.rs create mode 100644 distant-net/src/common/connection.rs rename {distant-core/src/manager/data => distant-net/src/common}/destination.rs (92%) rename {distant-core/src/manager/data => distant-net/src/common}/destination/host.rs (99%) rename {distant-core/src/manager/data => distant-net/src/common}/destination/parser.rs (100%) rename distant-net/src/{ => common}/listener.rs (100%) rename distant-net/src/{ => common}/listener/mapped.rs (97%) rename distant-net/src/{ => common}/listener/mpsc.rs (97%) rename distant-net/src/{ => common}/listener/oneshot.rs (96%) rename distant-net/src/{ => common}/listener/tcp.rs (90%) rename distant-net/src/{ => common}/listener/unix.rs (92%) rename distant-net/src/{ => common}/listener/windows.rs (89%) rename {distant-core/src/data => distant-net/src/common}/map.rs (97%) create mode 100644 distant-net/src/common/packet.rs rename distant-net/src/{ => common}/packet/request.rs (87%) rename distant-net/src/{ => common}/packet/response.rs (87%) rename distant-net/src/{ => common}/port.rs (100%) create mode 100644 distant-net/src/common/transport.rs create mode 100644 distant-net/src/common/transport/framed.rs create mode 100644 distant-net/src/common/transport/framed/backup.rs create mode 100644 distant-net/src/common/transport/framed/codec.rs create mode 100644 distant-net/src/common/transport/framed/codec/chain.rs create mode 100644 distant-net/src/common/transport/framed/codec/compression.rs create mode 100644 distant-net/src/common/transport/framed/codec/encryption.rs create mode 100644 distant-net/src/common/transport/framed/codec/encryption/key.rs create mode 100644 distant-net/src/common/transport/framed/codec/plain.rs create mode 100644 distant-net/src/common/transport/framed/codec/predicate.rs rename distant-net/src/{auth/handshake.rs => common/transport/framed/exchange.rs} (71%) rename distant-net/src/{auth/handshake => common/transport/framed/exchange}/pkb.rs (100%) rename distant-net/src/{auth/handshake => common/transport/framed/exchange}/salt.rs (100%) create mode 100644 distant-net/src/common/transport/framed/frame.rs create mode 100644 distant-net/src/common/transport/framed/handshake.rs create mode 100644 distant-net/src/common/transport/inmemory.rs create mode 100644 distant-net/src/common/transport/tcp.rs create mode 100644 distant-net/src/common/transport/test.rs create mode 100644 distant-net/src/common/transport/unix.rs create mode 100644 distant-net/src/common/transport/windows.rs create mode 100644 distant-net/src/common/transport/windows/pipe.rs rename distant-net/src/{ => common}/utils.rs (73%) delete mode 100644 distant-net/src/id.rs delete mode 100644 distant-net/src/key.rs rename {distant-core => distant-net}/src/manager.rs (100%) create mode 100644 distant-net/src/manager/client.rs create mode 100644 distant-net/src/manager/client/channel.rs rename {distant-core => distant-net}/src/manager/data.rs (69%) rename {distant-core => distant-net}/src/manager/data/capabilities.rs (100%) rename {distant-core => distant-net}/src/manager/data/info.rs (86%) rename {distant-core => distant-net}/src/manager/data/list.rs (80%) rename {distant-core => distant-net}/src/manager/data/request.rs (67%) rename {distant-core => distant-net}/src/manager/data/response.rs (55%) create mode 100644 distant-net/src/manager/server.rs create mode 100644 distant-net/src/manager/server/authentication.rs rename {distant-core => distant-net}/src/manager/server/config.rs (88%) create mode 100644 distant-net/src/manager/server/connection.rs create mode 100644 distant-net/src/manager/server/handler.rs delete mode 100644 distant-net/src/packet.rs rename distant-net/src/{client/ext.rs => server/builder.rs} (99%) create mode 100644 distant-net/src/server/builder/tcp.rs create mode 100644 distant-net/src/server/builder/unix.rs create mode 100644 distant-net/src/server/builder/windows.rs delete mode 100644 distant-net/src/server/ext.rs delete mode 100644 distant-net/src/server/ext/tcp.rs delete mode 100644 distant-net/src/server/ext/unix.rs delete mode 100644 distant-net/src/server/ext/windows.rs create mode 100644 distant-net/src/server/shutdown_timer.rs delete mode 100644 distant-net/src/transport.rs delete mode 100644 distant-net/src/transport/framed.rs delete mode 100644 distant-net/src/transport/framed/read.rs delete mode 100644 distant-net/src/transport/framed/test.rs delete mode 100644 distant-net/src/transport/framed/write.rs delete mode 100644 distant-net/src/transport/inmemory.rs delete mode 100644 distant-net/src/transport/inmemory/read.rs delete mode 100644 distant-net/src/transport/inmemory/write.rs delete mode 100644 distant-net/src/transport/mpsc.rs delete mode 100644 distant-net/src/transport/mpsc/read.rs delete mode 100644 distant-net/src/transport/mpsc/write.rs delete mode 100644 distant-net/src/transport/router.rs delete mode 100644 distant-net/src/transport/tcp.rs delete mode 100644 distant-net/src/transport/unix.rs delete mode 100644 distant-net/src/transport/untyped.rs delete mode 100644 distant-net/src/transport/windows.rs delete mode 100644 distant-net/src/transport/windows/pipe.rs delete mode 100644 distant-net/tests/auth.rs delete mode 100644 distant-net/tests/lib.rs create mode 100644 distant-net/tests/manager_tests.rs create mode 100644 distant-net/tests/typed_tests.rs create mode 100644 distant-net/tests/untyped_tests.rs diff --git a/.config/nextest.toml b/.config/nextest.toml index 400b353..55a0c2a 100644 --- a/.config/nextest.toml +++ b/.config/nextest.toml @@ -1,6 +1,6 @@ [profile.ci] fail-fast = false -retries = 2 +retries = 4 slow-timeout = { period = "60s", terminate-after = 3 } status-level = "fail" final-status-level = "fail" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ad1159c..d57b657 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -168,6 +168,7 @@ jobs: run: cargo build --release - name: Run CLI tests (all features) run: cargo nextest run --profile ci --release --all-features + if: matrix.os != 'windows-latest' ssh-launch-tests: name: "Test ssh launch using Rust ${{ matrix.rust }} on ${{ matrix.os }}" runs-on: ${{ matrix.os }} diff --git a/Cargo.lock b/Cargo.lock index bdc2fe2..fc45825 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + [[package]] name = "aead" version = "0.5.0" @@ -21,6 +27,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "ansi_term" version = "0.12.1" @@ -371,6 +386,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "chrono" +version = "0.4.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f" +dependencies = [ + "iana-time-zone", + "num-integer", + "num-traits", + "winapi", +] + [[package]] name = "cipher" version = "0.4.3" @@ -430,6 +457,16 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + [[package]] name = "combine" version = "4.6.4" @@ -505,6 +542,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if 1.0.0", +] + [[package]] name = "crossbeam-channel" version = "0.5.4" @@ -583,6 +629,50 @@ dependencies = [ "phf 0.11.1", ] +[[package]] +name = "cxx" +version = "1.0.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97abf9f0eca9e52b7f81b945524e76710e6cb2366aead23b7d4fbf72e281f888" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cc32cc5fea1d894b77d269ddb9f192110069a8a9c1f1d441195fba90553dea3" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ca220e4794c934dc6b1207c3b42856ad4c302f2df1712e9f8d2eec5afaacf1f" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b846f081361125bfc8dc9d3940c84e1fd83ba54bbca7b17cd29483c828be0704" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "deltae" version = "0.3.0" @@ -709,7 +799,7 @@ dependencies = [ [[package]] name = "distant" -version = "0.19.0" +version = "0.20.0" dependencies = [ "anyhow", "assert_cmd", @@ -723,6 +813,7 @@ dependencies = [ "directories", "distant-core", "distant-ssh2", + "env_logger", "flexi_logger", "fork", "indoc", @@ -740,6 +831,7 @@ dependencies = [ "tabled", "terminal_size 0.2.1", "termwiz", + "test-log", "tokio", "toml_edit", "which", @@ -750,7 +842,7 @@ dependencies = [ [[package]] name = "distant-core" -version = "0.19.0" +version = "0.20.0" dependencies = [ "assert_fs", "async-trait", @@ -759,7 +851,7 @@ dependencies = [ "clap", "derive_more", "distant-net", - "flexi_logger", + "env_logger", "futures", "grep", "hex", @@ -780,6 +872,7 @@ dependencies = [ "serde_json", "shell-words", "strum", + "test-log", "tokio", "tokio-util", "walkdir", @@ -789,13 +882,15 @@ dependencies = [ [[package]] name = "distant-net" -version = "0.19.0" +version = "0.20.0" dependencies = [ "async-trait", "bytes", "chacha20poly1305", "derive_more", - "futures", + "dyn-clone", + "env_logger", + "flate2", "hex", "hkdf", "log", @@ -807,14 +902,15 @@ dependencies = [ "serde", "serde_bytes", "sha2 0.10.2", + "strum", "tempfile", + "test-log", "tokio", - "tokio-util", ] [[package]] name = "distant-ssh2" -version = "0.19.0" +version = "0.20.0" dependencies = [ "anyhow", "assert_fs", @@ -824,7 +920,7 @@ dependencies = [ "derive_more", "distant-core", "dunce", - "flexi_logger", + "env_logger", "futures", "hex", "indoc", @@ -837,6 +933,7 @@ dependencies = [ "serde", "shell-words", "smol", + "test-log", "tokio", "typed-path", "wezterm-ssh", @@ -927,6 +1024,19 @@ dependencies = [ "encoding_rs", ] +[[package]] +name = "env_logger" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c90bf5f19754d10198ccb95b70664fc925bd1fc090a0fd9a6ebc54acc8cd6272" +dependencies = [ + "atty", + "humantime", + "log", + "regex", + "termcolor", +] + [[package]] name = "err-derive" version = "0.3.1" @@ -1028,21 +1138,31 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +[[package]] +name = "flate2" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f82b0f4c27ad9f8bfd1f3208d882da2b09c301bc1c828fd3a00d0216d2fbbff6" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "flexi_logger" -version = "0.23.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8790f70905b203171c21060222f18f1df5cba07317860215b7880b32aaef290" +checksum = "99659bcfd52cfece972bd00acb9dba7028094d47e699ea8b193b9aaebd5c362b" dependencies = [ "ansi_term", "atty", + "chrono", "glob", "lazy_static", "log", "regex", "rustversion", "thiserror", - "time", ] [[package]] @@ -1390,6 +1510,36 @@ dependencies = [ "digest 0.10.3", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + +[[package]] +name = "iana-time-zone" +version = "0.1.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64c122667b287044802d6ce17ee2ddf13207ed924c712de9a66a5814d5b64765" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "winapi", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0703ae284fc167426161c2e3f1da3ea71d94b21bedbcc9494e92b28e334e3dca" +dependencies = [ + "cxx", + "cxx-build", +] + [[package]] name = "ignore" version = "0.4.18" @@ -1588,6 +1738,15 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "link-cplusplus" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9272ab7b96c9046fbc5bc56c06c117cb639fe2d509df0c421cad82d2915cf369" +dependencies = [ + "cc", +] + [[package]] name = "linux-raw-sys" version = "0.0.46" @@ -1649,6 +1808,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96590ba8f175222643a85693f33d26e9c8a015f599c216509b1a6894af675d34" +dependencies = [ + "adler", +] + [[package]] name = "mio" version = "0.8.3" @@ -1701,9 +1869,9 @@ checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" [[package]] name = "notify" -version = "5.0.0-pre.15" +version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "553f9844ad0b0824605c20fb55a661679782680410abfb1a8144c2e7e437e7a7" +checksum = "ed2c66da08abae1c024c01d635253e402341b4060a12e99b31c7594063bf490a" dependencies = [ "bitflags", "crossbeam-channel", @@ -1738,6 +1906,16 @@ dependencies = [ "syn", ] +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.15" @@ -1757,15 +1935,6 @@ dependencies = [ "libc", ] -[[package]] -name = "num_threads" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44" -dependencies = [ - "libc", -] - [[package]] name = "once_cell" version = "1.13.0" @@ -2502,6 +2671,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "scratch" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8132065adcfd6e02db789d9285a0deb2f3fcb04002865ab67d5fb103533898" + [[package]] name = "sec1" version = "0.3.0" @@ -3019,6 +3194,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "test-log" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f0c854faeb68a048f0f2dc410c5ddae3bf83854ef0e4977d58306a5edef50e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "textwrap" version = "0.15.0" @@ -3054,24 +3240,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "time" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2702e08a7a860f005826c6815dcac101b19b5eb330c27fe4a5928fec1d20ddd" -dependencies = [ - "itoa", - "libc", - "num_threads", - "time-macros", -] - -[[package]] -name = "time-macros" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42657b1a6f4d817cda8e7a0ace261fe0cc946cf3a80314390b22cc61ae080792" - [[package]] name = "tokio" version = "1.20.1" diff --git a/Cargo.toml b/Cargo.toml index d6a7c76..f51f277 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "distant" description = "Operate on a remote computer through file and process manipulation" categories = ["command-line-utilities"] keywords = ["cli"] -version = "0.19.0" +version = "0.20.0" authors = ["Chip Senkbeil "] edition = "2021" homepage = "https://github.com/chipsenkbeil/distant" @@ -32,9 +32,9 @@ clap_complete = "3.2.3" config = { version = "0.13.2", default-features = false, features = ["toml"] } derive_more = { version = "0.99.17", default-features = false, features = ["display", "from", "error", "is_variant"] } dialoguer = { version = "0.10.2", default-features = false } -distant-core = { version = "=0.19.0", path = "distant-core", features = ["clap", "schemars"] } +distant-core = { version = "=0.20.0", path = "distant-core", features = ["clap", "schemars"] } directories = "4.0.1" -flexi_logger = "0.23.0" +flexi_logger = "0.24.1" indoc = "1.0.7" log = "0.4.17" once_cell = "1.13.0" @@ -54,7 +54,7 @@ winsplit = "0.1.0" whoami = "1.2.1" # Optional native SSH functionality -distant-ssh2 = { version = "=0.19.0", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true } +distant-ssh2 = { version = "=0.20.0", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true } [target.'cfg(unix)'.dependencies] fork = "0.1.19" @@ -66,6 +66,8 @@ windows-service = "0.5.0" [dev-dependencies] assert_cmd = "2.0.4" assert_fs = "1.0.7" +env_logger = "0.9.1" indoc = "1.0.7" predicates = "2.1.1" rstest = "0.15.0" +test-log = "0.2.11" diff --git a/distant-core/Cargo.toml b/distant-core/Cargo.toml index 7b33729..a811650 100644 --- a/distant-core/Cargo.toml +++ b/distant-core/Cargo.toml @@ -3,7 +3,7 @@ name = "distant-core" description = "Core library for distant, enabling operation on a remote computer through file and process manipulation" categories = ["network-programming"] keywords = ["api", "async"] -version = "0.19.0" +version = "0.20.0" authors = ["Chip Senkbeil "] edition = "2021" homepage = "https://github.com/chipsenkbeil/distant" @@ -19,13 +19,13 @@ async-trait = "0.1.57" bitflags = "1.3.2" bytes = "1.2.1" derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] } -distant-net = { version = "=0.19.0", path = "../distant-net" } +distant-net = { version = "=0.20.0", path = "../distant-net" } futures = "0.3.21" grep = "0.2.10" hex = "0.4.3" ignore = "0.4.18" log = "0.4.17" -notify = { version = "=5.0.0-pre.15", features = ["serde"] } +notify = { version = "5.0.0", features = ["serde"] } num_cpus = "1.13.1" once_cell = "1.13.0" portable-pty = "0.7.0" @@ -48,7 +48,8 @@ schemars = { version = "0.8.10", optional = true } [dev-dependencies] assert_fs = "1.0.7" -flexi_logger = "0.23.0" +env_logger = "0.9.1" indoc = "1.0.7" predicates = "2.1.1" rstest = "0.15.0" +test-log = "0.2.11" diff --git a/distant-core/src/api.rs b/distant-core/src/api.rs index f7d9b61..8b76c36 100644 --- a/distant-core/src/api.rs +++ b/distant-core/src/api.rs @@ -3,10 +3,11 @@ use crate::{ Capabilities, ChangeKind, DirEntry, Environment, Error, Metadata, ProcessId, PtySize, SearchId, SearchQuery, SystemInfo, }, - ConnectionId, DistantMsg, DistantRequestData, DistantResponseData, + DistantMsg, DistantRequestData, DistantResponseData, }; use async_trait::async_trait; -use distant_net::{Reply, Server, ServerConfig, ServerCtx}; +use distant_net::common::ConnectionId; +use distant_net::server::{ConnectionCtx, Reply, ServerCtx, ServerHandler}; use log::*; use std::{io, path::PathBuf, sync::Arc}; @@ -23,15 +24,15 @@ pub struct DistantCtx { pub local_data: Arc, } -/// Represents a server that leverages an API compliant with `distant` -pub struct DistantApiServer +/// Represents a [`ServerHandler`] that leverages an API compliant with `distant` +pub struct DistantApiServerHandler where T: DistantApi, { api: T, } -impl DistantApiServer +impl DistantApiServerHandler where T: DistantApi, { @@ -40,11 +41,11 @@ where } } -impl DistantApiServer::LocalData> { +impl DistantApiServerHandler::LocalData> { /// Creates a new server using the [`LocalDistantApi`] implementation - pub fn local(config: ServerConfig) -> io::Result { + pub fn local() -> io::Result { Ok(Self { - api: LocalDistantApi::initialize(config)?, + api: LocalDistantApi::initialize()?, }) } } @@ -63,15 +64,12 @@ fn unsupported(label: &str) -> io::Result { pub trait DistantApi { type LocalData: Send + Sync; - /// Returns config associated with API server - fn config(&self) -> ServerConfig { - ServerConfig::default() - } - /// Invoked whenever a new connection is established, providing a mutable reference to the /// newly-created local data. This is a way to support modifying local data before it is used. #[allow(unused_variables)] - async fn on_accept(&self, local_data: &mut Self::LocalData) {} + async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { + Ok(()) + } /// Retrieves information about the server's capabilities. /// @@ -420,7 +418,7 @@ pub trait DistantApi { } #[async_trait] -impl Server for DistantApiServer +impl ServerHandler for DistantApiServerHandler where T: DistantApi + Send + Sync, D: Send + Sync, @@ -429,14 +427,9 @@ where type Response = DistantMsg; type LocalData = D; - /// Overridden to leverage [`DistantApi`] implementation of `config` - fn config(&self) -> ServerConfig { - T::config(&self.api) - } - /// Overridden to leverage [`DistantApi`] implementation of `on_accept` - async fn on_accept(&self, local_data: &mut Self::LocalData) { - T::on_accept(&self.api, local_data).await + async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { + T::on_accept(&self.api, ctx).await } async fn on_request(&self, ctx: ServerCtx) { @@ -518,7 +511,7 @@ where /// Processes an incoming request async fn handle_request( - server: &DistantApiServer, + server: &DistantApiServerHandler, ctx: DistantCtx, request: DistantRequestData, ) -> DistantResponseData diff --git a/distant-core/src/api/local.rs b/distant-core/src/api/local.rs index c744330..411bb28 100644 --- a/distant-core/src/api/local.rs +++ b/distant-core/src/api/local.rs @@ -6,7 +6,7 @@ use crate::{ DistantApi, DistantCtx, }; use async_trait::async_trait; -use distant_net::ServerConfig; +use distant_net::server::ConnectionCtx; use log::*; use std::{ io, @@ -26,15 +26,13 @@ use state::*; /// impementation of the API instead of a proxy to another machine as seen with /// implementations on top of SSH and other protocol pub struct LocalDistantApi { - config: ServerConfig, state: GlobalState, } impl LocalDistantApi { /// Initialize the api instance - pub fn initialize(config: ServerConfig) -> io::Result { + pub fn initialize() -> io::Result { Ok(Self { - config, state: GlobalState::initialize()?, }) } @@ -44,14 +42,11 @@ impl LocalDistantApi { impl DistantApi for LocalDistantApi { type LocalData = ConnectionState; - fn config(&self) -> ServerConfig { - self.config.clone() - } - /// Injects the global channels into the local connection - async fn on_accept(&self, local_data: &mut Self::LocalData) { - local_data.process_channel = self.state.process.clone_channel(); - local_data.watcher_channel = self.state.watcher.clone_channel(); + async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { + ctx.local_data.process_channel = self.state.process.clone_channel(); + ctx.local_data.watcher_channel = self.state.watcher.clone_channel(); + Ok(()) } async fn capabilities(&self, ctx: DistantCtx) -> io::Result { @@ -511,10 +506,11 @@ mod tests { use super::*; use crate::data::DistantResponseData; use assert_fs::prelude::*; - use distant_net::Reply; + use distant_net::server::Reply; use once_cell::sync::Lazy; use predicates::prelude::*; use std::{sync::Arc, time::Duration}; + use test_log::test; use tokio::sync::mpsc; static TEMP_SCRIPT_DIR: Lazy = @@ -583,12 +579,21 @@ mod tests { DistantCtx, mpsc::Receiver, ) { - let api = LocalDistantApi::initialize(Default::default()).unwrap(); + let api = LocalDistantApi::initialize().unwrap(); let (reply, rx) = make_reply(buffer); + let connection_id = rand::random(); let mut local_data = ConnectionState::default(); - DistantApi::on_accept(&api, &mut local_data).await; + DistantApi::on_accept( + &api, + ConnectionCtx { + connection_id, + local_data: &mut local_data, + }, + ) + .await + .unwrap(); let ctx = DistantCtx { - connection_id: rand::random(), + connection_id, reply, local_data: Arc::new(local_data), }; @@ -605,7 +610,7 @@ mod tests { (Box::new(tx), rx) } - #[tokio::test] + #[test(tokio::test)] async fn read_file_should_fail_if_file_missing() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -614,7 +619,7 @@ mod tests { let _ = api.read_file(ctx, path).await.unwrap_err(); } - #[tokio::test] + #[test(tokio::test)] async fn read_file_should_send_blob_with_file_contents() { let (api, ctx, _rx) = setup(1).await; @@ -626,7 +631,7 @@ mod tests { assert_eq!(bytes, b"some file contents"); } - #[tokio::test] + #[test(tokio::test)] async fn read_file_text_should_send_error_if_fails_to_read_file() { let (api, ctx, _rx) = setup(1).await; @@ -636,7 +641,7 @@ mod tests { let _ = api.read_file_text(ctx, path).await.unwrap_err(); } - #[tokio::test] + #[test(tokio::test)] async fn read_file_text_should_send_text_with_file_contents() { let (api, ctx, _rx) = setup(1).await; @@ -651,7 +656,7 @@ mod tests { assert_eq!(text, "some file contents"); } - #[tokio::test] + #[test(tokio::test)] async fn write_file_should_send_error_if_fails_to_write_file() { let (api, ctx, _rx) = setup(1).await; @@ -669,7 +674,7 @@ mod tests { file.assert(predicate::path::missing()); } - #[tokio::test] + #[test(tokio::test)] async fn write_file_should_send_ok_when_successful() { let (api, ctx, _rx) = setup(1).await; @@ -687,7 +692,7 @@ mod tests { file.assert("some text"); } - #[tokio::test] + #[test(tokio::test)] async fn write_file_text_should_send_error_if_fails_to_write_file() { let (api, ctx, _rx) = setup(1).await; @@ -704,7 +709,7 @@ mod tests { file.assert(predicate::path::missing()); } - #[tokio::test] + #[test(tokio::test)] async fn write_file_text_should_send_ok_when_successful() { let (api, ctx, _rx) = setup(1).await; @@ -722,7 +727,7 @@ mod tests { file.assert("some text"); } - #[tokio::test] + #[test(tokio::test)] async fn append_file_should_send_error_if_fails_to_create_file() { let (api, ctx, _rx) = setup(1).await; @@ -743,7 +748,7 @@ mod tests { file.assert(predicate::path::missing()); } - #[tokio::test] + #[test(tokio::test)] async fn append_file_should_create_file_if_missing() { let (api, ctx, _rx) = setup(1).await; @@ -767,7 +772,7 @@ mod tests { file.assert("some extra contents"); } - #[tokio::test] + #[test(tokio::test)] async fn append_file_should_send_ok_when_successful() { let (api, ctx, _rx) = setup(1).await; @@ -791,7 +796,7 @@ mod tests { file.assert("some file contentssome extra contents"); } - #[tokio::test] + #[test(tokio::test)] async fn append_file_text_should_send_error_if_fails_to_create_file() { let (api, ctx, _rx) = setup(1).await; @@ -813,7 +818,7 @@ mod tests { file.assert(predicate::path::missing()); } - #[tokio::test] + #[test(tokio::test)] async fn append_file_text_should_create_file_if_missing() { let (api, ctx, _rx) = setup(1).await; @@ -837,7 +842,7 @@ mod tests { file.assert("some extra contents"); } - #[tokio::test] + #[test(tokio::test)] async fn append_file_text_should_send_ok_when_successful() { let (api, ctx, _rx) = setup(1).await; @@ -861,7 +866,7 @@ mod tests { file.assert("some file contentssome extra contents"); } - #[tokio::test] + #[test(tokio::test)] async fn dir_read_should_send_error_if_directory_does_not_exist() { let (api, ctx, _rx) = setup(1).await; @@ -902,7 +907,7 @@ mod tests { root_dir } - #[tokio::test] + #[test(tokio::test)] async fn dir_read_should_support_depth_limits() { let (api, ctx, _rx) = setup(1).await; @@ -936,7 +941,7 @@ mod tests { assert_eq!(entries[2].depth, 1); } - #[tokio::test] + #[test(tokio::test)] async fn dir_read_should_support_unlimited_depth_using_zero() { let (api, ctx, _rx) = setup(1).await; @@ -974,7 +979,7 @@ mod tests { assert_eq!(entries[3].depth, 2); } - #[tokio::test] + #[test(tokio::test)] async fn dir_read_should_support_including_directory_in_returned_entries() { let (api, ctx, _rx) = setup(1).await; @@ -1013,7 +1018,7 @@ mod tests { assert_eq!(entries[3].depth, 1); } - #[tokio::test] + #[test(tokio::test)] async fn dir_read_should_support_returning_absolute_paths() { let (api, ctx, _rx) = setup(1).await; @@ -1048,7 +1053,7 @@ mod tests { assert_eq!(entries[2].depth, 1); } - #[tokio::test] + #[test(tokio::test)] async fn dir_read_should_support_returning_canonicalized_paths() { let (api, ctx, _rx) = setup(1).await; @@ -1083,7 +1088,7 @@ mod tests { assert_eq!(entries[2].depth, 1); } - #[tokio::test] + #[test(tokio::test)] async fn create_dir_should_send_error_if_fails() { let (api, ctx, _rx) = setup(1).await; @@ -1101,7 +1106,7 @@ mod tests { assert!(!path.exists(), "Path unexpectedly exists"); } - #[tokio::test] + #[test(tokio::test)] async fn create_dir_should_send_ok_when_successful() { let (api, ctx, _rx) = setup(1).await; let root_dir = setup_dir().await; @@ -1115,7 +1120,7 @@ mod tests { assert!(path.exists(), "Directory not created"); } - #[tokio::test] + #[test(tokio::test)] async fn create_dir_should_support_creating_multiple_dir_components() { let (api, ctx, _rx) = setup(1).await; let root_dir = setup_dir().await; @@ -1129,7 +1134,7 @@ mod tests { assert!(path.exists(), "Directory not created"); } - #[tokio::test] + #[test(tokio::test)] async fn remove_should_send_error_on_failure() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1144,7 +1149,7 @@ mod tests { file.assert(predicate::path::missing()); } - #[tokio::test] + #[test(tokio::test)] async fn remove_should_support_deleting_a_directory() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1159,7 +1164,7 @@ mod tests { dir.assert(predicate::path::missing()); } - #[tokio::test] + #[test(tokio::test)] async fn remove_should_delete_nonempty_directory_if_force_is_true() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1175,7 +1180,7 @@ mod tests { dir.assert(predicate::path::missing()); } - #[tokio::test] + #[test(tokio::test)] async fn remove_should_support_deleting_a_single_file() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1190,7 +1195,7 @@ mod tests { file.assert(predicate::path::missing()); } - #[tokio::test] + #[test(tokio::test)] async fn copy_should_send_error_on_failure() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1206,7 +1211,7 @@ mod tests { dst.assert(predicate::path::missing()); } - #[tokio::test] + #[test(tokio::test)] async fn copy_should_support_copying_an_entire_directory() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1230,7 +1235,7 @@ mod tests { dst_file.assert(predicate::path::eq_file(src_file.path())); } - #[tokio::test] + #[test(tokio::test)] async fn copy_should_support_copying_an_empty_directory() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1247,7 +1252,7 @@ mod tests { dst.assert(predicate::path::is_dir()); } - #[tokio::test] + #[test(tokio::test)] async fn copy_should_support_copying_a_directory_that_only_contains_directories() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1271,7 +1276,7 @@ mod tests { dst_dir.assert(predicate::path::is_dir().name("dst/dir")); } - #[tokio::test] + #[test(tokio::test)] async fn copy_should_support_copying_a_single_file() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1288,7 +1293,7 @@ mod tests { dst.assert(predicate::path::eq_file(src.path())); } - #[tokio::test] + #[test(tokio::test)] async fn rename_should_fail_if_path_missing() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1304,7 +1309,7 @@ mod tests { dst.assert(predicate::path::missing()); } - #[tokio::test] + #[test(tokio::test)] async fn rename_should_support_renaming_an_entire_directory() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1328,7 +1333,7 @@ mod tests { dst_file.assert("some contents"); } - #[tokio::test] + #[test(tokio::test)] async fn rename_should_support_renaming_a_single_file() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1375,7 +1380,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn watch_should_support_watching_a_single_file() { // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. let (api, ctx, mut rx) = setup(100).await; @@ -1408,7 +1413,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn watch_should_support_watching_a_directory_recursively() { // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. let (api, ctx, mut rx) = setup(100).await; @@ -1485,7 +1490,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn watch_should_report_changes_using_the_ctx_replies() { // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. let (api, ctx_1, mut rx_1) = setup(100).await; @@ -1558,7 +1563,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn exists_should_send_true_if_path_exists() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1569,7 +1574,7 @@ mod tests { assert!(exists, "Expected exists to be true, but was false"); } - #[tokio::test] + #[test(tokio::test)] async fn exists_should_send_false_if_path_does_not_exist() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1579,7 +1584,7 @@ mod tests { assert!(!exists, "Expected exists to be false, but was true"); } - #[tokio::test] + #[test(tokio::test)] async fn metadata_should_send_error_on_failure() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1596,7 +1601,7 @@ mod tests { .unwrap_err(); } - #[tokio::test] + #[test(tokio::test)] async fn metadata_should_send_back_metadata_on_file_if_exists() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1630,7 +1635,7 @@ mod tests { } #[cfg(unix)] - #[tokio::test] + #[test(tokio::test)] async fn metadata_should_include_unix_specific_metadata_on_unix_platform() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1660,7 +1665,7 @@ mod tests { } #[cfg(windows)] - #[tokio::test] + #[test(tokio::test)] async fn metadata_should_include_windows_specific_metadata_on_windows_platform() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1689,7 +1694,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn metadata_should_send_back_metadata_on_dir_if_exists() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1721,7 +1726,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn metadata_should_send_back_metadata_on_symlink_if_exists() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1756,7 +1761,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn metadata_should_include_canonicalized_path_if_flag_specified() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1791,7 +1796,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified() { let (api, ctx, _rx) = setup(1).await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1826,7 +1831,7 @@ mod tests { // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ - #[tokio::test] + #[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_spawn_should_send_error_on_failure() { let (api, ctx, _rx) = setup(1).await; @@ -1846,7 +1851,7 @@ mod tests { // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ - #[tokio::test] + #[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_spawn_should_return_id_of_spawned_process() { let (api, ctx, _rx) = setup(1).await; @@ -1872,7 +1877,7 @@ mod tests { // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ - #[tokio::test] + #[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_spawn_should_send_back_stdout_periodically_when_available() { let (api, ctx, mut rx) = setup(1).await; @@ -1937,7 +1942,7 @@ mod tests { // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ - #[tokio::test] + #[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_spawn_should_send_back_stderr_periodically_when_available() { let (api, ctx, mut rx) = setup(1).await; @@ -2002,7 +2007,7 @@ mod tests { // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ - #[tokio::test] + #[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_spawn_should_send_done_signal_when_completed() { let (api, ctx, mut rx) = setup(1).await; @@ -2033,7 +2038,7 @@ mod tests { // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ - #[tokio::test] + #[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_spawn_should_clear_process_from_state_when_killed() { let (api, ctx_1, mut rx) = setup(1).await; @@ -2074,7 +2079,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn proc_kill_should_fail_if_given_non_existent_process() { let (api, ctx, _rx) = setup(1).await; @@ -2082,7 +2087,7 @@ mod tests { let _ = api.proc_kill(ctx, 0xDEADBEEF).await.unwrap_err(); } - #[tokio::test] + #[test(tokio::test)] async fn proc_stdin_should_fail_if_given_non_existent_process() { let (api, ctx, _rx) = setup(1).await; @@ -2095,7 +2100,7 @@ mod tests { // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ - #[tokio::test] + #[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_stdin_should_send_stdin_to_process() { let (api, ctx_1, mut rx) = setup(1).await; @@ -2141,7 +2146,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn system_info_should_return_system_info_based_on_binary() { let (api, ctx, _rx) = setup(1).await; diff --git a/distant-core/src/api/local/process/pty.rs b/distant-core/src/api/local/process/pty.rs index 6320752..73bf942 100644 --- a/distant-core/src/api/local/process/pty.rs +++ b/distant-core/src/api/local/process/pty.rs @@ -3,7 +3,7 @@ use super::{ ProcessPty, PtySize, WaitRx, }; use crate::{ - constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_MILLIS}, + constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_DURATION}, data::Environment, }; use log::*; @@ -150,8 +150,7 @@ impl PtyProcess { break; } _ => { - tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS)) - .await; + tokio::time::sleep(READ_PAUSE_DURATION).await; continue; } } diff --git a/distant-core/src/api/local/process/simple/tasks.rs b/distant-core/src/api/local/process/simple/tasks.rs index 7d8f8ae..6aa06ad 100644 --- a/distant-core/src/api/local/process/simple/tasks.rs +++ b/distant-core/src/api/local/process/simple/tasks.rs @@ -1,4 +1,4 @@ -use crate::constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_MILLIS}; +use crate::constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_DURATION}; use std::io; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, @@ -34,13 +34,13 @@ where // Pause to allow buffer to fill up a little bit, avoiding // spamming with a lot of smaller responses - tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS)).await; + tokio::time::sleep(READ_PAUSE_DURATION).await; } Ok(_) => return Ok(()), Err(x) if x.kind() == io::ErrorKind::WouldBlock => { // Pause to allow buffer to fill up a little bit, avoiding // spamming with a lot of smaller responses - tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS)).await; + tokio::time::sleep(READ_PAUSE_DURATION).await; } Err(x) => return Err(x), } diff --git a/distant-core/src/api/local/state.rs b/distant-core/src/api/local/state.rs index 98484e2..4bccae5 100644 --- a/distant-core/src/api/local/state.rs +++ b/distant-core/src/api/local/state.rs @@ -1,7 +1,5 @@ -use crate::{ - data::{ProcessId, SearchId}, - ConnectionId, -}; +use crate::data::{ProcessId, SearchId}; +use distant_net::common::ConnectionId; use std::{io, path::PathBuf}; mod process; diff --git a/distant-core/src/api/local/state/process.rs b/distant-core/src/api/local/state/process.rs index 1fe3319..3638177 100644 --- a/distant-core/src/api/local/state/process.rs +++ b/distant-core/src/api/local/state/process.rs @@ -1,5 +1,5 @@ use crate::data::{DistantResponseData, Environment, ProcessId, PtySize}; -use distant_net::Reply; +use distant_net::server::Reply; use std::{collections::HashMap, io, ops::Deref, path::PathBuf}; use tokio::{ sync::{mpsc, oneshot}, diff --git a/distant-core/src/api/local/state/process/instance.rs b/distant-core/src/api/local/state/process/instance.rs index 4b558cf..38b7aaf 100644 --- a/distant-core/src/api/local/state/process/instance.rs +++ b/distant-core/src/api/local/state/process/instance.rs @@ -4,7 +4,7 @@ use crate::{ }, data::{DistantResponseData, Environment, ProcessId, PtySize}, }; -use distant_net::Reply; +use distant_net::server::Reply; use log::*; use std::{future::Future, io, path::PathBuf}; use tokio::task::JoinHandle; @@ -174,12 +174,9 @@ async fn stdout_task( loop { match stdout.recv().await { Ok(Some(data)) => { - if let Err(x) = reply + reply .send(DistantResponseData::ProcStdout { id, data }) - .await - { - return Err(x); - } + .await?; } Ok(None) => return Ok(()), Err(x) => return Err(x), @@ -195,12 +192,9 @@ async fn stderr_task( loop { match stderr.recv().await { Ok(Some(data)) => { - if let Err(x) = reply + reply .send(DistantResponseData::ProcStderr { id, data }) - .await - { - return Err(x); - } + .await?; } Ok(None) => return Ok(()), Err(x) => return Err(x), diff --git a/distant-core/src/api/local/state/search.rs b/distant-core/src/api/local/state/search.rs index fbf3939..005a917 100644 --- a/distant-core/src/api/local/state/search.rs +++ b/distant-core/src/api/local/state/search.rs @@ -3,7 +3,7 @@ use crate::data::{ SearchQueryMatchData, SearchQueryOptions, SearchQueryPathMatch, SearchQuerySubmatch, SearchQueryTarget, }; -use distant_net::Reply; +use distant_net::server::Reply; use grep::{ matcher::Matcher, regex::{RegexMatcher, RegexMatcherBuilder}, @@ -764,6 +764,7 @@ mod tests { use crate::data::{FileType, SearchQueryCondition, SearchQueryMatchData}; use assert_fs::prelude::*; use std::path::PathBuf; + use test_log::test; fn make_path(path: &str) -> PathBuf { use std::path::MAIN_SEPARATOR; @@ -791,7 +792,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn should_send_event_when_query_finished() { let root = setup_dir(Vec::new()); @@ -816,7 +817,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_send_all_matches_at_once_by_default() { let root = setup_dir(vec![ ("path/to/file1.txt", ""), @@ -893,7 +894,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_support_targeting_paths() { let root = setup_dir(vec![ ("path/to/file1.txt", ""), @@ -971,7 +972,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_support_targeting_contents() { let root = setup_dir(vec![ ("path/to/file1.txt", "some\nlines of text in\na\nfile"), @@ -1047,7 +1048,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_support_multiple_submatches() { let root = setup_dir(vec![("path/to/file.txt", "aa ab ac\nba bb bc\nca cb cc")]); @@ -1139,7 +1140,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_send_paginated_results_if_specified() { let root = setup_dir(vec![ ("path/to/file1.txt", "some\nlines of text in\na\nfile"), @@ -1235,7 +1236,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_send_maximum_of_limit_results_if_specified() { let root = setup_dir(vec![ ("path/to/file1.txt", "some\nlines of text in\na\nfile"), @@ -1272,7 +1273,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_send_maximum_of_limit_results_with_pagination_if_specified() { let root = setup_dir(vec![ ("path/to/file1.txt", "some\nlines of text in\na\nfile"), @@ -1313,7 +1314,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_traverse_no_deeper_than_max_depth_if_specified() { let root = setup_dir(vec![ ("path/to/file1.txt", ""), @@ -1409,7 +1410,7 @@ mod tests { .await; } - #[tokio::test] + #[test(tokio::test)] async fn should_filter_searched_paths_to_only_those_that_match_include_regex() { let root = setup_dir(vec![ ("path/to/file1.txt", "some\nlines of text in\na\nfile"), @@ -1464,7 +1465,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_filter_searched_paths_to_only_those_that_do_not_match_exclude_regex() { let root = setup_dir(vec![ ("path/to/file1.txt", "some\nlines of text in\na\nfile"), @@ -1532,7 +1533,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_return_binary_match_data_if_match_is_not_utf8_but_path_is_explicit() { let root = assert_fs::TempDir::new().unwrap(); let bin_file = root.child(make_path("file.bin")); @@ -1587,7 +1588,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_not_return_binary_match_data_if_match_is_not_utf8_and_not_explicit_path() { let root = assert_fs::TempDir::new().unwrap(); let bin_file = root.child(make_path("file.bin")); @@ -1621,7 +1622,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_filter_searched_paths_to_only_those_are_an_allowed_file_type() { let root = assert_fs::TempDir::new().unwrap(); let file = root.child(make_path("file")); @@ -1708,7 +1709,7 @@ mod tests { .await; } - #[tokio::test] + #[test(tokio::test)] async fn should_follow_not_symbolic_links_if_specified_in_options() { let root = assert_fs::TempDir::new().unwrap(); @@ -1766,7 +1767,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_follow_symbolic_links_if_specified_in_options() { let root = assert_fs::TempDir::new().unwrap(); @@ -1825,7 +1826,7 @@ mod tests { assert_eq!(rx.recv().await, None); } - #[tokio::test] + #[test(tokio::test)] async fn should_support_being_supplied_more_than_one_path() { let root = setup_dir(vec![ ("path/to/file1.txt", "some\nlines of text in\na\nfile"), diff --git a/distant-core/src/api/local/state/watcher.rs b/distant-core/src/api/local/state/watcher.rs index 6bc6d37..79a3287 100644 --- a/distant-core/src/api/local/state/watcher.rs +++ b/distant-core/src/api/local/state/watcher.rs @@ -1,4 +1,5 @@ -use crate::{constants::SERVER_WATCHER_CAPACITY, data::ChangeKind, ConnectionId}; +use crate::{constants::SERVER_WATCHER_CAPACITY, data::ChangeKind}; +use distant_net::common::ConnectionId; use log::*; use notify::{ Config as WatcherConfig, Error as WatcherError, ErrorKind as WatcherErrorKind, @@ -41,26 +42,12 @@ impl WatcherState { // with a large volume of watch requests let (tx, rx) = mpsc::channel(SERVER_WATCHER_CAPACITY); - macro_rules! configure_and_spawn { + macro_rules! spawn_watcher { ($watcher:ident) => {{ - // Attempt to configure watcher, but don't fail if these configurations fail - match $watcher.configure(WatcherConfig::PreciseEvents(true)) { - Ok(true) => debug!("Watcher configured for precise events"), - Ok(false) => debug!("Watcher not configured for precise events",), - Err(x) => error!("Watcher configuration for precise events failed: {}", x), - } - - // Attempt to configure watcher, but don't fail if these configurations fail - match $watcher.configure(WatcherConfig::NoticeEvents(true)) { - Ok(true) => debug!("Watcher configured for notice events"), - Ok(false) => debug!("Watcher not configured for notice events",), - Err(x) => error!("Watcher configuration for notice events failed: {}", x), - } - - Ok(Self { + Self { channel: WatcherChannel { tx }, task: tokio::spawn(watcher_task($watcher, rx)), - }) + } }}; } @@ -91,7 +78,7 @@ impl WatcherState { }; match result { - Ok(mut watcher) => configure_and_spawn!(watcher), + Ok(watcher) => Ok(spawn_watcher!(watcher)), Err(x) => match x.kind { // notify-rs has a bug on Mac M1 with Docker and Linux, so we detect that error // and fall back to the poll watcher if this occurs @@ -99,9 +86,9 @@ impl WatcherState { // https://github.com/notify-rs/notify/issues/423 WatcherErrorKind::Io(x) if x.raw_os_error() == Some(38) => { warn!("Recommended watcher is unsupported! Falling back to polling watcher!"); - let mut watcher = PollWatcher::new(event_handler!(tx)) + let watcher = PollWatcher::new(event_handler!(tx), WatcherConfig::default()) .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; - configure_and_spawn!(watcher) + Ok(spawn_watcher!(watcher)) } _ => Err(io::Error::new(io::ErrorKind::Other, x)), }, diff --git a/distant-core/src/api/local/state/watcher/path.rs b/distant-core/src/api/local/state/watcher/path.rs index ab6b6ad..cada531 100644 --- a/distant-core/src/api/local/state/watcher/path.rs +++ b/distant-core/src/api/local/state/watcher/path.rs @@ -1,8 +1,6 @@ -use crate::{ - data::{Change, ChangeKind, ChangeKindSet, DistantResponseData, Error}, - ConnectionId, -}; -use distant_net::Reply; +use crate::data::{Change, ChangeKind, ChangeKindSet, DistantResponseData, Error}; +use distant_net::common::ConnectionId; +use distant_net::server::Reply; use std::{ fmt, hash::{Hash, Hasher}, diff --git a/distant-core/src/api/reply.rs b/distant-core/src/api/reply.rs index c693c56..0dd9349 100644 --- a/distant-core/src/api/reply.rs +++ b/distant-core/src/api/reply.rs @@ -1,5 +1,5 @@ use crate::{api::DistantMsg, data::DistantResponseData}; -use distant_net::Reply; +use distant_net::server::Reply; use std::{future::Future, io, pin::Pin}; /// Wrapper around a reply that can be batch or single, converting diff --git a/distant-core/src/client.rs b/distant-core/src/client.rs index 1371a34..3354c1b 100644 --- a/distant-core/src/client.rs +++ b/distant-core/src/client.rs @@ -1,5 +1,5 @@ use crate::{DistantMsg, DistantRequestData, DistantResponseData}; -use distant_net::{Channel, Client}; +use distant_net::{client::Channel, Client}; mod ext; mod lsp; diff --git a/distant-core/src/client/ext.rs b/distant-core/src/client/ext.rs index 95955cf..2053e1a 100644 --- a/distant-core/src/client/ext.rs +++ b/distant-core/src/client/ext.rs @@ -9,7 +9,7 @@ use crate::{ }, DistantMsg, }; -use distant_net::{Channel, Request}; +use distant_net::{client::Channel, common::Request}; use std::{future::Future, io, path::PathBuf, pin::Pin}; pub type AsyncReturn<'a, T, E = io::Error> = diff --git a/distant-core/src/client/lsp.rs b/distant-core/src/client/lsp.rs index 244d0d9..2240c09 100644 --- a/distant-core/src/client/lsp.rs +++ b/distant-core/src/client/lsp.rs @@ -411,33 +411,33 @@ mod tests { use super::*; use crate::data::{DistantRequestData, DistantResponseData}; use distant_net::{ - Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Request, Response, - TypedAsyncRead, TypedAsyncWrite, + common::{FramedTransport, InmemoryTransport, Request, Response}, + Client, ReconnectStrategy, }; use std::{future::Future, time::Duration}; + use test_log::test; /// Timeout used with timeout function const TIMEOUT: Duration = Duration::from_millis(50); // Configures an lsp process with a means to send & receive data from outside - async fn spawn_lsp_process() -> ( - FramedTransport, - RemoteLspProcess, - ) { + async fn spawn_lsp_process() -> (FramedTransport, RemoteLspProcess) { let (mut t1, t2) = FramedTransport::pair(100); - let (writer, reader) = t2.into_split(); - let session = Client::new(writer, reader).unwrap(); - let spawn_task = tokio::spawn(async move { - RemoteLspCommand::new() - .spawn(session.clone_channel(), String::from("cmd arg")) - .await + let client = Client::spawn_inmemory(t2, ReconnectStrategy::Fail); + let spawn_task = tokio::spawn({ + let channel = client.clone_channel(); + async move { + RemoteLspCommand::new() + .spawn(channel, String::from("cmd arg")) + .await + } }); // Wait until we get the request from the session - let req: Request = t1.read().await.unwrap().unwrap(); + let req: Request = t1.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session - t1.write(Response::new( + t1.write_frame_for(&Response::new( req.id, DistantResponseData::ProcSpawned { id: rand::random() }, )) @@ -471,7 +471,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn stdin_write_should_only_send_out_complete_lsp_messages() { let (mut transport, mut proc) = spawn_lsp_process().await; @@ -486,7 +486,7 @@ mod tests { .unwrap(); // Validate that the outgoing req is a complete LSP message - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); match req.payload { DistantRequestData::ProcStdin { data, .. } => { assert_eq!( @@ -501,7 +501,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn stdin_write_should_support_buffering_output_until_a_complete_lsp_message_is_composed() { let (mut transport, mut proc) = spawn_lsp_process().await; @@ -520,7 +520,7 @@ mod tests { tokio::task::yield_now().await; let result = timeout( TIMEOUT, - TypedAsyncRead::>::read(&mut transport), + transport.read_frame_as::>(), ) .await; assert!(result.is_err(), "Unexpectedly got data: {:?}", result); @@ -529,7 +529,7 @@ mod tests { proc.stdin.as_mut().unwrap().write(msg_b).await.unwrap(); // Validate that the outgoing req is a complete LSP message - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); match req.payload { DistantRequestData::ProcStdin { data, .. } => { assert_eq!( @@ -544,7 +544,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn stdin_write_should_only_consume_a_complete_lsp_message_even_if_more_is_written() { let (mut transport, mut proc) = spawn_lsp_process().await; @@ -564,7 +564,7 @@ mod tests { .unwrap(); // Validate that the outgoing req is a complete LSP message - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); match req.payload { DistantRequestData::ProcStdin { data, .. } => { assert_eq!( @@ -586,7 +586,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn stdin_write_should_support_sending_out_multiple_lsp_messages_if_all_received_at_once() { let (mut transport, mut proc) = spawn_lsp_process().await; @@ -613,7 +613,7 @@ mod tests { .unwrap(); // Validate that the first outgoing req is a complete LSP message matching first - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); match req.payload { DistantRequestData::ProcStdin { data, .. } => { assert_eq!( @@ -628,7 +628,7 @@ mod tests { } // Validate that the second outgoing req is a complete LSP message matching second - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); match req.payload { DistantRequestData::ProcStdin { data, .. } => { assert_eq!( @@ -643,7 +643,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn stdin_write_should_convert_content_with_distant_scheme_to_file_scheme() { let (mut transport, mut proc) = spawn_lsp_process().await; @@ -658,7 +658,7 @@ mod tests { .unwrap(); // Validate that the outgoing req is a complete LSP message - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); match req.payload { DistantRequestData::ProcStdin { data, .. } => { // Verify the contents AND headers are as expected; in this case, @@ -676,13 +676,13 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn stdout_read_should_yield_lsp_messages_as_strings() { let (mut transport, mut proc) = spawn_lsp_process().await; // Send complete LSP message as stdout to process transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStdout { id: proc.id(), @@ -706,7 +706,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn stdout_read_should_only_yield_complete_lsp_messages() { let (mut transport, mut proc) = spawn_lsp_process().await; @@ -718,7 +718,7 @@ mod tests { // Send half of LSP message over stdout transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStdout { id: proc.id(), @@ -736,7 +736,7 @@ mod tests { // Send other half of LSP message over stdout transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStdout { id: proc.id(), @@ -757,7 +757,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn stdout_read_should_only_consume_a_complete_lsp_message_even_if_more_output_is_available( ) { let (mut transport, mut proc) = spawn_lsp_process().await; @@ -770,7 +770,7 @@ mod tests { // Send complete LSP message as stdout to process transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStdout { id: proc.id(), @@ -798,7 +798,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn stdout_read_should_support_yielding_multiple_lsp_messages_if_all_received_at_once() { let (mut transport, mut proc) = spawn_lsp_process().await; @@ -813,7 +813,7 @@ mod tests { // Send complete LSP message as stdout to process transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStdout { id: proc.id(), @@ -849,13 +849,13 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn stdout_read_should_convert_content_with_file_scheme_to_distant_scheme() { let (mut transport, mut proc) = spawn_lsp_process().await; // Send complete LSP message as stdout to process transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStdout { id: proc.id(), @@ -879,13 +879,13 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn stderr_read_should_yield_lsp_messages_as_strings() { let (mut transport, mut proc) = spawn_lsp_process().await; // Send complete LSP message as stderr to process transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStderr { id: proc.id(), @@ -909,7 +909,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn stderr_read_should_only_yield_complete_lsp_messages() { let (mut transport, mut proc) = spawn_lsp_process().await; @@ -921,7 +921,7 @@ mod tests { // Send half of LSP message over stderr transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStderr { id: proc.id(), @@ -939,7 +939,7 @@ mod tests { // Send other half of LSP message over stderr transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStderr { id: proc.id(), @@ -960,7 +960,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn stderr_read_should_only_consume_a_complete_lsp_message_even_if_more_errput_is_available( ) { let (mut transport, mut proc) = spawn_lsp_process().await; @@ -973,7 +973,7 @@ mod tests { // Send complete LSP message as stderr to process transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStderr { id: proc.id(), @@ -1001,7 +1001,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn stderr_read_should_support_yielding_multiple_lsp_messages_if_all_received_at_once() { let (mut transport, mut proc) = spawn_lsp_process().await; @@ -1016,7 +1016,7 @@ mod tests { // Send complete LSP message as stderr to process transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStderr { id: proc.id(), @@ -1052,13 +1052,13 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn stderr_read_should_convert_content_with_file_scheme_to_distant_scheme() { let (mut transport, mut proc) = spawn_lsp_process().await; // Send complete LSP message as stderr to process transport - .write(Response::new( + .write_frame_for(&Response::new( proc.origin_id().to_string(), DistantResponseData::ProcStderr { id: proc.id(), diff --git a/distant-core/src/client/lsp/msg.rs b/distant-core/src/client/lsp/msg.rs index 0156d31..b297c12 100644 --- a/distant-core/src/client/lsp/msg.rs +++ b/distant-core/src/client/lsp/msg.rs @@ -310,7 +310,7 @@ fn swap_prefix(obj: &mut Map, old: &str, new: &str) { let check = |s: &String| s.starts_with(old); let mut mutate = |s: &mut String| { if let Some(pos) = s.find(old) { - s.replace_range(pos..old.len(), new); + s.replace_range(pos..pos + old.len(), new); } }; @@ -396,6 +396,7 @@ impl FromStr for LspContent { #[cfg(test)] mod tests { use super::*; + use test_log::test; macro_rules! make_obj { ($($tail:tt)*) => { diff --git a/distant-core/src/client/process.rs b/distant-core/src/client/process.rs index 95c7e35..e502b04 100644 --- a/distant-core/src/client/process.rs +++ b/distant-core/src/client/process.rs @@ -4,7 +4,10 @@ use crate::{ data::{Cmd, DistantRequestData, DistantResponseData, Environment, ProcessId, PtySize}, DistantMsg, }; -use distant_net::{Mailbox, Request, Response}; +use distant_net::{ + client::Mailbox, + common::{Request, Response}, +}; use log::*; use std::{path::PathBuf, sync::Arc}; use tokio::{ @@ -609,21 +612,18 @@ mod tests { data::{Error, ErrorKind}, }; use distant_net::{ - Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Response, - TypedAsyncRead, TypedAsyncWrite, + common::{FramedTransport, InmemoryTransport, Response}, + Client, ReconnectStrategy, }; use std::time::Duration; + use test_log::test; - fn make_session() -> ( - FramedTransport, - DistantClient, - ) { + fn make_session() -> (FramedTransport, DistantClient) { let (t1, t2) = FramedTransport::pair(100); - let (writer, reader) = t2.into_split(); - (t1, Client::new(writer, reader).unwrap()) + (t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail)) } - #[tokio::test] + #[test(tokio::test)] async fn spawn_should_return_invalid_data_if_received_batch_response() { let (mut transport, session) = make_session(); @@ -636,11 +636,12 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Batch(vec![DistantResponseData::ProcSpawned { id: 1 }]), )) @@ -654,7 +655,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn spawn_should_return_invalid_data_if_did_not_get_a_indicator_that_process_started() { let (mut transport, session) = make_session(); @@ -667,11 +668,12 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::Error(Error { kind: ErrorKind::BrokenPipe, @@ -688,7 +690,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn kill_should_return_error_if_internal_tasks_already_completed() { let (mut transport, session) = make_session(); @@ -701,12 +703,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -726,7 +729,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn kill_should_send_proc_kill_request_and_then_cause_stdin_forwarding_to_close() { let (mut transport, session) = make_session(); @@ -739,12 +742,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -756,7 +760,8 @@ mod tests { assert!(proc.kill().await.is_ok(), "Failed to send kill request"); // Verify the kill request was sent - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); match req.payload { DistantMsg::Single(DistantRequestData::ProcKill { id: proc_id }) => { assert_eq!(proc_id, id) @@ -777,7 +782,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn stdin_should_be_forwarded_from_receiver_field() { let (mut transport, session) = make_session(); @@ -790,12 +795,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -812,7 +818,8 @@ mod tests { .unwrap(); // Verify that a request is made through the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); match req.payload { DistantMsg::Single(DistantRequestData::ProcStdin { id, data }) => { assert_eq!(id, 12345); @@ -822,7 +829,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn stdout_should_be_forwarded_to_receiver_field() { let (mut transport, session) = make_session(); @@ -835,12 +842,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -851,7 +859,7 @@ mod tests { let mut proc = spawn_task.await.unwrap().unwrap(); transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcStdout { id, @@ -865,7 +873,7 @@ mod tests { assert_eq!(out, b"some out"); } - #[tokio::test] + #[test(tokio::test)] async fn stderr_should_be_forwarded_to_receiver_field() { let (mut transport, session) = make_session(); @@ -878,12 +886,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -894,7 +903,7 @@ mod tests { let mut proc = spawn_task.await.unwrap().unwrap(); transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcStderr { id, @@ -908,7 +917,7 @@ mod tests { assert_eq!(out, b"some err"); } - #[tokio::test] + #[test(tokio::test)] async fn status_should_return_none_if_not_done() { let (mut transport, session) = make_session(); @@ -921,12 +930,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -940,7 +950,7 @@ mod tests { assert_eq!(result, None, "Unexpectedly got proc status: {:?}", result); } - #[tokio::test] + #[test(tokio::test)] async fn status_should_return_false_for_success_if_internal_tasks_fail() { let (mut transport, session) = make_session(); @@ -953,12 +963,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -986,7 +997,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn status_should_return_process_status_when_done() { let (mut transport, session) = make_session(); @@ -999,12 +1010,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -1016,7 +1028,7 @@ mod tests { // Send a process completion response to pass along exit status and conclude wait transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcDone { id, @@ -1040,7 +1052,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn wait_should_return_error_if_internal_tasks_fail() { let (mut transport, session) = make_session(); @@ -1053,12 +1065,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -1075,7 +1088,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn wait_should_return_error_if_connection_terminates_before_receiving_done_response() { let (mut transport, session) = make_session(); @@ -1088,12 +1101,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -1117,7 +1131,7 @@ mod tests { } } - #[tokio::test] + #[test(tokio::test)] async fn receiving_done_response_should_result_in_wait_returning_exit_information() { let (mut transport, session) = make_session(); @@ -1130,12 +1144,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -1148,7 +1163,7 @@ mod tests { // Send a process completion response to pass along exit status and conclude wait transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcDone { id, @@ -1169,7 +1184,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn receiving_done_response_should_result_in_output_returning_exit_information() { let (mut transport, session) = make_session(); @@ -1182,12 +1197,13 @@ mod tests { }); // Wait until we get the request from the session - let req: Request> = transport.read().await.unwrap().unwrap(); + let req: Request> = + transport.read_frame_as().await.unwrap().unwrap(); // Send back a response through the session let id = 12345; transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantMsg::Single(DistantResponseData::ProcSpawned { id }), )) @@ -1200,7 +1216,7 @@ mod tests { // Send some stdout transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantMsg::Single(DistantResponseData::ProcStdout { id, @@ -1212,7 +1228,7 @@ mod tests { // Send some stderr transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantMsg::Single(DistantResponseData::ProcStderr { id, @@ -1224,7 +1240,7 @@ mod tests { // Send a process completion response to pass along exit status and conclude wait transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantMsg::Single(DistantResponseData::ProcDone { id, diff --git a/distant-core/src/client/searcher.rs b/distant-core/src/client/searcher.rs index 4635331..d3fb364 100644 --- a/distant-core/src/client/searcher.rs +++ b/distant-core/src/client/searcher.rs @@ -4,7 +4,7 @@ use crate::{ data::{DistantRequestData, DistantResponseData, SearchId, SearchQuery, SearchQueryMatch}, DistantMsg, }; -use distant_net::Request; +use distant_net::common::Request; use log::*; use std::{fmt, io}; use tokio::{sync::mpsc, task::JoinHandle}; @@ -197,22 +197,19 @@ mod tests { }; use crate::DistantClient; use distant_net::{ - Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Response, - TypedAsyncRead, TypedAsyncWrite, + common::{FramedTransport, InmemoryTransport, Response}, + Client, ReconnectStrategy, }; use std::{path::PathBuf, sync::Arc}; + use test_log::test; use tokio::sync::Mutex; - fn make_session() -> ( - FramedTransport, - DistantClient, - ) { + fn make_session() -> (FramedTransport, DistantClient) { let (t1, t2) = FramedTransport::pair(100); - let (writer, reader) = t2.into_split(); - (t1, Client::new(writer, reader).unwrap()) + (t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail)) } - #[tokio::test] + #[test(tokio::test)] async fn searcher_should_have_query_reflect_ongoing_query() { let (mut transport, session) = make_session(); let test_query = SearchQuery { @@ -232,11 +229,11 @@ mod tests { }; // Wait until we get the request from the session - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); // Send back an acknowledgement that a search was started transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantResponseData::SearchStarted { id: rand::random() }, )) @@ -248,7 +245,7 @@ mod tests { assert_eq!(searcher.query(), &test_query); } - #[tokio::test] + #[test(tokio::test)] async fn searcher_should_support_getting_next_match() { let (mut transport, session) = make_session(); let test_query = SearchQuery { @@ -268,12 +265,12 @@ mod tests { ); // Wait until we get the request from the session - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); // Send back an acknowledgement that a searcher was created let id = rand::random::(); transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantResponseData::SearchStarted { id }, )) @@ -285,7 +282,7 @@ mod tests { // Send some matches related to the file transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, vec![ DistantResponseData::SearchResults { @@ -366,7 +363,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn searcher_should_distinguish_match_events_and_only_receive_matches_for_itself() { let (mut transport, session) = make_session(); @@ -387,12 +384,12 @@ mod tests { ); // Wait until we get the request from the session - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); // Send back an acknowledgement that a searcher was created let id = rand::random(); transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantResponseData::SearchStarted { id }, )) @@ -404,7 +401,7 @@ mod tests { // Send a match from the appropriate origin transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantResponseData::SearchResults { id, @@ -423,7 +420,7 @@ mod tests { // Send a chanmatchge from a different origin transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone() + "1", DistantResponseData::SearchResults { id, @@ -442,7 +439,7 @@ mod tests { // Send a chanmatchge from the appropriate origin transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantResponseData::SearchResults { id, @@ -487,7 +484,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn searcher_should_stop_receiving_events_if_cancelled() { let (mut transport, session) = make_session(); @@ -508,12 +505,12 @@ mod tests { ); // Wait until we get the request from the session - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); // Send back an acknowledgement that a watcher was created let id = rand::random::(); transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantResponseData::SearchStarted { id }, )) @@ -522,7 +519,7 @@ mod tests { // Send some matches from the appropriate origin transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantResponseData::SearchResults { id, @@ -579,10 +576,10 @@ mod tests { let searcher_2 = Arc::clone(&searcher); let cancel_task = tokio::spawn(async move { searcher_2.lock().await.cancel().await }); - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); transport - .write(Response::new(req.id.clone(), DistantResponseData::Ok)) + .write_frame_for(&Response::new(req.id.clone(), DistantResponseData::Ok)) .await .unwrap(); @@ -591,7 +588,7 @@ mod tests { // Send a match that will get ignored transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantResponseData::SearchResults { id, diff --git a/distant-core/src/client/watcher.rs b/distant-core/src/client/watcher.rs index 072d6d0..a1f98a3 100644 --- a/distant-core/src/client/watcher.rs +++ b/distant-core/src/client/watcher.rs @@ -4,7 +4,7 @@ use crate::{ data::{Change, ChangeKindSet, DistantRequestData, DistantResponseData}, DistantMsg, }; -use distant_net::Request; +use distant_net::common::Request; use log::*; use std::{ fmt, io, @@ -185,22 +185,19 @@ mod tests { use crate::data::ChangeKind; use crate::DistantClient; use distant_net::{ - Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Response, - TypedAsyncRead, TypedAsyncWrite, + common::{FramedTransport, InmemoryTransport, Response}, + Client, ReconnectStrategy, }; use std::sync::Arc; + use test_log::test; use tokio::sync::Mutex; - fn make_session() -> ( - FramedTransport, - DistantClient, - ) { + fn make_session() -> (FramedTransport, DistantClient) { let (t1, t2) = FramedTransport::pair(100); - let (writer, reader) = t2.into_split(); - (t1, Client::new(writer, reader).unwrap()) + (t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail)) } - #[tokio::test] + #[test(tokio::test)] async fn watcher_should_have_path_reflect_watched_path() { let (mut transport, session) = make_session(); let test_path = Path::new("/some/test/path"); @@ -219,11 +216,11 @@ mod tests { }); // Wait until we get the request from the session - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); // Send back an acknowledgement that a watcher was created transport - .write(Response::new(req.id, DistantResponseData::Ok)) + .write_frame_for(&Response::new(req.id, DistantResponseData::Ok)) .await .unwrap(); @@ -232,7 +229,7 @@ mod tests { assert_eq!(watcher.path(), test_path); } - #[tokio::test] + #[test(tokio::test)] async fn watcher_should_support_getting_next_change() { let (mut transport, session) = make_session(); let test_path = Path::new("/some/test/path"); @@ -251,11 +248,11 @@ mod tests { }); // Wait until we get the request from the session - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); // Send back an acknowledgement that a watcher was created transport - .write(Response::new(req.id.clone(), DistantResponseData::Ok)) + .write_frame_for(&Response::new(req.id.clone(), DistantResponseData::Ok)) .await .unwrap(); @@ -264,7 +261,7 @@ mod tests { // Send some changes related to the file transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, vec![ DistantResponseData::Changed(Change { @@ -300,7 +297,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn watcher_should_distinguish_change_events_and_only_receive_changes_for_itself() { let (mut transport, session) = make_session(); let test_path = Path::new("/some/test/path"); @@ -319,11 +316,11 @@ mod tests { }); // Wait until we get the request from the session - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); // Send back an acknowledgement that a watcher was created transport - .write(Response::new(req.id.clone(), DistantResponseData::Ok)) + .write_frame_for(&Response::new(req.id.clone(), DistantResponseData::Ok)) .await .unwrap(); @@ -332,7 +329,7 @@ mod tests { // Send a change from the appropriate origin transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone(), DistantResponseData::Changed(Change { kind: ChangeKind::Access, @@ -344,7 +341,7 @@ mod tests { // Send a change from a different origin transport - .write(Response::new( + .write_frame_for(&Response::new( req.id.clone() + "1", DistantResponseData::Changed(Change { kind: ChangeKind::Content, @@ -356,7 +353,7 @@ mod tests { // Send a change from the appropriate origin transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantResponseData::Changed(Change { kind: ChangeKind::Remove, @@ -386,7 +383,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn watcher_should_stop_receiving_events_if_unwatched() { let (mut transport, session) = make_session(); let test_path = Path::new("/some/test/path"); @@ -405,17 +402,17 @@ mod tests { }); // Wait until we get the request from the session - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); // Send back an acknowledgement that a watcher was created transport - .write(Response::new(req.id.clone(), DistantResponseData::Ok)) + .write_frame_for(&Response::new(req.id.clone(), DistantResponseData::Ok)) .await .unwrap(); // Send some changes from the appropriate origin transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, vec![ DistantResponseData::Changed(Change { @@ -461,10 +458,10 @@ mod tests { let watcher_2 = Arc::clone(&watcher); let unwatch_task = tokio::spawn(async move { watcher_2.lock().await.unwatch().await }); - let req: Request = transport.read().await.unwrap().unwrap(); + let req: Request = transport.read_frame_as().await.unwrap().unwrap(); transport - .write(Response::new(req.id.clone(), DistantResponseData::Ok)) + .write_frame_for(&Response::new(req.id.clone(), DistantResponseData::Ok)) .await .unwrap(); @@ -472,7 +469,7 @@ mod tests { unwatch_task.await.unwrap().unwrap(); transport - .write(Response::new( + .write_frame_for(&Response::new( req.id, DistantResponseData::Changed(Change { kind: ChangeKind::Unknown, diff --git a/distant-core/src/constants.rs b/distant-core/src/constants.rs index 2fc79be..52349e7 100644 --- a/distant-core/src/constants.rs +++ b/distant-core/src/constants.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + /// Capacity associated stdin, stdout, and stderr pipes receiving data from remote server pub const CLIENT_PIPE_CAPACITY: usize = 10000; @@ -18,4 +20,4 @@ pub const MAX_PIPE_CHUNK_SIZE: usize = 16384; /// Duration in milliseconds to sleep between reading stdout/stderr chunks /// to avoid sending many small messages to clients -pub const READ_PAUSE_MILLIS: u64 = 50; +pub const READ_PAUSE_DURATION: Duration = Duration::from_millis(1); diff --git a/distant-core/src/credentials.rs b/distant-core/src/credentials.rs index 89dfed7..7f88dbc 100644 --- a/distant-core/src/credentials.rs +++ b/distant-core/src/credentials.rs @@ -1,8 +1,5 @@ -use crate::{ - serde_str::{deserialize_from_str, serialize_to_str}, - Destination, Host, -}; -use distant_net::SecretKey32; +use crate::serde_str::{deserialize_from_str, serialize_to_str}; +use distant_net::common::{Destination, Host, SecretKey32}; use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize}; use std::{convert::TryFrom, fmt, io, str::FromStr}; @@ -154,6 +151,7 @@ mod tests { use super::*; use once_cell::sync::Lazy; use std::net::{Ipv4Addr, Ipv6Addr}; + use test_log::test; const HOST: &str = "testhost"; const PORT: u16 = 12345; diff --git a/distant-core/src/data.rs b/distant-core/src/data.rs index fbab853..37c2325 100644 --- a/distant-core/src/data.rs +++ b/distant-core/src/data.rs @@ -24,9 +24,6 @@ pub use error::*; mod filesystem; pub use filesystem::*; -mod map; -pub use map::Map; - mod metadata; pub use metadata::*; @@ -46,7 +43,7 @@ pub(crate) use utils::*; pub type ProcessId = u32; /// Mapping of environment variables -pub type Environment = Map; +pub type Environment = distant_net::common::Map; /// Type alias for a vec of bytes /// diff --git a/distant-core/src/data/search.rs b/distant-core/src/data/search.rs index cdf393c..a3836df 100644 --- a/distant-core/src/data/search.rs +++ b/distant-core/src/data/search.rs @@ -391,6 +391,7 @@ mod tests { mod search_query_condition { use super::*; + use test_log::test; #[test] fn to_regex_string_should_convert_to_appropriate_regex_and_escape_as_needed() { diff --git a/distant-core/src/lib.rs b/distant-core/src/lib.rs index a457712..7189a9a 100644 --- a/distant-core/src/lib.rs +++ b/distant-core/src/lib.rs @@ -8,10 +8,7 @@ mod credentials; pub use credentials::*; pub mod data; -pub use data::{DistantMsg, DistantRequestData, DistantResponseData, Map}; - -mod manager; -pub use manager::*; +pub use data::{DistantMsg, DistantRequestData, DistantResponseData}; mod constants; mod serde_str; diff --git a/distant-core/src/manager/client.rs b/distant-core/src/manager/client.rs deleted file mode 100644 index e9c1fa8..0000000 --- a/distant-core/src/manager/client.rs +++ /dev/null @@ -1,783 +0,0 @@ -use super::data::{ - ConnectionId, ConnectionInfo, ConnectionList, Destination, ManagerCapabilities, ManagerRequest, - ManagerResponse, -}; -use crate::{ - DistantChannel, DistantClient, DistantMsg, DistantRequestData, DistantResponseData, Map, -}; -use distant_net::{ - router, Auth, AuthServer, Client, IntoSplit, MpscTransport, OneshotListener, Request, Response, - ServerExt, ServerRef, UntypedTransportRead, UntypedTransportWrite, -}; -use log::*; -use std::{ - collections::HashMap, - io, - ops::{Deref, DerefMut}, -}; -use tokio::task::JoinHandle; - -mod config; -pub use config::*; - -mod ext; -pub use ext::*; - -router!(DistantManagerClientRouter { - auth_transport: Request => Response, - manager_transport: Response => Request, -}); - -/// Represents a client that can connect to a remote distant manager -pub struct DistantManagerClient { - auth: Box, - client: Client, - distant_clients: HashMap, -} - -impl Drop for DistantManagerClient { - fn drop(&mut self) { - self.auth.abort(); - self.client.abort(); - } -} - -/// Represents a raw channel between a manager client and some remote server -pub struct RawDistantChannel { - pub transport: MpscTransport< - Request>, - Response>, - >, - forward_task: JoinHandle<()>, - mailbox_task: JoinHandle<()>, -} - -impl RawDistantChannel { - pub fn abort(&self) { - self.forward_task.abort(); - self.mailbox_task.abort(); - } -} - -impl Deref for RawDistantChannel { - type Target = MpscTransport< - Request>, - Response>, - >; - - fn deref(&self) -> &Self::Target { - &self.transport - } -} - -impl DerefMut for RawDistantChannel { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.transport - } -} - -struct ClientHandle { - client: DistantClient, - forward_task: JoinHandle<()>, - mailbox_task: JoinHandle<()>, -} - -impl Drop for ClientHandle { - fn drop(&mut self) { - self.forward_task.abort(); - self.mailbox_task.abort(); - } -} - -impl DistantManagerClient { - /// Initializes a client using the provided [`UntypedTransport`] - pub fn new(config: DistantManagerClientConfig, transport: T) -> io::Result - where - T: IntoSplit + 'static, - T::Read: UntypedTransportRead + 'static, - T::Write: UntypedTransportWrite + 'static, - { - let DistantManagerClientRouter { - auth_transport, - manager_transport, - .. - } = DistantManagerClientRouter::new(transport); - - // Initialize our client with manager request/response transport - let (writer, reader) = manager_transport.into_split(); - let client = Client::new(writer, reader)?; - - // Initialize our auth handler with auth/auth transport - let auth = AuthServer { - on_challenge: config.on_challenge, - on_verify: config.on_verify, - on_info: config.on_info, - on_error: config.on_error, - } - .start(OneshotListener::from_value(auth_transport.into_split()))?; - - Ok(Self { - auth, - client, - distant_clients: HashMap::new(), - }) - } - - /// Request that the manager launches a new server at the given `destination` - /// with `options` being passed for destination-specific details, returning the new - /// `destination` of the spawned server to connect to - pub async fn launch( - &mut self, - destination: impl Into, - options: impl Into, - ) -> io::Result { - let destination = Box::new(destination.into()); - let options = options.into(); - trace!("launch({}, {})", destination, options); - - let res = self - .client - .send(ManagerRequest::Launch { - destination, - options, - }) - .await?; - match res.payload { - ManagerResponse::Launched { destination } => Ok(destination), - ManagerResponse::Error(x) => Err(x.into()), - x => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Got unexpected response: {:?}", x), - )), - } - } - - /// Request that the manager establishes a new connection at the given `destination` - /// with `options` being passed for destination-specific details - pub async fn connect( - &mut self, - destination: impl Into, - options: impl Into, - ) -> io::Result { - let destination = Box::new(destination.into()); - let options = options.into(); - trace!("connect({}, {})", destination, options); - - let res = self - .client - .send(ManagerRequest::Connect { - destination, - options, - }) - .await?; - match res.payload { - ManagerResponse::Connected { id } => Ok(id), - ManagerResponse::Error(x) => Err(x.into()), - x => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Got unexpected response: {:?}", x), - )), - } - } - - /// Establishes a channel with the server represented by the `connection_id`, - /// returning a [`DistantChannel`] acting as the connection - /// - /// ### Note - /// - /// Multiple calls to open a channel against the same connection will result in - /// clones of the same [`DistantChannel`] rather than establishing a duplicate - /// remote connection to the same server - pub async fn open_channel( - &mut self, - connection_id: ConnectionId, - ) -> io::Result { - trace!("open_channel({})", connection_id); - if let Some(handle) = self.distant_clients.get(&connection_id) { - Ok(handle.client.clone_channel()) - } else { - let RawDistantChannel { - transport, - forward_task, - mailbox_task, - } = self.open_raw_channel(connection_id).await?; - let (writer, reader) = transport.into_split(); - let client = DistantClient::new(writer, reader)?; - let channel = client.clone_channel(); - self.distant_clients.insert( - connection_id, - ClientHandle { - client, - forward_task, - mailbox_task, - }, - ); - Ok(channel) - } - } - - /// Establishes a channel with the server represented by the `connection_id`, - /// returning a [`Transport`] acting as the connection - /// - /// ### Note - /// - /// Multiple calls to open a channel against the same connection will result in establishing a - /// duplicate remote connections to the same server, so take care when using this method - pub async fn open_raw_channel( - &mut self, - connection_id: ConnectionId, - ) -> io::Result { - trace!("open_raw_channel({})", connection_id); - let mut mailbox = self - .client - .mail(ManagerRequest::OpenChannel { id: connection_id }) - .await?; - - // Wait for the first response, which should be channel confirmation - let channel_id = match mailbox.next().await { - Some(response) => match response.payload { - ManagerResponse::ChannelOpened { id } => Ok(id), - ManagerResponse::Error(x) => Err(x.into()), - x => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Got unexpected response: {:?}", x), - )), - }, - None => Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "open_channel mailbox aborted", - )), - }?; - - // Spawn reader and writer tasks to forward requests and replies - // using our opened channel - let (t1, t2) = MpscTransport::pair(1); - let (mut writer, mut reader) = t1.into_split(); - let mailbox_task = tokio::spawn(async move { - use distant_net::TypedAsyncWrite; - while let Some(response) = mailbox.next().await { - match response.payload { - ManagerResponse::Channel { response, .. } => { - if let Err(x) = writer.write(response).await { - error!("[Conn {}] {}", connection_id, x); - } - } - ManagerResponse::ChannelClosed { .. } => break, - _ => continue, - } - } - }); - - let mut manager_channel = self.client.clone_channel(); - let forward_task = tokio::spawn(async move { - use distant_net::TypedAsyncRead; - loop { - match reader.read().await { - Ok(Some(request)) => { - // NOTE: In this situation, we do not expect a response to this - // request (even if the server sends something back) - if let Err(x) = manager_channel - .fire(ManagerRequest::Channel { - id: channel_id, - request, - }) - .await - { - error!("[Conn {}] {}", connection_id, x); - } - } - Ok(None) => break, - Err(x) => { - error!("[Conn {}] {}", connection_id, x); - continue; - } - } - } - }); - - Ok(RawDistantChannel { - transport: t2, - forward_task, - mailbox_task, - }) - } - - /// Retrieves a list of supported capabilities - pub async fn capabilities(&mut self) -> io::Result { - trace!("capabilities()"); - let res = self.client.send(ManagerRequest::Capabilities).await?; - match res.payload { - ManagerResponse::Capabilities { supported } => Ok(supported), - ManagerResponse::Error(x) => Err(x.into()), - x => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Got unexpected response: {:?}", x), - )), - } - } - - /// Retrieves information about a specific connection - pub async fn info(&mut self, id: ConnectionId) -> io::Result { - trace!("info({})", id); - let res = self.client.send(ManagerRequest::Info { id }).await?; - match res.payload { - ManagerResponse::Info(info) => Ok(info), - ManagerResponse::Error(x) => Err(x.into()), - x => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Got unexpected response: {:?}", x), - )), - } - } - - /// Kills the specified connection - pub async fn kill(&mut self, id: ConnectionId) -> io::Result<()> { - trace!("kill({})", id); - let res = self.client.send(ManagerRequest::Kill { id }).await?; - match res.payload { - ManagerResponse::Killed => Ok(()), - ManagerResponse::Error(x) => Err(x.into()), - x => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Got unexpected response: {:?}", x), - )), - } - } - - /// Retrieves a list of active connections - pub async fn list(&mut self) -> io::Result { - trace!("list()"); - let res = self.client.send(ManagerRequest::List).await?; - match res.payload { - ManagerResponse::List(list) => Ok(list), - ManagerResponse::Error(x) => Err(x.into()), - x => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Got unexpected response: {:?}", x), - )), - } - } - - /// Requests that the manager shuts down - pub async fn shutdown(&mut self) -> io::Result<()> { - trace!("shutdown()"); - let res = self.client.send(ManagerRequest::Shutdown).await?; - match res.payload { - ManagerResponse::Shutdown => Ok(()), - ManagerResponse::Error(x) => Err(x.into()), - x => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Got unexpected response: {:?}", x), - )), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::data::{Error, ErrorKind}; - use distant_net::{ - FramedTransport, InmemoryTransport, PlainCodec, UntypedTransportRead, UntypedTransportWrite, - }; - - fn setup() -> ( - DistantManagerClient, - FramedTransport, - ) { - let (t1, t2) = FramedTransport::pair(100); - let client = - DistantManagerClient::new(DistantManagerClientConfig::with_empty_prompts(), t1) - .unwrap(); - (client, t2) - } - - #[inline] - fn test_error() -> Error { - Error { - kind: ErrorKind::Interrupted, - description: "test error".to_string(), - } - } - - #[inline] - fn test_io_error() -> io::Error { - test_error().into() - } - - #[tokio::test] - async fn connect_should_report_error_if_receives_error_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new( - request.id, - ManagerResponse::Error(test_error()), - )) - .await - .unwrap(); - }); - - let err = client - .connect( - "scheme://host".parse::().unwrap(), - "key=value".parse::().unwrap(), - ) - .await - .unwrap_err(); - assert_eq!(err.kind(), test_io_error().kind()); - assert_eq!(err.to_string(), test_io_error().to_string()); - } - - #[tokio::test] - async fn connect_should_report_error_if_receives_unexpected_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new(request.id, ManagerResponse::Shutdown)) - .await - .unwrap(); - }); - - let err = client - .connect( - "scheme://host".parse::().unwrap(), - "key=value".parse::().unwrap(), - ) - .await - .unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::InvalidData); - } - - #[tokio::test] - async fn connect_should_return_id_from_successful_response() { - let (mut client, mut transport) = setup(); - - let expected_id = 999; - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new( - request.id, - ManagerResponse::Connected { id: expected_id }, - )) - .await - .unwrap(); - }); - - let id = client - .connect( - "scheme://host".parse::().unwrap(), - "key=value".parse::().unwrap(), - ) - .await - .unwrap(); - assert_eq!(id, expected_id); - } - - #[tokio::test] - async fn info_should_report_error_if_receives_error_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new( - request.id, - ManagerResponse::Error(test_error()), - )) - .await - .unwrap(); - }); - - let err = client.info(123).await.unwrap_err(); - assert_eq!(err.kind(), test_io_error().kind()); - assert_eq!(err.to_string(), test_io_error().to_string()); - } - - #[tokio::test] - async fn info_should_report_error_if_receives_unexpected_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new(request.id, ManagerResponse::Shutdown)) - .await - .unwrap(); - }); - - let err = client.info(123).await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::InvalidData); - } - - #[tokio::test] - async fn info_should_return_connection_info_from_successful_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - let info = ConnectionInfo { - id: 123, - destination: "scheme://host".parse::().unwrap(), - options: "key=value".parse::().unwrap(), - }; - - transport - .write(Response::new(request.id, ManagerResponse::Info(info))) - .await - .unwrap(); - }); - - let info = client.info(123).await.unwrap(); - assert_eq!(info.id, 123); - assert_eq!( - info.destination, - "scheme://host".parse::().unwrap() - ); - assert_eq!(info.options, "key=value".parse::().unwrap()); - } - - #[tokio::test] - async fn list_should_report_error_if_receives_error_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new( - request.id, - ManagerResponse::Error(test_error()), - )) - .await - .unwrap(); - }); - - let err = client.list().await.unwrap_err(); - assert_eq!(err.kind(), test_io_error().kind()); - assert_eq!(err.to_string(), test_io_error().to_string()); - } - - #[tokio::test] - async fn list_should_report_error_if_receives_unexpected_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new(request.id, ManagerResponse::Shutdown)) - .await - .unwrap(); - }); - - let err = client.list().await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::InvalidData); - } - - #[tokio::test] - async fn list_should_return_connection_list_from_successful_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - let mut list = ConnectionList::new(); - list.insert(123, "scheme://host".parse::().unwrap()); - - transport - .write(Response::new(request.id, ManagerResponse::List(list))) - .await - .unwrap(); - }); - - let list = client.list().await.unwrap(); - assert_eq!(list.len(), 1); - assert_eq!( - list.get(&123).expect("Connection list missing item"), - &"scheme://host".parse::().unwrap() - ); - } - - #[tokio::test] - async fn kill_should_report_error_if_receives_error_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new( - request.id, - ManagerResponse::Error(test_error()), - )) - .await - .unwrap(); - }); - - let err = client.kill(123).await.unwrap_err(); - assert_eq!(err.kind(), test_io_error().kind()); - assert_eq!(err.to_string(), test_io_error().to_string()); - } - - #[tokio::test] - async fn kill_should_report_error_if_receives_unexpected_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new(request.id, ManagerResponse::Shutdown)) - .await - .unwrap(); - }); - - let err = client.kill(123).await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::InvalidData); - } - - #[tokio::test] - async fn kill_should_return_success_from_successful_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new(request.id, ManagerResponse::Killed)) - .await - .unwrap(); - }); - - client.kill(123).await.unwrap(); - } - - #[tokio::test] - async fn shutdown_should_report_error_if_receives_error_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new( - request.id, - ManagerResponse::Connected { id: 0 }, - )) - .await - .unwrap(); - }); - - let err = client.shutdown().await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::InvalidData); - } - - #[tokio::test] - async fn shutdown_should_report_error_if_receives_unexpected_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new( - request.id, - ManagerResponse::Error(test_error()), - )) - .await - .unwrap(); - }); - - let err = client.shutdown().await.unwrap_err(); - assert_eq!(err.kind(), test_io_error().kind()); - assert_eq!(err.to_string(), test_io_error().to_string()); - } - - #[tokio::test] - async fn shutdown_should_return_success_from_successful_response() { - let (mut client, mut transport) = setup(); - - tokio::spawn(async move { - let request = transport - .read::>() - .await - .unwrap() - .unwrap(); - - transport - .write(Response::new(request.id, ManagerResponse::Shutdown)) - .await - .unwrap(); - }); - - client.shutdown().await.unwrap(); - } -} diff --git a/distant-core/src/manager/client/config.rs b/distant-core/src/manager/client/config.rs deleted file mode 100644 index cc0e03e..0000000 --- a/distant-core/src/manager/client/config.rs +++ /dev/null @@ -1,85 +0,0 @@ -use distant_net::{AuthChallengeFn, AuthErrorFn, AuthInfoFn, AuthVerifyFn, AuthVerifyKind}; -use log::*; -use std::io; - -/// Configuration to use when creating a new [`DistantManagerClient`](super::DistantManagerClient) -pub struct DistantManagerClientConfig { - pub on_challenge: Box, - pub on_verify: Box, - pub on_info: Box, - pub on_error: Box, -} - -impl DistantManagerClientConfig { - /// Creates a new config with prompts that return empty strings - pub fn with_empty_prompts() -> Self { - Self::with_prompts(|_| Ok("".to_string()), |_| Ok("".to_string())) - } - - /// Creates a new config with two prompts - /// - /// * `password_prompt` - used for prompting for a secret, and should not display what is typed - /// * `text_prompt` - used for general text, and is okay to display what is typed - pub fn with_prompts(password_prompt: PP, text_prompt: PT) -> Self - where - PP: Fn(&str) -> io::Result + Send + Sync + 'static, - PT: Fn(&str) -> io::Result + Send + Sync + 'static, - { - Self { - on_challenge: Box::new(move |questions, _extra| { - trace!("[manager client] on_challenge({questions:?}, {_extra:?})"); - let mut answers = Vec::new(); - for question in questions.iter() { - // Contains all prompt lines including same line - let mut lines = question.text.split('\n').collect::>(); - - // Line that is prompt on same line as answer - let line = lines.pop().unwrap(); - - // Go ahead and display all other lines - for line in lines.into_iter() { - eprintln!("{}", line); - } - - // Get an answer from user input, or use a blank string as an answer - // if we fail to get input from the user - let answer = password_prompt(line).unwrap_or_default(); - - answers.push(answer); - } - answers - }), - on_verify: Box::new(move |kind, text| { - trace!("[manager client] on_verify({kind}, {text})"); - match kind { - AuthVerifyKind::Host => { - eprintln!("{}", text); - - match text_prompt("Enter [y/N]> ") { - Ok(answer) => { - trace!("Verify? Answer = '{answer}'"); - matches!(answer.trim(), "y" | "Y" | "yes" | "YES") - } - Err(x) => { - error!("Failed verification: {x}"); - false - } - } - } - x => { - error!("Unsupported verify kind: {x}"); - false - } - } - }), - on_info: Box::new(|text| { - trace!("[manager client] on_info({text})"); - println!("{}", text); - }), - on_error: Box::new(|kind, text| { - trace!("[manager client] on_error({kind}, {text})"); - eprintln!("{}: {}", kind, text); - }), - } - } -} diff --git a/distant-core/src/manager/client/ext.rs b/distant-core/src/manager/client/ext.rs deleted file mode 100644 index d23a3d2..0000000 --- a/distant-core/src/manager/client/ext.rs +++ /dev/null @@ -1,14 +0,0 @@ -mod tcp; -pub use tcp::*; - -#[cfg(unix)] -mod unix; - -#[cfg(unix)] -pub use unix::*; - -#[cfg(windows)] -mod windows; - -#[cfg(windows)] -pub use windows::*; diff --git a/distant-core/src/manager/client/ext/tcp.rs b/distant-core/src/manager/client/ext/tcp.rs deleted file mode 100644 index e31ffc1..0000000 --- a/distant-core/src/manager/client/ext/tcp.rs +++ /dev/null @@ -1,50 +0,0 @@ -use crate::{DistantManagerClient, DistantManagerClientConfig}; -use async_trait::async_trait; -use distant_net::{Codec, FramedTransport, TcpTransport}; -use std::{convert, net::SocketAddr}; -use tokio::{io, time::Duration}; - -#[async_trait] -pub trait TcpDistantManagerClientExt { - /// Connect to a remote TCP server using the provided information - async fn connect( - config: DistantManagerClientConfig, - addr: SocketAddr, - codec: C, - ) -> io::Result - where - C: Codec + Send + 'static; - - /// Connect to a remote TCP server, timing out after duration has passed - async fn connect_timeout( - config: DistantManagerClientConfig, - addr: SocketAddr, - codec: C, - duration: Duration, - ) -> io::Result - where - C: Codec + Send + 'static, - { - tokio::time::timeout(duration, Self::connect(config, addr, codec)) - .await - .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) - .and_then(convert::identity) - } -} - -#[async_trait] -impl TcpDistantManagerClientExt for DistantManagerClient { - /// Connect to a remote TCP server using the provided information - async fn connect( - config: DistantManagerClientConfig, - addr: SocketAddr, - codec: C, - ) -> io::Result - where - C: Codec + Send + 'static, - { - let transport = TcpTransport::connect(addr).await?; - let transport = FramedTransport::new(transport, codec); - Self::new(config, transport) - } -} diff --git a/distant-core/src/manager/client/ext/unix.rs b/distant-core/src/manager/client/ext/unix.rs deleted file mode 100644 index 18df8c8..0000000 --- a/distant-core/src/manager/client/ext/unix.rs +++ /dev/null @@ -1,54 +0,0 @@ -use crate::{DistantManagerClient, DistantManagerClientConfig}; -use async_trait::async_trait; -use distant_net::{Codec, FramedTransport, UnixSocketTransport}; -use std::{convert, path::Path}; -use tokio::{io, time::Duration}; - -#[async_trait] -pub trait UnixSocketDistantManagerClientExt { - /// Connect to a proxy unix socket - async fn connect( - config: DistantManagerClientConfig, - path: P, - codec: C, - ) -> io::Result - where - P: AsRef + Send, - C: Codec + Send + 'static; - - /// Connect to a proxy unix socket, timing out after duration has passed - async fn connect_timeout( - config: DistantManagerClientConfig, - path: P, - codec: C, - duration: Duration, - ) -> io::Result - where - P: AsRef + Send, - C: Codec + Send + 'static, - { - tokio::time::timeout(duration, Self::connect(config, path, codec)) - .await - .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) - .and_then(convert::identity) - } -} - -#[async_trait] -impl UnixSocketDistantManagerClientExt for DistantManagerClient { - /// Connect to a proxy unix socket - async fn connect( - config: DistantManagerClientConfig, - path: P, - codec: C, - ) -> io::Result - where - P: AsRef + Send, - C: Codec + Send + 'static, - { - let p = path.as_ref(); - let transport = UnixSocketTransport::connect(p).await?; - let transport = FramedTransport::new(transport, codec); - Ok(DistantManagerClient::new(config, transport)?) - } -} diff --git a/distant-core/src/manager/client/ext/windows.rs b/distant-core/src/manager/client/ext/windows.rs deleted file mode 100644 index c13f10a..0000000 --- a/distant-core/src/manager/client/ext/windows.rs +++ /dev/null @@ -1,91 +0,0 @@ -use crate::{DistantManagerClient, DistantManagerClientConfig}; -use async_trait::async_trait; -use distant_net::{Codec, FramedTransport, WindowsPipeTransport}; -use std::{ - convert, - ffi::{OsStr, OsString}, -}; -use tokio::{io, time::Duration}; - -#[async_trait] -pub trait WindowsPipeDistantManagerClientExt { - /// Connect to a server listening on a Windows pipe at the specified address - /// using the given codec - async fn connect( - config: DistantManagerClientConfig, - addr: A, - codec: C, - ) -> io::Result - where - A: AsRef + Send, - C: Codec + Send + 'static; - - /// Connect to a server listening on a Windows pipe at the specified address - /// via `\\.\pipe\{name}` using the given codec - async fn connect_local( - config: DistantManagerClientConfig, - name: N, - codec: C, - ) -> io::Result - where - N: AsRef + Send, - C: Codec + Send + 'static, - { - let mut addr = OsString::from(r"\\.\pipe\"); - addr.push(name.as_ref()); - Self::connect(config, addr, codec).await - } - - /// Connect to a server listening on a Windows pipe at the specified address - /// using the given codec, timing out after duration has passed - async fn connect_timeout( - config: DistantManagerClientConfig, - addr: A, - codec: C, - duration: Duration, - ) -> io::Result - where - A: AsRef + Send, - C: Codec + Send + 'static, - { - tokio::time::timeout(duration, Self::connect(config, addr, codec)) - .await - .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) - .and_then(convert::identity) - } - - /// Connect to a server listening on a Windows pipe at the specified address - /// via `\\.\pipe\{name}` using the given codec, timing out after duration has passed - async fn connect_local_timeout( - config: DistantManagerClientConfig, - name: N, - codec: C, - duration: Duration, - ) -> io::Result - where - N: AsRef + Send, - C: Codec + Send + 'static, - { - let mut addr = OsString::from(r"\\.\pipe\"); - addr.push(name.as_ref()); - Self::connect_timeout(config, addr, codec, duration).await - } -} - -#[async_trait] -impl WindowsPipeDistantManagerClientExt for DistantManagerClient { - async fn connect( - config: DistantManagerClientConfig, - addr: A, - codec: C, - ) -> io::Result - where - A: AsRef + Send, - C: Codec + Send + 'static, - { - let a = addr.as_ref(); - let transport = WindowsPipeTransport::connect(a).await?; - let transport = FramedTransport::new(transport, codec); - Ok(DistantManagerClient::new(config, transport)?) - } -} diff --git a/distant-core/src/manager/data/id.rs b/distant-core/src/manager/data/id.rs deleted file mode 100644 index 34abc0d..0000000 --- a/distant-core/src/manager/data/id.rs +++ /dev/null @@ -1,5 +0,0 @@ -/// Id associated with an active connection -pub type ConnectionId = u64; - -/// Id associated with an open channel -pub type ChannelId = u64; diff --git a/distant-core/src/manager/server.rs b/distant-core/src/manager/server.rs deleted file mode 100644 index 979a2da..0000000 --- a/distant-core/src/manager/server.rs +++ /dev/null @@ -1,719 +0,0 @@ -use crate::{ - ChannelId, ConnectionId, ConnectionInfo, ConnectionList, Destination, ManagerCapabilities, - ManagerRequest, ManagerResponse, Map, -}; -use async_trait::async_trait; -use distant_net::{ - router, Auth, AuthClient, Client, IntoSplit, Listener, MpscListener, Request, Response, Server, - ServerCtx, ServerExt, UntypedTransportRead, UntypedTransportWrite, -}; -use log::*; -use std::{collections::HashMap, io, sync::Arc}; -use tokio::{ - sync::{mpsc, Mutex, RwLock}, - task::JoinHandle, -}; - -mod config; -pub use config::*; - -mod connection; -pub use connection::*; - -mod ext; -pub use ext::*; - -mod handler; -pub use handler::*; - -mod r#ref; -pub use r#ref::*; - -router!(DistantManagerRouter { - auth_transport: Response => Request, - manager_transport: Request => Response, -}); - -/// Represents a manager of multiple distant server connections -pub struct DistantManager { - /// Receives authentication clients to feed into local data of server - auth_client_rx: Mutex>, - - /// Configuration settings for the server - config: DistantManagerConfig, - - /// Mapping of connection id -> connection - connections: RwLock>, - - /// Handlers for launch requests - launch_handlers: Arc>>, - - /// Handlers for connect requests - connect_handlers: Arc>>, - - /// Primary task of server - task: JoinHandle<()>, -} - -impl DistantManager { - /// Initializes a new instance of [`DistantManagerServer`] using the provided [`UntypedTransport`] - pub fn start( - mut config: DistantManagerConfig, - mut listener: L, - ) -> io::Result - where - L: Listener + 'static, - T: IntoSplit + Send + 'static, - T::Read: UntypedTransportRead + 'static, - T::Write: UntypedTransportWrite + 'static, - { - let (conn_tx, mpsc_listener) = MpscListener::channel(config.connection_buffer_size); - let (auth_client_tx, auth_client_rx) = mpsc::channel(1); - - // Spawn task that uses our input listener to get both auth and manager events, - // forwarding manager events to the internal mpsc listener - let task = tokio::spawn(async move { - while let Ok(transport) = listener.accept().await { - let DistantManagerRouter { - auth_transport, - manager_transport, - .. - } = DistantManagerRouter::new(transport); - - let (writer, reader) = auth_transport.into_split(); - let client = match Client::new(writer, reader) { - Ok(client) => client, - Err(x) => { - error!("Creating auth client failed: {}", x); - continue; - } - }; - let auth_client = AuthClient::from(client); - - // Forward auth client for new connection in server - if auth_client_tx.send(auth_client).await.is_err() { - break; - } - - // Forward connected and routed transport to server - if conn_tx.send(manager_transport.into_split()).await.is_err() { - break; - } - } - }); - - let launch_handlers = Arc::new(RwLock::new(config.launch_handlers.drain().collect())); - let weak_launch_handlers = Arc::downgrade(&launch_handlers); - let connect_handlers = Arc::new(RwLock::new(config.connect_handlers.drain().collect())); - let weak_connect_handlers = Arc::downgrade(&connect_handlers); - let server_ref = Self { - auth_client_rx: Mutex::new(auth_client_rx), - config, - launch_handlers, - connect_handlers, - connections: RwLock::new(HashMap::new()), - task, - } - .start(mpsc_listener)?; - - Ok(DistantManagerRef { - launch_handlers: weak_launch_handlers, - connect_handlers: weak_connect_handlers, - inner: server_ref, - }) - } - - /// Launches a new server at the specified `destination` using the given `options` information - /// and authentication client (if needed) to retrieve additional information needed to - /// enter the destination prior to starting the server, returning the destination of the - /// launched server - async fn launch( - &self, - destination: Destination, - options: Map, - auth: Option<&mut AuthClient>, - ) -> io::Result { - let auth = auth.ok_or_else(|| { - io::Error::new( - io::ErrorKind::Other, - "Authentication client not initialized", - ) - })?; - - let scheme = match destination.scheme.as_deref() { - Some(scheme) => { - trace!("Using scheme {}", scheme); - scheme - } - None => { - trace!( - "Using fallback scheme of {}", - self.config.launch_fallback_scheme.as_str() - ); - self.config.launch_fallback_scheme.as_str() - } - } - .to_lowercase(); - - let credentials = { - let lock = self.launch_handlers.read().await; - let handler = lock.get(&scheme).ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("No launch handler registered for {}", scheme), - ) - })?; - handler.launch(&destination, &options, auth).await? - }; - - Ok(credentials) - } - - /// Connects to a new server at the specified `destination` using the given `options` information - /// and authentication client (if needed) to retrieve additional information needed to - /// establish the connection to the server - async fn connect( - &self, - destination: Destination, - options: Map, - auth: Option<&mut AuthClient>, - ) -> io::Result { - let auth = auth.ok_or_else(|| { - io::Error::new( - io::ErrorKind::Other, - "Authentication client not initialized", - ) - })?; - - let scheme = match destination.scheme.as_deref() { - Some(scheme) => { - trace!("Using scheme {}", scheme); - scheme - } - None => { - trace!( - "Using fallback scheme of {}", - self.config.connect_fallback_scheme.as_str() - ); - self.config.connect_fallback_scheme.as_str() - } - } - .to_lowercase(); - - let (writer, reader) = { - let lock = self.connect_handlers.read().await; - let handler = lock.get(&scheme).ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("No connect handler registered for {}", scheme), - ) - })?; - handler.connect(&destination, &options, auth).await? - }; - - let connection = DistantManagerConnection::new(destination, options, writer, reader); - let id = connection.id; - self.connections.write().await.insert(id, connection); - Ok(id) - } - - /// Retrieves the list of supported capabilities for this manager - async fn capabilities(&self) -> io::Result { - Ok(ManagerCapabilities::all()) - } - - /// Retrieves information about the connection to the server with the specified `id` - async fn info(&self, id: ConnectionId) -> io::Result { - match self.connections.read().await.get(&id) { - Some(connection) => Ok(ConnectionInfo { - id: connection.id, - destination: connection.destination.clone(), - options: connection.options.clone(), - }), - None => Err(io::Error::new( - io::ErrorKind::NotConnected, - "No connection found", - )), - } - } - - /// Retrieves a list of connections to servers - async fn list(&self) -> io::Result { - Ok(ConnectionList( - self.connections - .read() - .await - .values() - .map(|conn| (conn.id, conn.destination.clone())) - .collect(), - )) - } - - /// Kills the connection to the server with the specified `id` - async fn kill(&self, id: ConnectionId) -> io::Result<()> { - match self.connections.write().await.remove(&id) { - Some(_) => Ok(()), - None => Err(io::Error::new( - io::ErrorKind::NotConnected, - "No connection found", - )), - } - } -} - -#[derive(Default)] -pub struct DistantManagerServerConnection { - /// Authentication client that manager can use when establishing a new connection - /// and needing to get authentication details from the client to move forward - auth_client: Option>, - - /// Holds on to open channels feeding data back from a server to some connected client, - /// enabling us to cancel the tasks on demand - channels: RwLock>, -} - -#[async_trait] -impl Server for DistantManager { - type Request = ManagerRequest; - type Response = ManagerResponse; - type LocalData = DistantManagerServerConnection; - - async fn on_accept(&self, local_data: &mut Self::LocalData) { - local_data.auth_client = self - .auth_client_rx - .lock() - .await - .recv() - .await - .map(Mutex::new); - - // Enable jit handshake - if let Some(auth_client) = local_data.auth_client.as_ref() { - auth_client.lock().await.set_jit_handshake(true); - } - } - - async fn on_request(&self, ctx: ServerCtx) { - let ServerCtx { - connection_id, - request, - reply, - local_data, - } = ctx; - - let response = match request.payload { - ManagerRequest::Capabilities {} => match self.capabilities().await { - Ok(supported) => ManagerResponse::Capabilities { supported }, - Err(x) => ManagerResponse::Error(x.into()), - }, - ManagerRequest::Launch { - destination, - options, - } => { - let mut auth = match local_data.auth_client.as_ref() { - Some(client) => Some(client.lock().await), - None => None, - }; - - match self - .launch(*destination, options, auth.as_deref_mut()) - .await - { - Ok(destination) => ManagerResponse::Launched { destination }, - Err(x) => ManagerResponse::Error(x.into()), - } - } - ManagerRequest::Connect { - destination, - options, - } => { - let mut auth = match local_data.auth_client.as_ref() { - Some(client) => Some(client.lock().await), - None => None, - }; - - match self - .connect(*destination, options, auth.as_deref_mut()) - .await - { - Ok(id) => ManagerResponse::Connected { id }, - Err(x) => ManagerResponse::Error(x.into()), - } - } - ManagerRequest::OpenChannel { id } => match self.connections.read().await.get(&id) { - Some(connection) => match connection.open_channel(reply.clone()).await { - Ok(channel) => { - let id = channel.id(); - local_data.channels.write().await.insert(id, channel); - ManagerResponse::ChannelOpened { id } - } - Err(x) => ManagerResponse::Error(x.into()), - }, - None => ManagerResponse::Error( - io::Error::new(io::ErrorKind::NotConnected, "Connection does not exist").into(), - ), - }, - ManagerRequest::Channel { id, request } => { - match local_data.channels.read().await.get(&id) { - // TODO: For now, we are NOT sending back a response to acknowledge - // a successful channel send. We could do this in order for - // the client to listen for a complete send, but is it worth it? - Some(channel) => match channel.send(request).await { - Ok(_) => return, - Err(x) => ManagerResponse::Error(x.into()), - }, - None => ManagerResponse::Error( - io::Error::new( - io::ErrorKind::NotConnected, - "Channel is not open or does not exist", - ) - .into(), - ), - } - } - ManagerRequest::CloseChannel { id } => { - match local_data.channels.write().await.remove(&id) { - Some(channel) => match channel.close().await { - Ok(_) => ManagerResponse::ChannelClosed { id }, - Err(x) => ManagerResponse::Error(x.into()), - }, - None => ManagerResponse::Error( - io::Error::new( - io::ErrorKind::NotConnected, - "Channel is not open or does not exist", - ) - .into(), - ), - } - } - ManagerRequest::Info { id } => match self.info(id).await { - Ok(info) => ManagerResponse::Info(info), - Err(x) => ManagerResponse::Error(x.into()), - }, - ManagerRequest::List => match self.list().await { - Ok(list) => ManagerResponse::List(list), - Err(x) => ManagerResponse::Error(x.into()), - }, - ManagerRequest::Kill { id } => match self.kill(id).await { - Ok(()) => ManagerResponse::Killed, - Err(x) => ManagerResponse::Error(x.into()), - }, - ManagerRequest::Shutdown => { - if let Err(x) = reply.send(ManagerResponse::Shutdown).await { - error!("[Conn {}] {}", connection_id, x); - } - - // Clear out handler state in order to trigger drops - self.launch_handlers.write().await.clear(); - self.connect_handlers.write().await.clear(); - - // Shutdown the primary server task - self.task.abort(); - - // TODO: Perform a graceful shutdown instead of this? - // Review https://tokio.rs/tokio/topics/shutdown - std::process::exit(0); - } - }; - - if let Err(x) = reply.send(response).await { - error!("[Conn {}] {}", connection_id, x); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use distant_net::{ - AuthClient, FramedTransport, HeapAuthServer, InmemoryTransport, IntoSplit, MappedListener, - OneshotListener, PlainCodec, ServerExt, ServerRef, - }; - - /// Create a new server, bypassing the start loop - fn setup() -> DistantManager { - let (_, rx) = mpsc::channel(1); - DistantManager { - auth_client_rx: Mutex::new(rx), - config: Default::default(), - connections: RwLock::new(HashMap::new()), - launch_handlers: Arc::new(RwLock::new(HashMap::new())), - connect_handlers: Arc::new(RwLock::new(HashMap::new())), - task: tokio::spawn(async move {}), - } - } - - /// Creates a connected [`AuthClient`] with a launched auth server that blindly responds - fn auth_client_server() -> (AuthClient, Box) { - let (t1, t2) = FramedTransport::pair(1); - let client = AuthClient::from(Client::from_framed_transport(t1).unwrap()); - - // Create a server that does nothing, but will support - let server = HeapAuthServer { - on_challenge: Box::new(|_, _| Vec::new()), - on_verify: Box::new(|_, _| false), - on_info: Box::new(|_| ()), - on_error: Box::new(|_, _| ()), - } - .start(MappedListener::new(OneshotListener::from_value(t2), |t| { - t.into_split() - })) - .unwrap(); - - (client, server) - } - - fn dummy_distant_writer_reader() -> (BoxedDistantWriter, BoxedDistantReader) { - setup_distant_writer_reader().0 - } - - /// Creates a writer & reader with a connected transport - fn setup_distant_writer_reader() -> ( - (BoxedDistantWriter, BoxedDistantReader), - FramedTransport, - ) { - let (t1, t2) = FramedTransport::pair(1); - let (writer, reader) = t1.into_split(); - ((Box::new(writer), Box::new(reader)), t2) - } - - #[tokio::test] - async fn launch_should_fail_if_destination_scheme_is_unsupported() { - let server = setup(); - - let destination = "scheme://host".parse::().unwrap(); - let options = "".parse::().unwrap(); - let (mut auth, _auth_server) = auth_client_server(); - let err = server - .launch(destination, options, Some(&mut auth)) - .await - .unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err); - } - - #[tokio::test] - async fn launch_should_fail_if_handler_tied_to_scheme_fails() { - let server = setup(); - - let handler: Box = Box::new(|_: &_, _: &_, _: &mut _| async { - Err(io::Error::new(io::ErrorKind::Other, "test failure")) - }); - - server - .launch_handlers - .write() - .await - .insert("scheme".to_string(), handler); - - let destination = "scheme://host".parse::().unwrap(); - let options = "".parse::().unwrap(); - let (mut auth, _auth_server) = auth_client_server(); - let err = server - .launch(destination, options, Some(&mut auth)) - .await - .unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::Other); - assert_eq!(err.to_string(), "test failure"); - } - - #[tokio::test] - async fn launch_should_return_new_destination_on_success() { - let server = setup(); - - let handler: Box = { - Box::new(|_: &_, _: &_, _: &mut _| async { - Ok("scheme2://host2".parse::().unwrap()) - }) - }; - - server - .launch_handlers - .write() - .await - .insert("scheme".to_string(), handler); - - let destination = "scheme://host".parse::().unwrap(); - let options = "key=value".parse::().unwrap(); - let (mut auth, _auth_server) = auth_client_server(); - let destination = server - .launch(destination, options, Some(&mut auth)) - .await - .unwrap(); - - assert_eq!( - destination, - "scheme2://host2".parse::().unwrap() - ); - } - - #[tokio::test] - async fn connect_should_fail_if_destination_scheme_is_unsupported() { - let server = setup(); - - let destination = "scheme://host".parse::().unwrap(); - let options = "".parse::().unwrap(); - let (mut auth, _auth_server) = auth_client_server(); - let err = server - .connect(destination, options, Some(&mut auth)) - .await - .unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err); - } - - #[tokio::test] - async fn connect_should_fail_if_handler_tied_to_scheme_fails() { - let server = setup(); - - let handler: Box = Box::new(|_: &_, _: &_, _: &mut _| async { - Err(io::Error::new(io::ErrorKind::Other, "test failure")) - }); - - server - .connect_handlers - .write() - .await - .insert("scheme".to_string(), handler); - - let destination = "scheme://host".parse::().unwrap(); - let options = "".parse::().unwrap(); - let (mut auth, _auth_server) = auth_client_server(); - let err = server - .connect(destination, options, Some(&mut auth)) - .await - .unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::Other); - assert_eq!(err.to_string(), "test failure"); - } - - #[tokio::test] - async fn connect_should_return_id_of_new_connection_on_success() { - let server = setup(); - - let handler: Box = - Box::new(|_: &_, _: &_, _: &mut _| async { Ok(dummy_distant_writer_reader()) }); - - server - .connect_handlers - .write() - .await - .insert("scheme".to_string(), handler); - - let destination = "scheme://host".parse::().unwrap(); - let options = "key=value".parse::().unwrap(); - let (mut auth, _auth_server) = auth_client_server(); - let id = server - .connect(destination, options, Some(&mut auth)) - .await - .unwrap(); - - let lock = server.connections.read().await; - let connection = lock.get(&id).unwrap(); - assert_eq!(connection.id, id); - assert_eq!(connection.destination, "scheme://host"); - assert_eq!(connection.options, "key=value".parse().unwrap()); - } - - #[tokio::test] - async fn info_should_fail_if_no_connection_found_for_specified_id() { - let server = setup(); - - let err = server.info(999).await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err); - } - - #[tokio::test] - async fn info_should_return_information_about_established_connection() { - let server = setup(); - - let (writer, reader) = dummy_distant_writer_reader(); - let connection = DistantManagerConnection::new( - "scheme://host".parse().unwrap(), - "key=value".parse().unwrap(), - writer, - reader, - ); - let id = connection.id; - server.connections.write().await.insert(id, connection); - - let info = server.info(id).await.unwrap(); - assert_eq!( - info, - ConnectionInfo { - id, - destination: "scheme://host".parse().unwrap(), - options: "key=value".parse().unwrap(), - } - ); - } - - #[tokio::test] - async fn list_should_return_empty_connection_list_if_no_established_connections() { - let server = setup(); - - let list = server.list().await.unwrap(); - assert_eq!(list, ConnectionList(HashMap::new())); - } - - #[tokio::test] - async fn list_should_return_a_list_of_established_connections() { - let server = setup(); - - let (writer, reader) = dummy_distant_writer_reader(); - let connection = DistantManagerConnection::new( - "scheme://host".parse().unwrap(), - "key=value".parse().unwrap(), - writer, - reader, - ); - let id_1 = connection.id; - server.connections.write().await.insert(id_1, connection); - - let (writer, reader) = dummy_distant_writer_reader(); - let connection = DistantManagerConnection::new( - "other://host2".parse().unwrap(), - "key=value".parse().unwrap(), - writer, - reader, - ); - let id_2 = connection.id; - server.connections.write().await.insert(id_2, connection); - - let list = server.list().await.unwrap(); - assert_eq!( - list.get(&id_1).unwrap(), - &"scheme://host".parse::().unwrap() - ); - assert_eq!( - list.get(&id_2).unwrap(), - &"other://host2".parse::().unwrap() - ); - } - - #[tokio::test] - async fn kill_should_fail_if_no_connection_found_for_specified_id() { - let server = setup(); - - let err = server.kill(999).await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err); - } - - #[tokio::test] - async fn kill_should_terminate_established_connection_and_remove_it_from_the_list() { - let server = setup(); - - let (writer, reader) = dummy_distant_writer_reader(); - let connection = DistantManagerConnection::new( - "scheme://host".parse().unwrap(), - "key=value".parse().unwrap(), - writer, - reader, - ); - let id = connection.id; - server.connections.write().await.insert(id, connection); - - server.kill(id).await.unwrap(); - - let lock = server.connections.read().await; - assert!(!lock.contains_key(&id), "Connection still exists"); - } -} diff --git a/distant-core/src/manager/server/connection.rs b/distant-core/src/manager/server/connection.rs deleted file mode 100644 index 843597b..0000000 --- a/distant-core/src/manager/server/connection.rs +++ /dev/null @@ -1,202 +0,0 @@ -use crate::{ - data::Map, - manager::{ - data::{ChannelId, ConnectionId, Destination}, - BoxedDistantReader, BoxedDistantWriter, - }, - DistantMsg, DistantRequestData, DistantResponseData, ManagerResponse, -}; -use distant_net::{Request, Response, ServerReply}; -use log::*; -use std::{collections::HashMap, io}; -use tokio::{sync::mpsc, task::JoinHandle}; - -/// Represents a connection a distant manager has with some distant-compatible server -pub struct DistantManagerConnection { - pub id: ConnectionId, - pub destination: Destination, - pub options: Map, - tx: mpsc::Sender, - reader_task: JoinHandle<()>, - writer_task: JoinHandle<()>, -} - -#[derive(Clone)] -pub struct DistantManagerChannel { - channel_id: ChannelId, - tx: mpsc::Sender, -} - -impl DistantManagerChannel { - pub fn id(&self) -> ChannelId { - self.channel_id - } - - pub async fn send(&self, request: Request>) -> io::Result<()> { - let channel_id = self.channel_id; - self.tx - .send(StateMachine::Write { - id: channel_id, - request, - }) - .await - .map_err(|x| { - io::Error::new( - io::ErrorKind::BrokenPipe, - format!("channel {} send failed: {}", channel_id, x), - ) - }) - } - - pub async fn close(&self) -> io::Result<()> { - let channel_id = self.channel_id; - self.tx - .send(StateMachine::Unregister { id: channel_id }) - .await - .map_err(|x| { - io::Error::new( - io::ErrorKind::BrokenPipe, - format!("channel {} close failed: {}", channel_id, x), - ) - }) - } -} - -enum StateMachine { - Register { - id: ChannelId, - reply: ServerReply, - }, - - Unregister { - id: ChannelId, - }, - - Read { - response: Response>, - }, - - Write { - id: ChannelId, - request: Request>, - }, -} - -impl DistantManagerConnection { - pub fn new( - destination: Destination, - options: Map, - mut writer: BoxedDistantWriter, - mut reader: BoxedDistantReader, - ) -> Self { - let connection_id = rand::random(); - let (tx, mut rx) = mpsc::channel(1); - let reader_task = { - let tx = tx.clone(); - tokio::spawn(async move { - loop { - match reader.read().await { - Ok(Some(response)) => { - if tx.send(StateMachine::Read { response }).await.is_err() { - break; - } - } - Ok(None) => break, - Err(x) => { - error!("[Conn {}] {}", connection_id, x); - continue; - } - } - } - }) - }; - let writer_task = tokio::spawn(async move { - let mut registered = HashMap::new(); - while let Some(state_machine) = rx.recv().await { - match state_machine { - StateMachine::Register { id, reply } => { - registered.insert(id, reply); - } - StateMachine::Unregister { id } => { - registered.remove(&id); - } - StateMachine::Read { mut response } => { - // Split {channel id}_{request id} back into pieces and - // update the origin id to match the request id only - let channel_id = match response.origin_id.split_once('_') { - Some((cid_str, oid_str)) => { - if let Ok(cid) = cid_str.parse::() { - response.origin_id = oid_str.to_string(); - cid - } else { - continue; - } - } - None => continue, - }; - - if let Some(reply) = registered.get(&channel_id) { - let response = ManagerResponse::Channel { - id: channel_id, - response, - }; - if let Err(x) = reply.send(response).await { - error!("[Conn {}] {}", connection_id, x); - } - } - } - StateMachine::Write { id, request } => { - // Combine channel id with request id so we can properly forward - // the response containing this in the origin id - let request = Request { - id: format!("{}_{}", id, request.id), - payload: request.payload, - }; - if let Err(x) = writer.write(request).await { - error!("[Conn {}] {}", connection_id, x); - } - } - } - } - }); - - Self { - id: connection_id, - destination, - options, - tx, - reader_task, - writer_task, - } - } - - pub async fn open_channel( - &self, - reply: ServerReply, - ) -> io::Result { - let channel_id = rand::random(); - self.tx - .send(StateMachine::Register { - id: channel_id, - reply, - }) - .await - .map_err(|x| { - io::Error::new( - io::ErrorKind::BrokenPipe, - format!("open_channel failed: {}", x), - ) - })?; - Ok(DistantManagerChannel { - channel_id, - tx: self.tx.clone(), - }) - } -} - -impl Drop for DistantManagerConnection { - fn drop(&mut self) { - self.reader_task.abort(); - self.writer_task.abort(); - } -} diff --git a/distant-core/src/manager/server/ext.rs b/distant-core/src/manager/server/ext.rs deleted file mode 100644 index d23a3d2..0000000 --- a/distant-core/src/manager/server/ext.rs +++ /dev/null @@ -1,14 +0,0 @@ -mod tcp; -pub use tcp::*; - -#[cfg(unix)] -mod unix; - -#[cfg(unix)] -pub use unix::*; - -#[cfg(windows)] -mod windows; - -#[cfg(windows)] -pub use windows::*; diff --git a/distant-core/src/manager/server/ext/tcp.rs b/distant-core/src/manager/server/ext/tcp.rs deleted file mode 100644 index f9a2f6d..0000000 --- a/distant-core/src/manager/server/ext/tcp.rs +++ /dev/null @@ -1,30 +0,0 @@ -use crate::{DistantManager, DistantManagerConfig}; -use distant_net::{ - Codec, FramedTransport, IntoSplit, MappedListener, PortRange, TcpListener, TcpServerRef, -}; -use std::{io, net::IpAddr}; - -impl DistantManager { - /// Start a new server by binding to the given IP address and one of the ports in the - /// specified range, mapping all connections to use the given codec - pub async fn start_tcp( - config: DistantManagerConfig, - addr: IpAddr, - port: P, - codec: C, - ) -> io::Result - where - P: Into + Send, - C: Codec + Send + Sync + 'static, - { - let listener = TcpListener::bind(addr, port).await?; - let port = listener.port(); - - let listener = MappedListener::new(listener, move |transport| { - let transport = FramedTransport::new(transport, codec.clone()); - transport.into_split() - }); - let inner = DistantManager::start(config, listener)?; - Ok(TcpServerRef::new(addr, port, Box::new(inner))) - } -} diff --git a/distant-core/src/manager/server/ext/unix.rs b/distant-core/src/manager/server/ext/unix.rs deleted file mode 100644 index fec9743..0000000 --- a/distant-core/src/manager/server/ext/unix.rs +++ /dev/null @@ -1,50 +0,0 @@ -use crate::{DistantManager, DistantManagerConfig}; -use distant_net::{ - Codec, FramedTransport, IntoSplit, MappedListener, UnixSocketListener, UnixSocketServerRef, -}; -use std::{io, path::Path}; - -impl DistantManager { - /// Start a new server using the specified path as a unix socket using default unix socket file - /// permissions - pub async fn start_unix_socket( - config: DistantManagerConfig, - path: P, - codec: C, - ) -> io::Result - where - P: AsRef + Send, - C: Codec + Send + Sync + 'static, - { - Self::start_unix_socket_with_permissions( - config, - path, - codec, - UnixSocketListener::default_unix_socket_file_permissions(), - ) - .await - } - - /// Start a new server using the specified path as a unix socket and `mode` as the unix socket - /// file permissions - pub async fn start_unix_socket_with_permissions( - config: DistantManagerConfig, - path: P, - codec: C, - mode: u32, - ) -> io::Result - where - P: AsRef + Send, - C: Codec + Send + Sync + 'static, - { - let listener = UnixSocketListener::bind_with_permissions(path, mode).await?; - let path = listener.path().to_path_buf(); - - let listener = MappedListener::new(listener, move |transport| { - let transport = FramedTransport::new(transport, codec.clone()); - transport.into_split() - }); - let inner = DistantManager::start(config, listener)?; - Ok(UnixSocketServerRef::new(path, Box::new(inner))) - } -} diff --git a/distant-core/src/manager/server/ext/windows.rs b/distant-core/src/manager/server/ext/windows.rs deleted file mode 100644 index 537bbfe..0000000 --- a/distant-core/src/manager/server/ext/windows.rs +++ /dev/null @@ -1,48 +0,0 @@ -use crate::{DistantManager, DistantManagerConfig}; -use distant_net::{ - Codec, FramedTransport, IntoSplit, MappedListener, WindowsPipeListener, WindowsPipeServerRef, -}; -use std::{ - ffi::{OsStr, OsString}, - io, -}; - -impl DistantManager { - /// Start a new server at the specified address via `\\.\pipe\{name}` using the given codec - pub async fn start_local_named_pipe( - config: DistantManagerConfig, - name: N, - codec: C, - ) -> io::Result - where - Self: Sized, - N: AsRef + Send, - C: Codec + Send + Sync + 'static, - { - let mut addr = OsString::from(r"\\.\pipe\"); - addr.push(name.as_ref()); - Self::start_named_pipe(config, addr, codec).await - } - - /// Start a new server at the specified pipe address using the given codec - pub async fn start_named_pipe( - config: DistantManagerConfig, - addr: A, - codec: C, - ) -> io::Result - where - A: AsRef + Send, - C: Codec + Send + Sync + 'static, - { - let a = addr.as_ref(); - let listener = WindowsPipeListener::bind(a)?; - let addr = listener.addr().to_os_string(); - - let listener = MappedListener::new(listener, move |transport| { - let transport = FramedTransport::new(transport, codec.clone()); - transport.into_split() - }); - let inner = DistantManager::start(config, listener)?; - Ok(WindowsPipeServerRef::new(addr, Box::new(inner))) - } -} diff --git a/distant-core/src/manager/server/handler.rs b/distant-core/src/manager/server/handler.rs deleted file mode 100644 index f0fa6be..0000000 --- a/distant-core/src/manager/server/handler.rs +++ /dev/null @@ -1,68 +0,0 @@ -use crate::{ - data::Map, manager::data::Destination, DistantMsg, DistantRequestData, DistantResponseData, -}; -use async_trait::async_trait; -use distant_net::{AuthClient, Request, Response, TypedAsyncRead, TypedAsyncWrite}; -use std::{future::Future, io}; - -pub type BoxedDistantWriter = - Box>> + Send>; -pub type BoxedDistantReader = - Box>> + Send>; -pub type BoxedDistantWriterReader = (BoxedDistantWriter, BoxedDistantReader); -pub type BoxedLaunchHandler = Box; -pub type BoxedConnectHandler = Box; - -/// Used to launch a server at the specified destination, returning some result as a vec of bytes -#[async_trait] -pub trait LaunchHandler: Send + Sync { - async fn launch( - &self, - destination: &Destination, - options: &Map, - auth_client: &mut AuthClient, - ) -> io::Result; -} - -#[async_trait] -impl LaunchHandler for F -where - F: for<'a> Fn(&'a Destination, &'a Map, &'a mut AuthClient) -> R + Send + Sync + 'static, - R: Future> + Send + 'static, -{ - async fn launch( - &self, - destination: &Destination, - options: &Map, - auth_client: &mut AuthClient, - ) -> io::Result { - self(destination, options, auth_client).await - } -} - -/// Used to connect to a destination, returning a connected reader and writer pair -#[async_trait] -pub trait ConnectHandler: Send + Sync { - async fn connect( - &self, - destination: &Destination, - options: &Map, - auth_client: &mut AuthClient, - ) -> io::Result; -} - -#[async_trait] -impl ConnectHandler for F -where - F: for<'a> Fn(&'a Destination, &'a Map, &'a mut AuthClient) -> R + Send + Sync + 'static, - R: Future> + Send + 'static, -{ - async fn connect( - &self, - destination: &Destination, - options: &Map, - auth_client: &mut AuthClient, - ) -> io::Result { - self(destination, options, auth_client).await - } -} diff --git a/distant-core/src/manager/server/ref.rs b/distant-core/src/manager/server/ref.rs deleted file mode 100644 index 360a00f..0000000 --- a/distant-core/src/manager/server/ref.rs +++ /dev/null @@ -1,73 +0,0 @@ -use super::{BoxedConnectHandler, BoxedLaunchHandler, ConnectHandler, LaunchHandler}; -use distant_net::{ServerRef, ServerState}; -use std::{collections::HashMap, io, sync::Weak}; -use tokio::sync::RwLock; - -/// Reference to a distant manager's server instance -pub struct DistantManagerRef { - /// Mapping of "scheme" -> handler - pub(crate) launch_handlers: Weak>>, - - /// Mapping of "scheme" -> handler - pub(crate) connect_handlers: Weak>>, - - pub(crate) inner: Box, -} - -impl DistantManagerRef { - /// Registers a new [`LaunchHandler`] for the specified scheme (e.g. "distant" or "ssh") - pub async fn register_launch_handler( - &self, - scheme: impl Into, - handler: impl LaunchHandler + 'static, - ) -> io::Result<()> { - let handlers = Weak::upgrade(&self.launch_handlers).ok_or_else(|| { - io::Error::new( - io::ErrorKind::Other, - "Handler reference is no longer available", - ) - })?; - - handlers - .write() - .await - .insert(scheme.into(), Box::new(handler)); - - Ok(()) - } - - /// Registers a new [`ConnectHandler`] for the specified scheme (e.g. "distant" or "ssh") - pub async fn register_connect_handler( - &self, - scheme: impl Into, - handler: impl ConnectHandler + 'static, - ) -> io::Result<()> { - let handlers = Weak::upgrade(&self.connect_handlers).ok_or_else(|| { - io::Error::new( - io::ErrorKind::Other, - "Handler reference is no longer available", - ) - })?; - - handlers - .write() - .await - .insert(scheme.into(), Box::new(handler)); - - Ok(()) - } -} - -impl ServerRef for DistantManagerRef { - fn state(&self) -> &ServerState { - self.inner.state() - } - - fn is_finished(&self) -> bool { - self.inner.is_finished() - } - - fn abort(&self) { - self.inner.abort(); - } -} diff --git a/distant-core/tests/manager_tests.rs b/distant-core/tests/manager_tests.rs deleted file mode 100644 index 40527b8..0000000 --- a/distant-core/tests/manager_tests.rs +++ /dev/null @@ -1,96 +0,0 @@ -use distant_core::{ - net::{FramedTransport, InmemoryTransport, IntoSplit, OneshotListener, PlainCodec}, - BoxedDistantReader, BoxedDistantWriter, Destination, DistantApiServer, DistantChannelExt, - DistantManager, DistantManagerClient, DistantManagerClientConfig, DistantManagerConfig, Map, -}; -use std::io; - -/// Creates a client transport and server listener for our tests -/// that are connected together -async fn setup() -> ( - FramedTransport, - OneshotListener>, -) { - let (t1, t2) = InmemoryTransport::pair(100); - - let listener = OneshotListener::from_value(FramedTransport::new(t2, PlainCodec)); - let transport = FramedTransport::new(t1, PlainCodec); - (transport, listener) -} - -#[tokio::test] -async fn should_be_able_to_establish_a_single_connection_and_communicate() { - let (transport, listener) = setup().await; - - let config = DistantManagerConfig::default(); - let manager_ref = DistantManager::start(config, listener).expect("Failed to start manager"); - - // NOTE: To pass in a raw function, we HAVE to specify the types of the parameters manually, - // otherwise we get a compilation error about lifetime mismatches - manager_ref - .register_connect_handler("scheme", |_: &_, _: &_, _: &mut _| async { - use distant_core::net::ServerExt; - let (t1, t2) = FramedTransport::pair(100); - - // Spawn a server on one end - let _ = DistantApiServer::local(Default::default()) - .unwrap() - .start(OneshotListener::from_value(t2.into_split()))?; - - // Create a reader/writer pair on the other end - let (writer, reader) = t1.into_split(); - let writer: BoxedDistantWriter = Box::new(writer); - let reader: BoxedDistantReader = Box::new(reader); - Ok((writer, reader)) - }) - .await - .expect("Failed to register handler"); - - let config = DistantManagerClientConfig::with_empty_prompts(); - let mut client = - DistantManagerClient::new(config, transport).expect("Failed to connect to manager"); - - // Test establishing a connection to some remote server - let id = client - .connect( - "scheme://host".parse::().unwrap(), - "key=value".parse::().unwrap(), - ) - .await - .expect("Failed to connect to a remote server"); - - // Test retrieving list of connections - let list = client - .list() - .await - .expect("Failed to get list of connections"); - assert_eq!(list.len(), 1); - assert_eq!(list.get(&id).unwrap().to_string(), "scheme://host"); - - // Test retrieving information - let info = client - .info(id) - .await - .expect("Failed to get info about connection"); - assert_eq!(info.id, id); - assert_eq!(info.destination.to_string(), "scheme://host"); - assert_eq!(info.options, "key=value".parse::().unwrap()); - - // Create a new channel and request some data - let mut channel = client - .open_channel(id) - .await - .expect("Failed to open channel"); - let _ = channel - .system_info() - .await - .expect("Failed to get system information"); - - // Test killing a connection - client.kill(id).await.expect("Failed to kill connection"); - - // Test getting an error to ensure that serialization of that data works, - // which we do by trying to access a connection that no longer exists - let err = client.info(id).await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::NotConnected); -} diff --git a/distant-core/tests/stress/distant/large_file.rs b/distant-core/tests/stress/distant/large_file.rs index 9b66fb4..feb69bc 100644 --- a/distant-core/tests/stress/distant/large_file.rs +++ b/distant-core/tests/stress/distant/large_file.rs @@ -2,6 +2,7 @@ use crate::stress::fixtures::*; use assert_fs::prelude::*; use distant_core::DistantChannelExt; use rstest::*; +use test_log::test; // 64KB is maximum TCP packet size const MAX_TCP_PACKET_BYTES: usize = 65535; @@ -10,7 +11,7 @@ const MAX_TCP_PACKET_BYTES: usize = 65535; const LARGE_FILE_LEN: usize = MAX_TCP_PACKET_BYTES * 10; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_handle_large_files(#[future] ctx: DistantClientCtx) { let ctx = ctx.await; let mut channel = ctx.client.clone_channel(); diff --git a/distant-core/tests/stress/distant/watch.rs b/distant-core/tests/stress/distant/watch.rs index a107204..1a54bce 100644 --- a/distant-core/tests/stress/distant/watch.rs +++ b/distant-core/tests/stress/distant/watch.rs @@ -2,11 +2,12 @@ use crate::stress::fixtures::*; use assert_fs::prelude::*; use distant_core::{data::ChangeKindSet, DistantChannelExt}; use rstest::*; +use test_log::test; const MAX_FILES: usize = 500; #[rstest] -#[tokio::test] +#[test(tokio::test)] #[ignore] async fn should_handle_large_volume_of_file_watching(#[future] ctx: DistantClientCtx) { let ctx = ctx.await; diff --git a/distant-core/tests/stress/fixtures.rs b/distant-core/tests/stress/fixtures.rs index 33d2657..ae06018 100644 --- a/distant-core/tests/stress/fixtures.rs +++ b/distant-core/tests/stress/fixtures.rs @@ -1,14 +1,13 @@ -use crate::stress::utils; -use distant_core::{DistantApiServer, DistantClient, LocalDistantApi}; -use distant_net::{ - PortRange, SecretKey, SecretKey32, TcpClientExt, TcpServerExt, XChaCha20Poly1305Codec, -}; +use distant_core::net::client::{Client, TcpConnector}; +use distant_core::net::common::authentication::{DummyAuthHandler, Verifier}; +use distant_core::net::common::PortRange; +use distant_core::net::server::Server; +use distant_core::{DistantApiServerHandler, DistantClient, LocalDistantApi}; use rstest::*; +use std::net::SocketAddr; use std::time::Duration; use tokio::sync::mpsc; -const LOG_PATH: &str = "/tmp/test.distant.server.log"; - pub struct DistantClientCtx { pub client: DistantClient, _done_tx: mpsc::Sender<()>, @@ -18,42 +17,43 @@ impl DistantClientCtx { pub async fn initialize() -> Self { let ip_addr = "127.0.0.1".parse().unwrap(); let (done_tx, mut done_rx) = mpsc::channel::<()>(1); - let (started_tx, mut started_rx) = mpsc::channel::<(u16, SecretKey32)>(1); + let (started_tx, mut started_rx) = mpsc::channel::(1); tokio::spawn(async move { - let logger = utils::init_logging(LOG_PATH); - let key = SecretKey::default(); - let codec = XChaCha20Poly1305Codec::from(key.clone()); - - if let Ok(api) = LocalDistantApi::initialize(Default::default()) { + if let Ok(api) = LocalDistantApi::initialize() { let port: PortRange = "0".parse().unwrap(); let port = { - let server_ref = DistantApiServer::new(api) - .start(ip_addr, port, codec) + let handler = DistantApiServerHandler::new(api); + let server_ref = Server::new() + .handler(handler) + .verifier(Verifier::none()) + .into_tcp_builder() + .start(ip_addr, port) .await .unwrap(); server_ref.port() }; - started_tx.send((port, key)).await.unwrap(); + started_tx.send(port).await.unwrap(); let _ = done_rx.recv().await; } - - logger.flush(); - logger.shutdown(); }); // Extract our server startup data if we succeeded - let (port, key) = started_rx.recv().await.unwrap(); + let port = started_rx.recv().await.unwrap(); // Now initialize our client - let client = DistantClient::connect_timeout( - format!("{}:{}", ip_addr, port).parse().unwrap(), - XChaCha20Poly1305Codec::from(key), - Duration::from_secs(1), - ) - .await - .unwrap(); + let client: DistantClient = Client::build() + .auth_handler(DummyAuthHandler) + .timeout(Duration::from_secs(1)) + .connector(TcpConnector::new( + format!("{}:{}", ip_addr, port) + .parse::() + .unwrap(), + )) + .connect() + .await + .unwrap(); DistantClientCtx { client, diff --git a/distant-core/tests/stress/mod.rs b/distant-core/tests/stress/mod.rs index 43b3708..fa0332a 100644 --- a/distant-core/tests/stress/mod.rs +++ b/distant-core/tests/stress/mod.rs @@ -1,3 +1,2 @@ mod distant; mod fixtures; -mod utils; diff --git a/distant-core/tests/stress/utils.rs b/distant-core/tests/stress/utils.rs deleted file mode 100644 index abdaa68..0000000 --- a/distant-core/tests/stress/utils.rs +++ /dev/null @@ -1,23 +0,0 @@ -use std::path::PathBuf; - -/// Initializes logging (should only call once) -pub fn init_logging(path: impl Into) -> flexi_logger::LoggerHandle { - use flexi_logger::{FileSpec, LevelFilter, LogSpecification, Logger}; - let modules = &["distant", "distant_core", "distant_ssh2"]; - - // Disable logging for everything but our binary, which is based on verbosity - let mut builder = LogSpecification::builder(); - builder.default(LevelFilter::Off); - - // For each module, configure logging - for module in modules { - builder.module(module, LevelFilter::Trace); - } - - // Create our logger, but don't initialize yet - let logger = Logger::with(builder.build()) - .format_for_files(flexi_logger::opt_format) - .log_to_file(FileSpec::try_from(path).expect("Failed to create log file spec")); - - logger.start().expect("Failed to initialize logger") -} diff --git a/distant-net/Cargo.toml b/distant-net/Cargo.toml index 7672b33..f696de9 100644 --- a/distant-net/Cargo.toml +++ b/distant-net/Cargo.toml @@ -3,7 +3,7 @@ name = "distant-net" description = "Network library for distant, providing implementations to support client/server architecture" categories = ["network-programming"] keywords = ["api", "async"] -version = "0.19.0" +version = "0.20.0" authors = ["Chip Senkbeil "] edition = "2021" homepage = "https://github.com/chipsenkbeil/distant" @@ -16,7 +16,8 @@ async-trait = "0.1.57" bytes = "1.2.1" chacha20poly1305 = "0.10.0" derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] } -futures = "0.3.21" +dyn-clone = "1.0.9" +flate2 = "1.0.24" hex = "0.4.3" hkdf = "0.12.3" log = "0.4.17" @@ -27,11 +28,13 @@ rmp-serde = "1.1.0" sha2 = "0.10.2" serde = { version = "1.0.142", features = ["derive"] } serde_bytes = "0.11.7" +strum = { version = "0.24.1", features = ["derive"] } tokio = { version = "1.20.1", features = ["full"] } -tokio-util = { version = "0.7.3", features = ["codec"] } # Optional dependencies based on features schemars = { version = "0.8.10", optional = true } [dev-dependencies] +env_logger = "0.9.1" tempfile = "3.3.0" +test-log = "0.2.11" diff --git a/distant-net/src/auth.rs b/distant-net/src/auth.rs deleted file mode 100644 index 1f7e103..0000000 --- a/distant-net/src/auth.rs +++ /dev/null @@ -1,122 +0,0 @@ -use derive_more::Display; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -mod client; -pub use client::*; - -mod handshake; -pub use handshake::*; - -mod server; -pub use server::*; - -/// Represents authentication messages that can be sent over the wire -/// -/// NOTE: Must use serde's content attribute with the tag attribute. Just the tag attribute will -/// cause deserialization to fail -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", tag = "type", content = "data")] -pub enum Auth { - /// Represents a request to perform an authentication handshake, - /// providing the public key and salt from one side in order to - /// derive the shared key - #[serde(rename = "auth_handshake")] - Handshake { - /// Bytes of the public key - #[serde(with = "serde_bytes")] - public_key: PublicKeyBytes, - - /// Randomly generated salt - #[serde(with = "serde_bytes")] - salt: Salt, - }, - - /// Represents the bytes of an encrypted message - /// - /// Underneath, will be one of either [`AuthRequest`] or [`AuthResponse`] - #[serde(rename = "auth_msg")] - Msg { - #[serde(with = "serde_bytes")] - encrypted_payload: Vec, - }, -} - -/// Represents authentication messages that act as initiators such as providing -/// a challenge, verifying information, presenting information, or highlighting an error -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", tag = "type")] -pub enum AuthRequest { - /// Represents a challenge comprising a series of questions to be presented - Challenge { - questions: Vec, - options: HashMap, - }, - - /// Represents an ask to verify some information - Verify { kind: AuthVerifyKind, text: String }, - - /// Represents some information to be presented - Info { text: String }, - - /// Represents some error that occurred - Error { kind: AuthErrorKind, text: String }, -} - -/// Represents authentication messages that are responses to auth requests such -/// as answers to challenges or verifying information -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", tag = "type")] -pub enum AuthResponse { - /// Represents the answers to a previously-asked challenge - Challenge { answers: Vec }, - - /// Represents the answer to a previously-asked verify - Verify { valid: bool }, -} - -/// Represents the type of verification being requested -#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -#[non_exhaustive] -pub enum AuthVerifyKind { - /// An ask to verify the host such as with SSH - #[display(fmt = "host")] - Host, -} - -/// Represents a single question in a challenge -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct AuthQuestion { - /// The text of the question - pub text: String, - - /// Any options information specific to a particular auth domain - /// such as including a username and instructions for SSH authentication - pub options: HashMap, -} - -impl AuthQuestion { - /// Creates a new question without any options data - pub fn new(text: impl Into) -> Self { - Self { - text: text.into(), - options: HashMap::new(), - } - } -} - -/// Represents the type of error encountered during authentication -#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum AuthErrorKind { - /// When the answer(s) to a challenge do not pass authentication - FailedChallenge, - - /// When verification during authentication fails - /// (e.g. a host is not allowed or blocked) - FailedVerification, - - /// When the error is unknown - Unknown, -} diff --git a/distant-net/src/auth/client.rs b/distant-net/src/auth/client.rs deleted file mode 100644 index 206fdb5..0000000 --- a/distant-net/src/auth/client.rs +++ /dev/null @@ -1,817 +0,0 @@ -use crate::{ - utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Client, - Codec, Handshake, XChaCha20Poly1305Codec, -}; -use bytes::BytesMut; -use log::*; -use std::{collections::HashMap, io}; - -pub struct AuthClient { - inner: Client, - codec: Option, - jit_handshake: bool, -} - -impl From> for AuthClient { - fn from(client: Client) -> Self { - Self { - inner: client, - codec: None, - jit_handshake: false, - } - } -} - -impl AuthClient { - /// Sends a request to the server to establish an encrypted connection - pub async fn handshake(&mut self) -> io::Result<()> { - let handshake = Handshake::default(); - - let response = self - .inner - .send(Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - }) - .await?; - - match response.payload { - Auth::Handshake { public_key, salt } => { - let key = handshake.handshake(public_key, salt)?; - self.codec.replace(XChaCha20Poly1305Codec::new(&key)); - Ok(()) - } - Auth::Msg { .. } => Err(io::Error::new( - io::ErrorKind::Other, - "Got unexpected encrypted message during handshake", - )), - } - } - - /// Perform a handshake only if jit is enabled and no handshake has succeeded yet - async fn jit_handshake(&mut self) -> io::Result<()> { - if self.will_jit_handshake() && !self.is_ready() { - self.handshake().await - } else { - Ok(()) - } - } - - /// Returns true if client has successfully performed a handshake - /// and is ready to communicate with the server - pub fn is_ready(&self) -> bool { - self.codec.is_some() - } - - /// Returns true if this client will perform a handshake just-in-time (JIT) prior to making a - /// request in the scenario where the client has not already performed a handshake - #[inline] - pub fn will_jit_handshake(&self) -> bool { - self.jit_handshake - } - - /// Sets the jit flag on this client with `true` indicating that this client will perform a - /// handshake just-in-time (JIT) prior to making a request in the scenario where the client has - /// not already performed a handshake - #[inline] - pub fn set_jit_handshake(&mut self, flag: bool) { - self.jit_handshake = flag; - } - - /// Provides a challenge to the server and returns the answers to the questions - /// asked by the client - pub async fn challenge( - &mut self, - questions: Vec, - options: HashMap, - ) -> io::Result> { - trace!( - "AuthClient::challenge(questions = {:?}, options = {:?})", - questions, - options - ); - - // Perform JIT handshake if enabled - self.jit_handshake().await?; - - let payload = AuthRequest::Challenge { questions, options }; - let encrypted_payload = self.serialize_and_encrypt(&payload)?; - let response = self.inner.send(Auth::Msg { encrypted_payload }).await?; - - match response.payload { - Auth::Msg { encrypted_payload } => { - match self.decrypt_and_deserialize(&encrypted_payload)? { - AuthResponse::Challenge { answers } => Ok(answers), - AuthResponse::Verify { .. } => Err(io::Error::new( - io::ErrorKind::Other, - "Got unexpected verify response during challenge", - )), - } - } - Auth::Handshake { .. } => Err(io::Error::new( - io::ErrorKind::Other, - "Got unexpected handshake during challenge", - )), - } - } - - /// Provides a verification request to the server and returns whether or not - /// the server approved - pub async fn verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result { - trace!("AuthClient::verify(kind = {:?}, text = {:?})", kind, text); - - // Perform JIT handshake if enabled - self.jit_handshake().await?; - - let payload = AuthRequest::Verify { kind, text }; - let encrypted_payload = self.serialize_and_encrypt(&payload)?; - let response = self.inner.send(Auth::Msg { encrypted_payload }).await?; - - match response.payload { - Auth::Msg { encrypted_payload } => { - match self.decrypt_and_deserialize(&encrypted_payload)? { - AuthResponse::Verify { valid } => Ok(valid), - AuthResponse::Challenge { .. } => Err(io::Error::new( - io::ErrorKind::Other, - "Got unexpected challenge response during verify", - )), - } - } - Auth::Handshake { .. } => Err(io::Error::new( - io::ErrorKind::Other, - "Got unexpected handshake during verify", - )), - } - } - - /// Provides information to the server to use as it pleases with no response expected - pub async fn info(&mut self, text: String) -> io::Result<()> { - trace!("AuthClient::info(text = {:?})", text); - - // Perform JIT handshake if enabled - self.jit_handshake().await?; - - let payload = AuthRequest::Info { text }; - let encrypted_payload = self.serialize_and_encrypt(&payload)?; - self.inner.fire(Auth::Msg { encrypted_payload }).await - } - - /// Provides an error to the server to use as it pleases with no response expected - pub async fn error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> { - trace!("AuthClient::error(kind = {:?}, text = {:?})", kind, text); - - // Perform JIT handshake if enabled - self.jit_handshake().await?; - - let payload = AuthRequest::Error { kind, text }; - let encrypted_payload = self.serialize_and_encrypt(&payload)?; - self.inner.fire(Auth::Msg { encrypted_payload }).await - } - - fn serialize_and_encrypt(&mut self, payload: &AuthRequest) -> io::Result> { - let codec = self.codec.as_mut().ok_or_else(|| { - io::Error::new( - io::ErrorKind::Other, - "Handshake must be performed first (client encrypt message)", - ) - })?; - - let mut encryped_payload = BytesMut::new(); - let payload = utils::serialize_to_vec(payload)?; - codec.encode(&payload, &mut encryped_payload)?; - Ok(encryped_payload.freeze().to_vec()) - } - - fn decrypt_and_deserialize(&mut self, payload: &[u8]) -> io::Result { - let codec = self.codec.as_mut().ok_or_else(|| { - io::Error::new( - io::ErrorKind::Other, - "Handshake must be performed first (client decrypt message)", - ) - })?; - - let mut payload = BytesMut::from(payload); - match codec.decode(&mut payload)? { - Some(payload) => utils::deserialize_from_slice::(&payload), - None => Err(io::Error::new( - io::ErrorKind::InvalidData, - "Incomplete message received", - )), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{Client, FramedTransport, Request, Response, TypedAsyncRead, TypedAsyncWrite}; - use serde::{de::DeserializeOwned, Serialize}; - - const TIMEOUT_MILLIS: u64 = 100; - - #[tokio::test] - async fn handshake_should_fail_if_get_unexpected_response_from_server() { - let (t, mut server) = FramedTransport::make_test_pair(); - let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); - - // We start a separate task for the client to avoid blocking since - // we also need to receive the client's request and respond - let task = tokio::spawn(async move { client.handshake().await }); - - // Get the request, but send a bad response - let request: Request = server.read().await.unwrap().unwrap(); - match request.payload { - Auth::Handshake { .. } => server - .write(Response::new( - request.id, - Auth::Msg { - encrypted_payload: Vec::new(), - }, - )) - .await - .unwrap(), - _ => panic!("Server received unexpected payload"), - } - - let result = task.await.unwrap(); - assert!(result.is_err(), "Handshake succeeded unexpectedly") - } - - #[tokio::test] - async fn challenge_should_fail_if_handshake_not_finished() { - let (t, mut server) = FramedTransport::make_test_pair(); - let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); - - // We start a separate task for the client to avoid blocking since - // we also need to receive the client's request and respond - let task = tokio::spawn(async move { client.challenge(Vec::new(), HashMap::new()).await }); - - // Wait for a request, failing if we get one as the failure - // should have prevented sending anything, but we should - tokio::select! { - x = TypedAsyncRead::>::read(&mut server) => { - match x { - Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x), - Ok(None) => {}, - Err(x) => panic!("Unexpectedly failed on server side: {}", x), - } - }, - _ = wait_ms(TIMEOUT_MILLIS) => { - panic!("Should have gotten server closure as part of client exit"); - } - } - - // Verify that we got an error with the method - let result = task.await.unwrap(); - assert!(result.is_err(), "Challenge succeeded unexpectedly") - } - - #[tokio::test] - async fn challenge_should_fail_if_receive_wrong_response() { - let (t, mut server) = FramedTransport::make_test_pair(); - let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); - - // We start a separate task for the client to avoid blocking since - // we also need to receive the client's request and respond - let task = tokio::spawn(async move { - client.handshake().await.unwrap(); - client - .challenge( - vec![ - AuthQuestion::new("question1".to_string()), - AuthQuestion { - text: "question2".to_string(), - options: vec![("key2".to_string(), "value2".to_string())] - .into_iter() - .collect(), - }, - ], - vec![("key".to_string(), "value".to_string())] - .into_iter() - .collect(), - ) - .await - }); - - // Wait for a handshake request and set up our encryption codec - let request: Request = server.read().await.unwrap().unwrap(); - let mut codec = match request.payload { - Auth::Handshake { public_key, salt } => { - let handshake = Handshake::default(); - let key = handshake.handshake(public_key, salt).unwrap(); - server - .write(Response::new( - request.id, - Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - }, - )) - .await - .unwrap(); - XChaCha20Poly1305Codec::new(&key) - } - _ => panic!("Server received unexpected payload"), - }; - - // Wait for a challenge request and send back wrong response - let request: Request = server.read().await.unwrap().unwrap(); - match request.payload { - Auth::Msg { encrypted_payload } => { - match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { - AuthRequest::Challenge { .. } => { - server - .write(Response::new( - request.id, - Auth::Msg { - encrypted_payload: serialize_and_encrypt( - &mut codec, - &AuthResponse::Verify { valid: true }, - ) - .unwrap(), - }, - )) - .await - .unwrap(); - } - _ => panic!("Server received wrong request type"), - } - } - _ => panic!("Server received unexpected payload"), - }; - - // Verify that we got an error with the method - let result = task.await.unwrap(); - assert!(result.is_err(), "Challenge succeeded unexpectedly") - } - - #[tokio::test] - async fn challenge_should_return_answers_received_from_server() { - let (t, mut server) = FramedTransport::make_test_pair(); - let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); - - // We start a separate task for the client to avoid blocking since - // we also need to receive the client's request and respond - let task = tokio::spawn(async move { - client.handshake().await.unwrap(); - client - .challenge( - vec![ - AuthQuestion::new("question1".to_string()), - AuthQuestion { - text: "question2".to_string(), - options: vec![("key2".to_string(), "value2".to_string())] - .into_iter() - .collect(), - }, - ], - vec![("key".to_string(), "value".to_string())] - .into_iter() - .collect(), - ) - .await - }); - - // Wait for a handshake request and set up our encryption codec - let request: Request = server.read().await.unwrap().unwrap(); - let mut codec = match request.payload { - Auth::Handshake { public_key, salt } => { - let handshake = Handshake::default(); - let key = handshake.handshake(public_key, salt).unwrap(); - server - .write(Response::new( - request.id, - Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - }, - )) - .await - .unwrap(); - XChaCha20Poly1305Codec::new(&key) - } - _ => panic!("Server received unexpected payload"), - }; - - // Wait for a challenge request and send back wrong response - let request: Request = server.read().await.unwrap().unwrap(); - match request.payload { - Auth::Msg { encrypted_payload } => { - match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { - AuthRequest::Challenge { questions, options } => { - assert_eq!( - questions, - vec![ - AuthQuestion::new("question1".to_string()), - AuthQuestion { - text: "question2".to_string(), - options: vec![("key2".to_string(), "value2".to_string())] - .into_iter() - .collect(), - }, - ], - ); - - assert_eq!( - options, - vec![("key".to_string(), "value".to_string())] - .into_iter() - .collect(), - ); - - server - .write(Response::new( - request.id, - Auth::Msg { - encrypted_payload: serialize_and_encrypt( - &mut codec, - &AuthResponse::Challenge { - answers: vec![ - "answer1".to_string(), - "answer2".to_string(), - ], - }, - ) - .unwrap(), - }, - )) - .await - .unwrap(); - } - _ => panic!("Server received wrong request type"), - } - } - _ => panic!("Server received unexpected payload"), - }; - - // Verify that we got the right results - let answers = task.await.unwrap().unwrap(); - assert_eq!(answers, vec!["answer1".to_string(), "answer2".to_string()]); - } - - #[tokio::test] - async fn verify_should_fail_if_handshake_not_finished() { - let (t, mut server) = FramedTransport::make_test_pair(); - let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); - - // We start a separate task for the client to avoid blocking since - // we also need to receive the client's request and respond - let task = tokio::spawn(async move { - client - .verify(AuthVerifyKind::Host, "some text".to_string()) - .await - }); - - // Wait for a request, failing if we get one as the failure - // should have prevented sending anything, but we should - tokio::select! { - x = TypedAsyncRead::>::read(&mut server) => { - match x { - Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x), - Ok(None) => {}, - Err(x) => panic!("Unexpectedly failed on server side: {}", x), - } - }, - _ = wait_ms(TIMEOUT_MILLIS) => { - panic!("Should have gotten server closure as part of client exit"); - } - } - - // Verify that we got an error with the method - let result = task.await.unwrap(); - assert!(result.is_err(), "Verify succeeded unexpectedly") - } - - #[tokio::test] - async fn verify_should_fail_if_receive_wrong_response() { - let (t, mut server) = FramedTransport::make_test_pair(); - let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); - - // We start a separate task for the client to avoid blocking since - // we also need to receive the client's request and respond - let task = tokio::spawn(async move { - client.handshake().await.unwrap(); - client - .verify(AuthVerifyKind::Host, "some text".to_string()) - .await - }); - - // Wait for a handshake request and set up our encryption codec - let request: Request = server.read().await.unwrap().unwrap(); - let mut codec = match request.payload { - Auth::Handshake { public_key, salt } => { - let handshake = Handshake::default(); - let key = handshake.handshake(public_key, salt).unwrap(); - server - .write(Response::new( - request.id, - Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - }, - )) - .await - .unwrap(); - XChaCha20Poly1305Codec::new(&key) - } - _ => panic!("Server received unexpected payload"), - }; - - // Wait for a verify request and send back wrong response - let request: Request = server.read().await.unwrap().unwrap(); - match request.payload { - Auth::Msg { encrypted_payload } => { - match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { - AuthRequest::Verify { .. } => { - server - .write(Response::new( - request.id, - Auth::Msg { - encrypted_payload: serialize_and_encrypt( - &mut codec, - &AuthResponse::Challenge { - answers: Vec::new(), - }, - ) - .unwrap(), - }, - )) - .await - .unwrap(); - } - _ => panic!("Server received wrong request type"), - } - } - _ => panic!("Server received unexpected payload"), - }; - - // Verify that we got an error with the method - let result = task.await.unwrap(); - assert!(result.is_err(), "Verify succeeded unexpectedly") - } - - #[tokio::test] - async fn verify_should_return_valid_bool_received_from_server() { - let (t, mut server) = FramedTransport::make_test_pair(); - let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); - - // We start a separate task for the client to avoid blocking since - // we also need to receive the client's request and respond - let task = tokio::spawn(async move { - client.handshake().await.unwrap(); - client - .verify(AuthVerifyKind::Host, "some text".to_string()) - .await - }); - - // Wait for a handshake request and set up our encryption codec - let request: Request = server.read().await.unwrap().unwrap(); - let mut codec = match request.payload { - Auth::Handshake { public_key, salt } => { - let handshake = Handshake::default(); - let key = handshake.handshake(public_key, salt).unwrap(); - server - .write(Response::new( - request.id, - Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - }, - )) - .await - .unwrap(); - XChaCha20Poly1305Codec::new(&key) - } - _ => panic!("Server received unexpected payload"), - }; - - // Wait for a challenge request and send back wrong response - let request: Request = server.read().await.unwrap().unwrap(); - match request.payload { - Auth::Msg { encrypted_payload } => { - match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { - AuthRequest::Verify { kind, text } => { - assert_eq!(kind, AuthVerifyKind::Host); - assert_eq!(text, "some text"); - - server - .write(Response::new( - request.id, - Auth::Msg { - encrypted_payload: serialize_and_encrypt( - &mut codec, - &AuthResponse::Verify { valid: true }, - ) - .unwrap(), - }, - )) - .await - .unwrap(); - } - _ => panic!("Server received wrong request type"), - } - } - _ => panic!("Server received unexpected payload"), - }; - - // Verify that we got the right results - let valid = task.await.unwrap().unwrap(); - assert!(valid, "Got verify response, but valid was set incorrectly"); - } - - #[tokio::test] - async fn info_should_fail_if_handshake_not_finished() { - let (t, mut server) = FramedTransport::make_test_pair(); - let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); - - // We start a separate task for the client to avoid blocking since - // we also need to receive the client's request and respond - let task = tokio::spawn(async move { client.info("some text".to_string()).await }); - - // Wait for a request, failing if we get one as the failure - // should have prevented sending anything, but we should - tokio::select! { - x = TypedAsyncRead::>::read(&mut server) => { - match x { - Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x), - Ok(None) => {}, - Err(x) => panic!("Unexpectedly failed on server side: {}", x), - } - }, - _ = wait_ms(TIMEOUT_MILLIS) => { - panic!("Should have gotten server closure as part of client exit"); - } - } - - // Verify that we got an error with the method - let result = task.await.unwrap(); - assert!(result.is_err(), "Info succeeded unexpectedly") - } - - #[tokio::test] - async fn info_should_send_the_server_a_request_but_not_wait_for_a_response() { - let (t, mut server) = FramedTransport::make_test_pair(); - let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); - - // We start a separate task for the client to avoid blocking since - // we also need to receive the client's request and respond - let task = tokio::spawn(async move { - client.handshake().await.unwrap(); - client.info("some text".to_string()).await - }); - - // Wait for a handshake request and set up our encryption codec - let request: Request = server.read().await.unwrap().unwrap(); - let mut codec = match request.payload { - Auth::Handshake { public_key, salt } => { - let handshake = Handshake::default(); - let key = handshake.handshake(public_key, salt).unwrap(); - server - .write(Response::new( - request.id, - Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - }, - )) - .await - .unwrap(); - XChaCha20Poly1305Codec::new(&key) - } - _ => panic!("Server received unexpected payload"), - }; - - // Wait for a request - let request: Request = server.read().await.unwrap().unwrap(); - match request.payload { - Auth::Msg { encrypted_payload } => { - match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { - AuthRequest::Info { text } => { - assert_eq!(text, "some text"); - } - _ => panic!("Server received wrong request type"), - } - } - _ => panic!("Server received unexpected payload"), - }; - - // Verify that we got the right results - task.await.unwrap().unwrap(); - } - - #[tokio::test] - async fn error_should_fail_if_handshake_not_finished() { - let (t, mut server) = FramedTransport::make_test_pair(); - let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); - - // We start a separate task for the client to avoid blocking since - // we also need to receive the client's request and respond - let task = tokio::spawn(async move { - client - .error(AuthErrorKind::FailedChallenge, "some text".to_string()) - .await - }); - - // Wait for a request, failing if we get one as the failure - // should have prevented sending anything, but we should - tokio::select! { - x = TypedAsyncRead::>::read(&mut server) => { - match x { - Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x), - Ok(None) => {}, - Err(x) => panic!("Unexpectedly failed on server side: {}", x), - } - }, - _ = wait_ms(TIMEOUT_MILLIS) => { - panic!("Should have gotten server closure as part of client exit"); - } - } - - // Verify that we got an error with the method - let result = task.await.unwrap(); - assert!(result.is_err(), "Error succeeded unexpectedly") - } - - #[tokio::test] - async fn error_should_send_the_server_a_request_but_not_wait_for_a_response() { - let (t, mut server) = FramedTransport::make_test_pair(); - let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap()); - - // We start a separate task for the client to avoid blocking since - // we also need to receive the client's request and respond - let task = tokio::spawn(async move { - client.handshake().await.unwrap(); - client - .error(AuthErrorKind::FailedChallenge, "some text".to_string()) - .await - }); - - // Wait for a handshake request and set up our encryption codec - let request: Request = server.read().await.unwrap().unwrap(); - let mut codec = match request.payload { - Auth::Handshake { public_key, salt } => { - let handshake = Handshake::default(); - let key = handshake.handshake(public_key, salt).unwrap(); - server - .write(Response::new( - request.id, - Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - }, - )) - .await - .unwrap(); - XChaCha20Poly1305Codec::new(&key) - } - _ => panic!("Server received unexpected payload"), - }; - - // Wait for a request - let request: Request = server.read().await.unwrap().unwrap(); - match request.payload { - Auth::Msg { encrypted_payload } => { - match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { - AuthRequest::Error { kind, text } => { - assert_eq!(kind, AuthErrorKind::FailedChallenge); - assert_eq!(text, "some text"); - } - _ => panic!("Server received wrong request type"), - } - } - _ => panic!("Server received unexpected payload"), - }; - - // Verify that we got the right results - task.await.unwrap().unwrap(); - } - - async fn wait_ms(ms: u64) { - use std::time::Duration; - tokio::time::sleep(Duration::from_millis(ms)).await; - } - - fn serialize_and_encrypt( - codec: &mut XChaCha20Poly1305Codec, - payload: &T, - ) -> io::Result> { - let mut encryped_payload = BytesMut::new(); - let payload = utils::serialize_to_vec(payload)?; - codec.encode(&payload, &mut encryped_payload)?; - Ok(encryped_payload.freeze().to_vec()) - } - - fn decrypt_and_deserialize( - codec: &mut XChaCha20Poly1305Codec, - payload: &[u8], - ) -> io::Result { - let mut payload = BytesMut::from(payload); - match codec.decode(&mut payload)? { - Some(payload) => utils::deserialize_from_slice::(&payload), - None => Err(io::Error::new( - io::ErrorKind::InvalidData, - "Incomplete message received", - )), - } - } -} diff --git a/distant-net/src/auth/server.rs b/distant-net/src/auth/server.rs deleted file mode 100644 index 2c5f599..0000000 --- a/distant-net/src/auth/server.rs +++ /dev/null @@ -1,653 +0,0 @@ -use crate::{ - utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Codec, - Handshake, Server, ServerCtx, XChaCha20Poly1305Codec, -}; -use async_trait::async_trait; -use bytes::BytesMut; -use log::*; -use std::{collections::HashMap, io}; -use tokio::sync::RwLock; - -/// Type signature for a dynamic on_challenge function -pub type AuthChallengeFn = - dyn Fn(Vec, HashMap) -> Vec + Send + Sync; - -/// Type signature for a dynamic on_verify function -pub type AuthVerifyFn = dyn Fn(AuthVerifyKind, String) -> bool + Send + Sync; - -/// Type signature for a dynamic on_info function -pub type AuthInfoFn = dyn Fn(String) + Send + Sync; - -/// Type signature for a dynamic on_error function -pub type AuthErrorFn = dyn Fn(AuthErrorKind, String) + Send + Sync; - -/// Represents an [`AuthServer`] where all handlers are stored on the heap -pub type HeapAuthServer = - AuthServer, Box, Box, Box>; - -/// Server that handles authentication -pub struct AuthServer -where - ChallengeFn: Fn(Vec, HashMap) -> Vec + Send + Sync, - VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync, - InfoFn: Fn(String) + Send + Sync, - ErrorFn: Fn(AuthErrorKind, String) + Send + Sync, -{ - pub on_challenge: ChallengeFn, - pub on_verify: VerifyFn, - pub on_info: InfoFn, - pub on_error: ErrorFn, -} - -#[async_trait] -impl Server - for AuthServer -where - ChallengeFn: Fn(Vec, HashMap) -> Vec + Send + Sync, - VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync, - InfoFn: Fn(String) + Send + Sync, - ErrorFn: Fn(AuthErrorKind, String) + Send + Sync, -{ - type Request = Auth; - type Response = Auth; - type LocalData = RwLock>; - - async fn on_request(&self, ctx: ServerCtx) { - let reply = ctx.reply.clone(); - - match ctx.request.payload { - Auth::Handshake { public_key, salt } => { - trace!( - "Received handshake request from client, request id = {}", - ctx.request.id - ); - let handshake = Handshake::default(); - match handshake.handshake(public_key, salt) { - Ok(key) => { - ctx.local_data - .write() - .await - .replace(XChaCha20Poly1305Codec::new(&key)); - - trace!( - "Sending reciprocal handshake to client, response origin id = {}", - ctx.request.id - ); - if let Err(x) = reply - .send(Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - }) - .await - { - error!("[Conn {}] {}", ctx.connection_id, x); - } - } - Err(x) => { - error!("[Conn {}] {}", ctx.connection_id, x); - return; - } - } - } - Auth::Msg { - ref encrypted_payload, - } => { - trace!( - "Received auth msg, encrypted payload size = {}", - encrypted_payload.len() - ); - - // Attempt to decrypt the message so we can understand what to do - let request = match ctx.local_data.write().await.as_mut() { - Some(codec) => { - let mut payload = BytesMut::from(encrypted_payload.as_slice()); - match codec.decode(&mut payload) { - Ok(Some(payload)) => { - utils::deserialize_from_slice::(&payload) - } - Ok(None) => Err(io::Error::new( - io::ErrorKind::InvalidData, - "Incomplete message received", - )), - Err(x) => Err(x), - } - } - None => Err(io::Error::new( - io::ErrorKind::Other, - "Handshake must be performed first (server decrypt message)", - )), - }; - - let response = match request { - Ok(request) => match request { - AuthRequest::Challenge { questions, options } => { - trace!("Received challenge request"); - trace!("questions = {:?}", questions); - trace!("options = {:?}", options); - - let answers = (self.on_challenge)(questions, options); - AuthResponse::Challenge { answers } - } - AuthRequest::Verify { kind, text } => { - trace!("Received verify request"); - trace!("kind = {:?}", kind); - trace!("text = {:?}", text); - - let valid = (self.on_verify)(kind, text); - AuthResponse::Verify { valid } - } - AuthRequest::Info { text } => { - trace!("Received info request"); - trace!("text = {:?}", text); - - (self.on_info)(text); - return; - } - AuthRequest::Error { kind, text } => { - trace!("Received error request"); - trace!("kind = {:?}", kind); - trace!("text = {:?}", text); - - (self.on_error)(kind, text); - return; - } - }, - Err(x) => { - error!("[Conn {}] {}", ctx.connection_id, x); - return; - } - }; - - // Serialize and encrypt the message before sending it back - let encrypted_payload = match ctx.local_data.write().await.as_mut() { - Some(codec) => { - let mut encrypted_payload = BytesMut::new(); - - // Convert the response into bytes for us to send back - match utils::serialize_to_vec(&response) { - Ok(bytes) => match codec.encode(&bytes, &mut encrypted_payload) { - Ok(_) => Ok(encrypted_payload.freeze().to_vec()), - Err(x) => Err(x), - }, - Err(x) => Err(x), - } - } - None => Err(io::Error::new( - io::ErrorKind::Other, - "Handshake must be performed first (server encrypt messaage)", - )), - }; - - match encrypted_payload { - Ok(encrypted_payload) => { - if let Err(x) = reply.send(Auth::Msg { encrypted_payload }).await { - error!("[Conn {}] {}", ctx.connection_id, x); - return; - } - } - Err(x) => { - error!("[Conn {}] {}", ctx.connection_id, x); - return; - } - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - IntoSplit, MpscListener, MpscTransport, Request, Response, ServerExt, ServerRef, - TypedAsyncRead, TypedAsyncWrite, - }; - use tokio::sync::mpsc; - - const TIMEOUT_MILLIS: u64 = 100; - - #[tokio::test] - async fn should_not_reply_if_receive_encrypted_msg_without_handshake_first() { - let (mut t, _) = spawn_auth_server( - /* on_challenge */ |_, _| Vec::new(), - /* on_verify */ |_, _| false, - /* on_info */ |_| {}, - /* on_error */ |_, _| {}, - ) - .await - .expect("Failed to spawn server"); - - // Send an encrypted message before establishing a handshake - t.write(Request::new(Auth::Msg { - encrypted_payload: Vec::new(), - })) - .await - .expect("Failed to send request to server"); - - // Wait for a response, failing if we get one - tokio::select! { - x = t.read() => panic!("Unexpectedly resolved: {:?}", x), - _ = wait_ms(TIMEOUT_MILLIS) => {} - } - } - - #[tokio::test] - async fn should_reply_to_handshake_request_with_new_public_key_and_salt() { - let (mut t, _) = spawn_auth_server( - /* on_challenge */ |_, _| Vec::new(), - /* on_verify */ |_, _| false, - /* on_info */ |_| {}, - /* on_error */ |_, _| {}, - ) - .await - .expect("Failed to spawn server"); - - // Send a handshake - let handshake = Handshake::default(); - t.write(Request::new(Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - })) - .await - .expect("Failed to send request to server"); - - // Wait for a handshake response - tokio::select! { - x = t.read() => { - let response = x.expect("Request failed").expect("Response missing"); - match response.payload { - Auth::Handshake { .. } => {}, - Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), - } - } - _ = wait_ms(TIMEOUT_MILLIS) => panic!("Ran out of time waiting on response"), - } - } - - #[tokio::test] - async fn should_not_reply_if_receive_invalid_encrypted_msg() { - let (mut t, _) = spawn_auth_server( - /* on_challenge */ |_, _| Vec::new(), - /* on_verify */ |_, _| false, - /* on_info */ |_| {}, - /* on_error */ |_, _| {}, - ) - .await - .expect("Failed to spawn server"); - - // Send a handshake - let handshake = Handshake::default(); - t.write(Request::new(Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - })) - .await - .expect("Failed to send request to server"); - - // Complete handshake - let key = match t.read().await.unwrap().unwrap().payload { - Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(), - Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), - }; - - // Send a bad chunk of data - let _codec = XChaCha20Poly1305Codec::new(&key); - t.write(Request::new(Auth::Msg { - encrypted_payload: vec![1, 2, 3, 4], - })) - .await - .unwrap(); - - // Wait for a response, failing if we get one - tokio::select! { - x = t.read() => panic!("Unexpectedly resolved: {:?}", x), - _ = wait_ms(TIMEOUT_MILLIS) => {} - } - } - - #[tokio::test] - async fn should_invoke_appropriate_function_when_receive_challenge_request_and_reply() { - let (tx, mut rx) = mpsc::channel(1); - let (mut t, _) = spawn_auth_server( - /* on_challenge */ - move |questions, options| { - tx.try_send((questions, options)).unwrap(); - vec!["answer1".to_string(), "answer2".to_string()] - }, - /* on_verify */ |_, _| false, - /* on_info */ |_| {}, - /* on_error */ |_, _| {}, - ) - .await - .expect("Failed to spawn server"); - - // Send a handshake - let handshake = Handshake::default(); - t.write(Request::new(Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - })) - .await - .expect("Failed to send request to server"); - - // Complete handshake - let key = match t.read().await.unwrap().unwrap().payload { - Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(), - Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), - }; - - // Send an error request - let mut codec = XChaCha20Poly1305Codec::new(&key); - t.write(Request::new(Auth::Msg { - encrypted_payload: serialize_and_encrypt( - &mut codec, - &AuthRequest::Challenge { - questions: vec![ - AuthQuestion::new("question1".to_string()), - AuthQuestion { - text: "question2".to_string(), - options: vec![("key".to_string(), "value".to_string())] - .into_iter() - .collect(), - }, - ], - options: vec![("hello".to_string(), "world".to_string())] - .into_iter() - .collect(), - }, - ) - .unwrap(), - })) - .await - .unwrap(); - - // Verify that the handler was triggered - let (questions, options) = rx.recv().await.expect("Channel closed unexpectedly"); - assert_eq!( - questions, - vec![ - AuthQuestion::new("question1".to_string()), - AuthQuestion { - text: "question2".to_string(), - options: vec![("key".to_string(), "value".to_string())] - .into_iter() - .collect(), - } - ] - ); - assert_eq!( - options, - vec![("hello".to_string(), "world".to_string())] - .into_iter() - .collect() - ); - - // Wait for a response and verify that it matches what we expect - tokio::select! { - x = t.read() => { - let response = x.expect("Request failed").expect("Response missing"); - match response.payload { - Auth::Handshake { .. } => panic!("Received unexpected handshake"), - Auth::Msg { encrypted_payload } => { - match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { - AuthResponse::Challenge { answers } => - assert_eq!( - answers, - vec!["answer1".to_string(), "answer2".to_string()] - ), - _ => panic!("Got wrong response for verify"), - } - }, - } - } - _ = wait_ms(TIMEOUT_MILLIS) => {} - } - } - - #[tokio::test] - async fn should_invoke_appropriate_function_when_receive_verify_request_and_reply() { - let (tx, mut rx) = mpsc::channel(1); - let (mut t, _) = spawn_auth_server( - /* on_challenge */ |_, _| Vec::new(), - /* on_verify */ - move |kind, text| { - tx.try_send((kind, text)).unwrap(); - true - }, - /* on_info */ |_| {}, - /* on_error */ |_, _| {}, - ) - .await - .expect("Failed to spawn server"); - - // Send a handshake - let handshake = Handshake::default(); - t.write(Request::new(Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - })) - .await - .expect("Failed to send request to server"); - - // Complete handshake - let key = match t.read().await.unwrap().unwrap().payload { - Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(), - Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), - }; - - // Send an error request - let mut codec = XChaCha20Poly1305Codec::new(&key); - t.write(Request::new(Auth::Msg { - encrypted_payload: serialize_and_encrypt( - &mut codec, - &AuthRequest::Verify { - kind: AuthVerifyKind::Host, - text: "some text".to_string(), - }, - ) - .unwrap(), - })) - .await - .unwrap(); - - // Verify that the handler was triggered - let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly"); - assert_eq!(kind, AuthVerifyKind::Host); - assert_eq!(text, "some text"); - - // Wait for a response and verify that it matches what we expect - tokio::select! { - x = t.read() => { - let response = x.expect("Request failed").expect("Response missing"); - match response.payload { - Auth::Handshake { .. } => panic!("Received unexpected handshake"), - Auth::Msg { encrypted_payload } => { - match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() { - AuthResponse::Verify { valid } => - assert!(valid, "Got verify, but valid was wrong"), - _ => panic!("Got wrong response for verify"), - } - }, - } - } - _ = wait_ms(TIMEOUT_MILLIS) => {} - } - } - - #[tokio::test] - async fn should_invoke_appropriate_function_when_receive_info_request() { - let (tx, mut rx) = mpsc::channel(1); - let (mut t, _) = spawn_auth_server( - /* on_challenge */ |_, _| Vec::new(), - /* on_verify */ |_, _| false, - /* on_info */ - move |text| { - tx.try_send(text).unwrap(); - }, - /* on_error */ |_, _| {}, - ) - .await - .expect("Failed to spawn server"); - - // Send a handshake - let handshake = Handshake::default(); - t.write(Request::new(Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - })) - .await - .expect("Failed to send request to server"); - - // Complete handshake - let key = match t.read().await.unwrap().unwrap().payload { - Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(), - Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), - }; - - // Send an error request - let mut codec = XChaCha20Poly1305Codec::new(&key); - t.write(Request::new(Auth::Msg { - encrypted_payload: serialize_and_encrypt( - &mut codec, - &AuthRequest::Info { - text: "some text".to_string(), - }, - ) - .unwrap(), - })) - .await - .unwrap(); - - // Verify that the handler was triggered - let text = rx.recv().await.expect("Channel closed unexpectedly"); - assert_eq!(text, "some text"); - - // Wait for a response, failing if we get one - tokio::select! { - x = t.read() => panic!("Unexpectedly resolved: {:?}", x), - _ = wait_ms(TIMEOUT_MILLIS) => {} - } - } - - #[tokio::test] - async fn should_invoke_appropriate_function_when_receive_error_request() { - let (tx, mut rx) = mpsc::channel(1); - let (mut t, _) = spawn_auth_server( - /* on_challenge */ |_, _| Vec::new(), - /* on_verify */ |_, _| false, - /* on_info */ |_| {}, - /* on_error */ - move |kind, text| { - tx.try_send((kind, text)).unwrap(); - }, - ) - .await - .expect("Failed to spawn server"); - - // Send a handshake - let handshake = Handshake::default(); - t.write(Request::new(Auth::Handshake { - public_key: handshake.pk_bytes(), - salt: *handshake.salt(), - })) - .await - .expect("Failed to send request to server"); - - // Complete handshake - let key = match t.read().await.unwrap().unwrap().payload { - Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(), - Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"), - }; - - // Send an error request - let mut codec = XChaCha20Poly1305Codec::new(&key); - t.write(Request::new(Auth::Msg { - encrypted_payload: serialize_and_encrypt( - &mut codec, - &AuthRequest::Error { - kind: AuthErrorKind::FailedChallenge, - text: "some text".to_string(), - }, - ) - .unwrap(), - })) - .await - .unwrap(); - - // Verify that the handler was triggered - let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly"); - assert_eq!(kind, AuthErrorKind::FailedChallenge); - assert_eq!(text, "some text"); - - // Wait for a response, failing if we get one - tokio::select! { - x = t.read() => panic!("Unexpectedly resolved: {:?}", x), - _ = wait_ms(TIMEOUT_MILLIS) => {} - } - } - - async fn wait_ms(ms: u64) { - use std::time::Duration; - tokio::time::sleep(Duration::from_millis(ms)).await; - } - - fn serialize_and_encrypt( - codec: &mut XChaCha20Poly1305Codec, - payload: &AuthRequest, - ) -> io::Result> { - let mut encryped_payload = BytesMut::new(); - let payload = utils::serialize_to_vec(payload)?; - codec.encode(&payload, &mut encryped_payload)?; - Ok(encryped_payload.freeze().to_vec()) - } - - fn decrypt_and_deserialize( - codec: &mut XChaCha20Poly1305Codec, - payload: &[u8], - ) -> io::Result { - let mut payload = BytesMut::from(payload); - match codec.decode(&mut payload)? { - Some(payload) => utils::deserialize_from_slice::(&payload), - None => Err(io::Error::new( - io::ErrorKind::InvalidData, - "Incomplete message received", - )), - } - } - - async fn spawn_auth_server( - on_challenge: ChallengeFn, - on_verify: VerifyFn, - on_info: InfoFn, - on_error: ErrorFn, - ) -> io::Result<( - MpscTransport, Response>, - Box, - )> - where - ChallengeFn: - Fn(Vec, HashMap) -> Vec + Send + Sync + 'static, - VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync + 'static, - InfoFn: Fn(String) + Send + Sync + 'static, - ErrorFn: Fn(AuthErrorKind, String) + Send + Sync + 'static, - { - let server = AuthServer { - on_challenge, - on_verify, - on_info, - on_error, - }; - - // Create a test listener where we will forward a connection - let (tx, listener) = MpscListener::channel(100); - - // Make bounded transport pair and send off one of them to act as our connection - let (transport, connection) = MpscTransport::, Response>::pair(100); - tx.send(connection.into_split()) - .await - .expect("Failed to feed listener a connection"); - - let server = server.start(listener)?; - Ok((transport, server)) - } -} diff --git a/distant-net/src/client.rs b/distant-net/src/client.rs index 8d1babb..fa3fba3 100644 --- a/distant-net/src/client.rs +++ b/distant-net/src/client.rs @@ -1,140 +1,452 @@ -use crate::{ - Codec, FramedTransport, IntoSplit, RawTransport, RawTransportRead, RawTransportWrite, Request, - Response, TypedAsyncRead, TypedAsyncWrite, +use crate::common::{ + Connection, FramedTransport, HeapSecretKey, InmemoryTransport, Interest, Reconnectable, + Transport, UntypedRequest, UntypedResponse, }; +use log::*; use serde::{de::DeserializeOwned, Serialize}; use std::{ + fmt, io, ops::{Deref, DerefMut}, sync::Arc, + time::Duration, }; use tokio::{ - io, - sync::mpsc, - task::{JoinError, JoinHandle}, + sync::{mpsc, oneshot}, + task::JoinHandle, }; +mod builder; +pub use builder::*; + mod channel; pub use channel::*; -mod ext; -pub use ext::*; +mod reconnect; +pub use reconnect::*; -/// Represents a client that can be used to send requests & receive responses from a server -pub struct Client -where - T: Send + Sync + Serialize + 'static, - U: Send + Sync + DeserializeOwned + 'static, -{ - /// Used to send requests to a server - channel: Channel, +mod shutdown; +pub use shutdown::*; + +/// Time to wait inbetween connection read/write when nothing was read or written on last pass +const SLEEP_DURATION: Duration = Duration::from_millis(1); - /// Contains the task that is running to send requests to a server - request_task: JoinHandle<()>, +/// Represents a client that can be used to send requests & receive responses from a server. +/// +/// ### Note +/// +/// This variant does not validate the payload of requests or responses being sent and received. +pub struct UntypedClient { + /// Used to send requests to a server. + channel: UntypedChannel, - /// Contains the task that is running to receive responses from a server - response_task: JoinHandle<()>, + /// Used to send shutdown request to inner task. + shutdown: Box, + + /// Contains the task that is running to send requests and receive responses from a server. + task: JoinHandle>, } -impl Client -where - T: Send + Sync + Serialize, - U: Send + Sync + DeserializeOwned, -{ - /// Initializes a client using the provided reader and writer - pub fn new(mut writer: W, mut reader: R) -> io::Result +impl fmt::Debug for UntypedClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UntypedClient") + .field("channel", &self.channel) + .field("shutdown", &"...") + .field("task", &self.task) + .finish() + } +} + +impl UntypedClient { + /// Consumes the client, returning a typed variant. + pub fn into_typed_client(self) -> Client { + Client { + channel: self.channel.into_typed_channel(), + shutdown: self.shutdown, + task: self.task, + } + } + + /// Convert into underlying channel. + pub fn into_channel(self) -> UntypedChannel { + self.channel + } + + /// Clones the underlying channel for requests and returns the cloned instance. + pub fn clone_channel(&self) -> UntypedChannel { + self.channel.clone() + } + + /// Waits for the client to terminate, which resolves when the receiving end of the network + /// connection is closed (or the client is shutdown). Returns whether or not the client exited + /// successfully or due to an error. + pub async fn wait(self) -> io::Result<()> { + match self.task.await { + Ok(x) => x, + Err(x) => Err(io::Error::new(io::ErrorKind::Other, x)), + } + } + + /// Abort the client's current connection by forcing its tasks to abort. + pub fn abort(&self) { + self.task.abort(); + } + + /// Clones the underlying shutdown signaler for the client. This enables you to wait on the + /// client while still having the option to shut it down from somewhere else. + pub fn clone_shutdown(&self) -> Box { + self.shutdown.clone() + } + + /// Signal for the client to shutdown its connection cleanly. + pub async fn shutdown(&self) -> io::Result<()> { + self.shutdown.shutdown().await + } + + /// Returns true if client's underlying event processing has finished/terminated. + pub fn is_finished(&self) -> bool { + self.task.is_finished() + } + + /// Spawns a client using the provided [`FramedTransport`] of [`InmemoryTransport`] and a + /// specific [`ReconnectStrategy`]. + /// + /// ### Note + /// + /// This will NOT perform any handshakes or authentication procedures nor will it replay any + /// missing frames. This is to be used when establishing a [`Client`] to be run internally + /// within a program. + pub fn spawn_inmemory( + transport: FramedTransport, + strategy: ReconnectStrategy, + ) -> Self { + let connection = Connection::Client { + id: rand::random(), + reauth_otp: HeapSecretKey::generate(32).unwrap(), + transport, + }; + Self::spawn(connection, strategy) + } + + /// Spawns a client using the provided [`Connection`]. + pub(crate) fn spawn(mut connection: Connection, mut strategy: ReconnectStrategy) -> Self where - R: TypedAsyncRead> + Send + 'static, - W: TypedAsyncWrite> + Send + 'static, + V: Transport + 'static, { let post_office = Arc::new(PostOffice::default()); let weak_post_office = Arc::downgrade(&post_office); + let (tx, mut rx) = mpsc::channel::>(1); + let (shutdown_tx, mut shutdown_rx) = mpsc::channel::>>(1); + + // Ensure that our transport starts off clean (nothing in buffers or backup) + connection.clear(); // Start a task that continually checks for responses and delivers them using the // post office - let response_task = tokio::spawn(async move { + let shutdown_tx_2 = shutdown_tx.clone(); + let task = tokio::spawn(async move { + let mut needs_reconnect = false; + + // NOTE: We hold onto a copy of the shutdown sender, even though we will never use it, + // to prevent the channel from being closed. This is because we do a check to + // see if we get a shutdown signal or ready state, and closing the channel + // would cause recv() to resolve immediately and result in the task shutting + // down. + let _shutdown_tx = shutdown_tx_2; + loop { - match reader.read().await { - Ok(Some(res)) => { - // Try to send response to appropriate mailbox - // TODO: How should we handle false response? Did logging in past - post_office.deliver_response(res).await; + // If we have flagged that a reconnect is needed, attempt to do so + if needs_reconnect { + info!("Client encountered issue, attempting to reconnect"); + if log::log_enabled!(log::Level::Debug) { + debug!("Using strategy {strategy:?}"); } - Ok(None) => { - break; + match strategy.reconnect(&mut connection).await { + Ok(x) => { + needs_reconnect = false; + x + } + Err(x) => { + error!("Unable to re-establish connection: {x}"); + break Err(x); + } } - Err(_) => { - break; + } + + let ready = tokio::select! { + // NOTE: This should NEVER return None as we never allow the channel to close. + cb = shutdown_rx.recv() => { + debug!("Client got shutdown signal, so exiting event loop"); + let cb = cb.expect("Impossible: shutdown channel closed!"); + let _ = cb.send(Ok(())); + break Ok(()); + } + result = connection.ready(Interest::READABLE | Interest::WRITABLE) => { + match result { + Ok(result) => result, + Err(x) => { + error!("Failed to examine ready state: {x}"); + needs_reconnect = true; + continue; + } + } + } + }; + + let mut read_blocked = !ready.is_readable(); + let mut write_blocked = !ready.is_writable(); + + if ready.is_readable() { + match connection.try_read_frame() { + Ok(Some(frame)) => { + match UntypedResponse::from_slice(frame.as_item()) { + Ok(response) => { + if log_enabled!(Level::Trace) { + trace!( + "Client receiving {}", + String::from_utf8_lossy(&response.to_bytes()) + .to_string() + ); + } + // Try to send response to appropriate mailbox + // TODO: This will block if full... is that a problem? + // TODO: How should we handle false response? Did logging in past + post_office + .deliver_untyped_response(response.into_owned()) + .await; + } + Err(x) => { + error!("Invalid response: {x}"); + } + } + } + Ok(None) => { + debug!("Connection closed"); + needs_reconnect = true; + continue; + } + Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true, + Err(x) => { + error!("Failed to read next frame: {x}"); + needs_reconnect = true; + continue; + } } } - } - }); - let (tx, mut rx) = mpsc::channel::>(1); - let request_task = tokio::spawn(async move { - while let Some(req) = rx.recv().await { - if writer.write(req).await.is_err() { - break; + if ready.is_writable() { + // If we get more data to write, attempt to write it, which will result in + // writing any queued bytes as well. Othewise, we attempt to flush any pending + // outgoing bytes that weren't sent earlier. + if let Ok(request) = rx.try_recv() { + if log_enabled!(Level::Trace) { + trace!( + "Client sending {}", + String::from_utf8_lossy(&request.to_bytes()).to_string() + ); + } + match connection.try_write_frame(request.to_bytes()) { + Ok(()) => (), + Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true, + Err(x) => { + error!("Failed to write frame: {x}"); + needs_reconnect = true; + continue; + } + } + } else { + // In the case of flushing, there are two scenarios in which we want to + // mark no write occurring: + // + // 1. When flush did not write any bytes, which can happen when the buffer + // is empty + // 2. When the call to write bytes blocks + match connection.try_flush() { + Ok(0) => write_blocked = true, + Ok(_) => (), + Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true, + Err(x) => { + error!("Failed to flush outgoing data: {x}"); + needs_reconnect = true; + continue; + } + } + } + } + + // If we did not read or write anything, sleep a bit to offload CPU usage + if read_blocked && write_blocked { + tokio::time::sleep(SLEEP_DURATION).await; } } }); - let channel = Channel { + let channel = UntypedChannel { tx, post_office: weak_post_office, }; - Ok(Self { + Self { channel, - request_task, - response_task, - }) + shutdown: Box::new(shutdown_tx), + task, + } } +} - /// Initializes a client using the provided framed transport - pub fn from_framed_transport(transport: FramedTransport) -> io::Result - where - TR: RawTransport + IntoSplit + 'static, - ::Read: RawTransportRead, - ::Write: RawTransportWrite, - C: Codec + Send + 'static, - { - let (writer, reader) = transport.into_split(); - Self::new(writer, reader) +impl Deref for UntypedClient { + type Target = UntypedChannel; + + fn deref(&self) -> &Self::Target { + &self.channel + } +} + +impl DerefMut for UntypedClient { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.channel } +} - /// Convert into underlying channel +impl From for UntypedChannel { + fn from(client: UntypedClient) -> Self { + client.channel + } +} + +/// Represents a client that can be used to send requests & receive responses from a server. +pub struct Client { + /// Used to send requests to a server. + channel: Channel, + + /// Used to send shutdown request to inner task. + shutdown: Box, + + /// Contains the task that is running to send requests and receive responses from a server. + task: JoinHandle>, +} + +impl fmt::Debug for Client { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Client") + .field("channel", &self.channel) + .field("shutdown", &"...") + .field("task", &self.task) + .finish() + } +} + +impl Client +where + T: Send + Sync + Serialize + 'static, + U: Send + Sync + DeserializeOwned + 'static, +{ + /// Consumes the client, returning an untyped variant. + pub fn into_untyped_client(self) -> UntypedClient { + UntypedClient { + channel: self.channel.into_untyped_channel(), + shutdown: self.shutdown, + task: self.task, + } + } + + /// Spawns a client using the provided [`FramedTransport`] of [`InmemoryTransport`] and a + /// specific [`ReconnectStrategy`]. + /// + /// ### Note + /// + /// This will NOT perform any handshakes or authentication procedures nor will it replay any + /// missing frames. This is to be used when establishing a [`Client`] to be run internally + /// within a program. + pub fn spawn_inmemory( + transport: FramedTransport, + strategy: ReconnectStrategy, + ) -> Self { + UntypedClient::spawn_inmemory(transport, strategy).into_typed_client() + } +} + +impl Client<(), ()> { + /// Creates a new [`ClientBuilder`]. + pub fn build() -> ClientBuilder<(), ()> { + ClientBuilder::new() + } + + /// Creates a new [`ClientBuilder`] configured to use a [`TcpConnector`]. + pub fn tcp(connector: impl Into>) -> ClientBuilder<(), TcpConnector> { + ClientBuilder::new().connector(connector.into()) + } + + /// Creates a new [`ClientBuilder`] configured to use a [`UnixSocketConnector`]. + #[cfg(unix)] + pub fn unix_socket( + connector: impl Into, + ) -> ClientBuilder<(), UnixSocketConnector> { + ClientBuilder::new().connector(connector.into()) + } + + /// Creates a new [`ClientBuilder`] configured to use a local [`WindowsPipeConnector`]. + #[cfg(windows)] + pub fn local_windows_pipe( + connector: impl Into, + ) -> ClientBuilder<(), WindowsPipeConnector> { + let mut connector = connector.into(); + connector.local = true; + ClientBuilder::new().connector(connector) + } + + /// Creates a new [`ClientBuilder`] configured to use a [`WindowsPipeConnector`]. + #[cfg(windows)] + pub fn windows_pipe( + connector: impl Into, + ) -> ClientBuilder<(), WindowsPipeConnector> { + ClientBuilder::new().connector(connector.into()) + } +} + +impl Client { + /// Convert into underlying channel. pub fn into_channel(self) -> Channel { self.channel } - /// Clones the underlying channel for requests and returns the cloned instance + /// Clones the underlying channel for requests and returns the cloned instance. pub fn clone_channel(&self) -> Channel { self.channel.clone() } - /// Waits for the client to terminate, which results when the receiving end of the network - /// connection is closed (or the client is shutdown) - pub async fn wait(self) -> Result<(), JoinError> { - tokio::try_join!(self.request_task, self.response_task).map(|_| ()) + /// Waits for the client to terminate, which resolves when the receiving end of the network + /// connection is closed (or the client is shutdown). Returns whether or not the client exited + /// successfully or due to an error. + pub async fn wait(self) -> io::Result<()> { + match self.task.await { + Ok(x) => x, + Err(x) => Err(io::Error::new(io::ErrorKind::Other, x)), + } } - /// Abort the client's current connection by forcing its tasks to abort + /// Abort the client's current connection by forcing its tasks to abort. pub fn abort(&self) { - self.request_task.abort(); - self.response_task.abort(); + self.task.abort(); + } + + /// Clones the underlying shutdown signaler for the client. This enables you to wait on the + /// client while still having the option to shut it down from somewhere else. + pub fn clone_shutdown(&self) -> Box { + self.shutdown.clone() } - /// Returns true if client's underlying event processing has finished/terminated + /// Signal for the client to shutdown its connection cleanly. + pub async fn shutdown(&self) -> io::Result<()> { + self.shutdown.shutdown().await + } + + /// Returns true if client's underlying event processing has finished/terminated. pub fn is_finished(&self) -> bool { - self.request_task.is_finished() && self.response_task.is_finished() + self.task.is_finished() } } -impl Deref for Client -where - T: Send + Sync + Serialize + 'static, - U: Send + Sync + DeserializeOwned + 'static, -{ +impl Deref for Client { type Target = Channel; fn deref(&self) -> &Self::Target { @@ -142,22 +454,659 @@ where } } -impl DerefMut for Client -where - T: Send + Sync + Serialize + 'static, - U: Send + Sync + DeserializeOwned + 'static, -{ +impl DerefMut for Client { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.channel } } -impl From> for Channel -where - T: Send + Sync + Serialize + 'static, - U: Send + Sync + DeserializeOwned + 'static, -{ +impl From> for Channel { fn from(client: Client) -> Self { client.channel } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::{Ready, Request, Response, TestTransport}; + + mod typed { + use super::*; + use test_log::test; + type TestClient = Client; + + fn spawn_test_client( + connection: Connection, + strategy: ReconnectStrategy, + ) -> TestClient + where + T: Transport + 'static, + { + UntypedClient::spawn(connection, strategy).into_typed_client() + } + + /// Creates a new test transport whose operations do not panic, but do nothing. + #[inline] + fn new_test_transport() -> TestTransport { + TestTransport { + f_try_read: Box::new(|_| Err(io::ErrorKind::WouldBlock.into())), + f_try_write: Box::new(|_| Err(io::ErrorKind::WouldBlock.into())), + f_ready: Box::new(|_| Ok(Ready::EMPTY)), + f_reconnect: Box::new(|| Ok(())), + } + } + + #[test(tokio::test)] + async fn should_write_queued_requests_as_outgoing_frames() { + let (client, mut server) = Connection::pair(100); + + let mut client = spawn_test_client(client, ReconnectStrategy::Fail); + client.fire(Request::new(1u8)).await.unwrap(); + client.fire(Request::new(2u8)).await.unwrap(); + client.fire(Request::new(3u8)).await.unwrap(); + + assert_eq!( + server + .read_frame_as::>() + .await + .unwrap() + .unwrap() + .payload, + 1 + ); + assert_eq!( + server + .read_frame_as::>() + .await + .unwrap() + .unwrap() + .payload, + 2 + ); + assert_eq!( + server + .read_frame_as::>() + .await + .unwrap() + .unwrap() + .payload, + 3 + ); + } + + #[test(tokio::test)] + async fn should_read_incoming_frames_as_responses_and_deliver_them_to_waiting_mailboxes() { + let (client, mut server) = Connection::pair(100); + + // NOTE: Spawn a separate task to handle the response so we do not deadlock + tokio::spawn(async move { + let request = server + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + server + .write_frame_for(&Response::new(request.id, 2u8)) + .await + .unwrap(); + }); + + let mut client = spawn_test_client(client, ReconnectStrategy::Fail); + assert_eq!(client.send(Request::new(1u8)).await.unwrap().payload, 2); + } + + #[test(tokio::test)] + async fn should_attempt_to_reconnect_if_connection_fails_to_determine_state() { + let (reconnect_tx, mut reconnect_rx) = mpsc::channel(1); + spawn_test_client( + Connection::test_client({ + let mut transport = new_test_transport(); + + transport.f_ready = Box::new(|_| Err(io::ErrorKind::Other.into())); + + // Send a signal that the reconnect happened while marking it successful + transport.f_reconnect = Box::new(move || { + reconnect_tx.try_send(()).expect("reconnect tx blocked"); + Ok(()) + }); + + transport + }), + ReconnectStrategy::FixedInterval { + interval: Duration::from_millis(50), + max_retries: None, + timeout: None, + }, + ); + + reconnect_rx.recv().await.expect("Reconnect did not occur"); + } + + #[test(tokio::test)] + async fn should_attempt_to_reconnect_if_connection_closed_by_server() { + let (reconnect_tx, mut reconnect_rx) = mpsc::channel(1); + spawn_test_client( + Connection::test_client({ + let mut transport = new_test_transport(); + + // Report back that we're readable to trigger try_read + transport.f_ready = Box::new(|_| Ok(Ready::READABLE)); + + // Report that no bytes were written, indicting the channel was closed + transport.f_try_read = Box::new(|_| Ok(0)); + + // Send a signal that the reconnect happened while marking it successful + transport.f_reconnect = Box::new(move || { + reconnect_tx.try_send(()).expect("reconnect tx blocked"); + Ok(()) + }); + + transport + }), + ReconnectStrategy::FixedInterval { + interval: Duration::from_millis(50), + max_retries: None, + timeout: None, + }, + ); + + reconnect_rx.recv().await.expect("Reconnect did not occur"); + } + + #[test(tokio::test)] + async fn should_attempt_to_reconnect_if_connection_errors_while_reading_data() { + let (reconnect_tx, mut reconnect_rx) = mpsc::channel(1); + spawn_test_client( + Connection::test_client({ + let mut transport = new_test_transport(); + + // Report back that we're readable to trigger try_read + transport.f_ready = Box::new(|_| Ok(Ready::READABLE)); + + // Fail the read + transport.f_try_read = Box::new(|_| Err(io::ErrorKind::Other.into())); + + // Send a signal that the reconnect happened while marking it successful + transport.f_reconnect = Box::new(move || { + reconnect_tx.try_send(()).expect("reconnect tx blocked"); + Ok(()) + }); + + transport + }), + ReconnectStrategy::FixedInterval { + interval: Duration::from_millis(50), + max_retries: None, + timeout: None, + }, + ); + + reconnect_rx.recv().await.expect("Reconnect did not occur"); + } + + #[test(tokio::test)] + async fn should_attempt_to_reconnect_if_connection_unable_to_send_new_request() { + let (reconnect_tx, mut reconnect_rx) = mpsc::channel(1); + let mut client = spawn_test_client( + Connection::test_client({ + let mut transport = new_test_transport(); + + // Report back that we're readable to trigger try_read + transport.f_ready = Box::new(|_| Ok(Ready::WRITABLE)); + + // Fail the write + transport.f_try_write = Box::new(|_| Err(io::ErrorKind::Other.into())); + + // Send a signal that the reconnect happened while marking it successful + transport.f_reconnect = Box::new(move || { + reconnect_tx.try_send(()).expect("reconnect tx blocked"); + Ok(()) + }); + + transport + }), + ReconnectStrategy::FixedInterval { + interval: Duration::from_millis(50), + max_retries: None, + timeout: None, + }, + ); + + // Queue up a request to fail to send + client + .fire(Request::new(123u8)) + .await + .expect("Failed to queue request"); + + reconnect_rx.recv().await.expect("Reconnect did not occur"); + } + + #[test(tokio::test)] + async fn should_attempt_to_reconnect_if_connection_unable_to_flush_an_existing_request() { + let (reconnect_tx, mut reconnect_rx) = mpsc::channel(1); + let mut client = spawn_test_client( + Connection::test_client({ + let mut transport = new_test_transport(); + + // Report back that we're readable to trigger try_read + transport.f_ready = Box::new(|_| Ok(Ready::WRITABLE)); + + // Succeed partially with initial try_write, block on second call, and then + // fail during a try_flush + transport.f_try_write = Box::new(|buf| unsafe { + static mut CNT: u8 = 0; + CNT += 1; + if CNT == 1 { + Ok(buf.len() / 2) + } else if CNT == 2 { + Err(io::ErrorKind::WouldBlock.into()) + } else { + Err(io::ErrorKind::Other.into()) + } + }); + + // Send a signal that the reconnect happened while marking it successful + transport.f_reconnect = Box::new(move || { + reconnect_tx.try_send(()).expect("reconnect tx blocked"); + Ok(()) + }); + + transport + }), + ReconnectStrategy::FixedInterval { + interval: Duration::from_millis(50), + max_retries: None, + timeout: None, + }, + ); + + // Queue up a request to fail to send + client + .fire(Request::new(123u8)) + .await + .expect("Failed to queue request"); + + reconnect_rx.recv().await.expect("Reconnect did not occur"); + } + + #[test(tokio::test)] + async fn should_exit_if_reconnect_strategy_has_failed_to_connect() { + let (client, server) = Connection::pair(100); + + // Spawn the client, verify the task is running, kill our server, and verify that the + // client does not block trying to reconnect + let client = spawn_test_client(client, ReconnectStrategy::Fail); + assert!(!client.is_finished(), "Client unexpectedly died"); + drop(server); + assert_eq!( + client.wait().await.unwrap_err().kind(), + io::ErrorKind::ConnectionAborted + ); + } + + #[test(tokio::test)] + async fn should_exit_if_shutdown_signal_detected() { + let (client, _server) = Connection::pair(100); + + let client = spawn_test_client(client, ReconnectStrategy::Fail); + client.shutdown().await.unwrap(); + + // NOTE: We wait for the client's task to conclude by using `wait` to ensure we do not + // have a race condition testing the task finished state. This will also verify + // that the task exited cleanly, rather than panicking. + client.wait().await.unwrap(); + } + + #[test(tokio::test)] + async fn should_not_exit_if_shutdown_channel_is_closed() { + let (client, mut server) = Connection::pair(100); + + // NOTE: Spawn a separate task to handle the response so we do not deadlock + tokio::spawn(async move { + let request = server + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + server + .write_frame_for(&Response::new(request.id, 2u8)) + .await + .unwrap(); + }); + + // NOTE: We consume the client to produce a channel without maintaining the shutdown + // channel in order to ensure that dropping the client does not kill the task. + let mut channel = spawn_test_client(client, ReconnectStrategy::Fail).into_channel(); + assert_eq!(channel.send(Request::new(1u8)).await.unwrap().payload, 2); + } + } + + mod untyped { + use super::*; + use test_log::test; + type TestClient = UntypedClient; + + /// Creates a new test transport whose operations do not panic, but do nothing. + #[inline] + fn new_test_transport() -> TestTransport { + TestTransport { + f_try_read: Box::new(|_| Err(io::ErrorKind::WouldBlock.into())), + f_try_write: Box::new(|_| Err(io::ErrorKind::WouldBlock.into())), + f_ready: Box::new(|_| Ok(Ready::EMPTY)), + f_reconnect: Box::new(|| Ok(())), + } + } + + #[test(tokio::test)] + async fn should_write_queued_requests_as_outgoing_frames() { + let (client, mut server) = Connection::pair(100); + + let mut client = TestClient::spawn(client, ReconnectStrategy::Fail); + client + .fire(Request::new(1u8).to_untyped_request().unwrap()) + .await + .unwrap(); + client + .fire(Request::new(2u8).to_untyped_request().unwrap()) + .await + .unwrap(); + client + .fire(Request::new(3u8).to_untyped_request().unwrap()) + .await + .unwrap(); + + assert_eq!( + server + .read_frame_as::>() + .await + .unwrap() + .unwrap() + .payload, + 1 + ); + assert_eq!( + server + .read_frame_as::>() + .await + .unwrap() + .unwrap() + .payload, + 2 + ); + assert_eq!( + server + .read_frame_as::>() + .await + .unwrap() + .unwrap() + .payload, + 3 + ); + } + + #[test(tokio::test)] + async fn should_read_incoming_frames_as_responses_and_deliver_them_to_waiting_mailboxes() { + let (client, mut server) = Connection::pair(100); + + // NOTE: Spawn a separate task to handle the response so we do not deadlock + tokio::spawn(async move { + let request = server + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + server + .write_frame_for(&Response::new(request.id, 2u8)) + .await + .unwrap(); + }); + + let mut client = TestClient::spawn(client, ReconnectStrategy::Fail); + assert_eq!( + client + .send(Request::new(1u8).to_untyped_request().unwrap()) + .await + .unwrap() + .to_typed_response::() + .unwrap() + .payload, + 2 + ); + } + + #[test(tokio::test)] + async fn should_attempt_to_reconnect_if_connection_fails_to_determine_state() { + let (reconnect_tx, mut reconnect_rx) = mpsc::channel(1); + TestClient::spawn( + Connection::test_client({ + let mut transport = new_test_transport(); + + transport.f_ready = Box::new(|_| Err(io::ErrorKind::Other.into())); + + // Send a signal that the reconnect happened while marking it successful + transport.f_reconnect = Box::new(move || { + reconnect_tx.try_send(()).expect("reconnect tx blocked"); + Ok(()) + }); + + transport + }), + ReconnectStrategy::FixedInterval { + interval: Duration::from_millis(50), + max_retries: None, + timeout: None, + }, + ); + + reconnect_rx.recv().await.expect("Reconnect did not occur"); + } + + #[test(tokio::test)] + async fn should_attempt_to_reconnect_if_connection_closed_by_server() { + let (reconnect_tx, mut reconnect_rx) = mpsc::channel(1); + TestClient::spawn( + Connection::test_client({ + let mut transport = new_test_transport(); + + // Report back that we're readable to trigger try_read + transport.f_ready = Box::new(|_| Ok(Ready::READABLE)); + + // Report that no bytes were written, indicting the channel was closed + transport.f_try_read = Box::new(|_| Ok(0)); + + // Send a signal that the reconnect happened while marking it successful + transport.f_reconnect = Box::new(move || { + reconnect_tx.try_send(()).expect("reconnect tx blocked"); + Ok(()) + }); + + transport + }), + ReconnectStrategy::FixedInterval { + interval: Duration::from_millis(50), + max_retries: None, + timeout: None, + }, + ); + + reconnect_rx.recv().await.expect("Reconnect did not occur"); + } + + #[test(tokio::test)] + async fn should_attempt_to_reconnect_if_connection_errors_while_reading_data() { + let (reconnect_tx, mut reconnect_rx) = mpsc::channel(1); + TestClient::spawn( + Connection::test_client({ + let mut transport = new_test_transport(); + + // Report back that we're readable to trigger try_read + transport.f_ready = Box::new(|_| Ok(Ready::READABLE)); + + // Fail the read + transport.f_try_read = Box::new(|_| Err(io::ErrorKind::Other.into())); + + // Send a signal that the reconnect happened while marking it successful + transport.f_reconnect = Box::new(move || { + reconnect_tx.try_send(()).expect("reconnect tx blocked"); + Ok(()) + }); + + transport + }), + ReconnectStrategy::FixedInterval { + interval: Duration::from_millis(50), + max_retries: None, + timeout: None, + }, + ); + + reconnect_rx.recv().await.expect("Reconnect did not occur"); + } + + #[test(tokio::test)] + async fn should_attempt_to_reconnect_if_connection_unable_to_send_new_request() { + let (reconnect_tx, mut reconnect_rx) = mpsc::channel(1); + let mut client = TestClient::spawn( + Connection::test_client({ + let mut transport = new_test_transport(); + + // Report back that we're readable to trigger try_read + transport.f_ready = Box::new(|_| Ok(Ready::WRITABLE)); + + // Fail the write + transport.f_try_write = Box::new(|_| Err(io::ErrorKind::Other.into())); + + // Send a signal that the reconnect happened while marking it successful + transport.f_reconnect = Box::new(move || { + reconnect_tx.try_send(()).expect("reconnect tx blocked"); + Ok(()) + }); + + transport + }), + ReconnectStrategy::FixedInterval { + interval: Duration::from_millis(50), + max_retries: None, + timeout: None, + }, + ); + + // Queue up a request to fail to send + client + .fire(Request::new(123u8).to_untyped_request().unwrap()) + .await + .expect("Failed to queue request"); + + reconnect_rx.recv().await.expect("Reconnect did not occur"); + } + + #[test(tokio::test)] + async fn should_attempt_to_reconnect_if_connection_unable_to_flush_an_existing_request() { + let (reconnect_tx, mut reconnect_rx) = mpsc::channel(1); + let mut client = TestClient::spawn( + Connection::test_client({ + let mut transport = new_test_transport(); + + // Report back that we're readable to trigger try_read + transport.f_ready = Box::new(|_| Ok(Ready::WRITABLE)); + + // Succeed partially with initial try_write, block on second call, and then + // fail during a try_flush + transport.f_try_write = Box::new(|buf| unsafe { + static mut CNT: u8 = 0; + CNT += 1; + if CNT == 1 { + Ok(buf.len() / 2) + } else if CNT == 2 { + Err(io::ErrorKind::WouldBlock.into()) + } else { + Err(io::ErrorKind::Other.into()) + } + }); + + // Send a signal that the reconnect happened while marking it successful + transport.f_reconnect = Box::new(move || { + reconnect_tx.try_send(()).expect("reconnect tx blocked"); + Ok(()) + }); + + transport + }), + ReconnectStrategy::FixedInterval { + interval: Duration::from_millis(50), + max_retries: None, + timeout: None, + }, + ); + + // Queue up a request to fail to send + client + .fire(Request::new(123u8).to_untyped_request().unwrap()) + .await + .expect("Failed to queue request"); + + reconnect_rx.recv().await.expect("Reconnect did not occur"); + } + + #[test(tokio::test)] + async fn should_exit_if_reconnect_strategy_has_failed_to_connect() { + let (client, server) = Connection::pair(100); + + // Spawn the client, verify the task is running, kill our server, and verify that the + // client does not block trying to reconnect + let client = TestClient::spawn(client, ReconnectStrategy::Fail); + assert!(!client.is_finished(), "Client unexpectedly died"); + drop(server); + assert_eq!( + client.wait().await.unwrap_err().kind(), + io::ErrorKind::ConnectionAborted + ); + } + + #[test(tokio::test)] + async fn should_exit_if_shutdown_signal_detected() { + let (client, _server) = Connection::pair(100); + + let client = TestClient::spawn(client, ReconnectStrategy::Fail); + client.shutdown().await.unwrap(); + + // NOTE: We wait for the client's task to conclude by using `wait` to ensure we do not + // have a race condition testing the task finished state. This will also verify + // that the task exited cleanly, rather than panicking. + client.wait().await.unwrap(); + } + + #[test(tokio::test)] + async fn should_not_exit_if_shutdown_channel_is_closed() { + let (client, mut server) = Connection::pair(100); + + // NOTE: Spawn a separate task to handle the response so we do not deadlock + tokio::spawn(async move { + let request = server + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + server + .write_frame_for(&Response::new(request.id, 2u8)) + .await + .unwrap(); + }); + + // NOTE: We consume the client to produce a channel without maintaining the shutdown + // channel in order to ensure that dropping the client does not kill the task. + let mut channel = TestClient::spawn(client, ReconnectStrategy::Fail).into_channel(); + assert_eq!( + channel + .send(Request::new(1u8).to_untyped_request().unwrap()) + .await + .unwrap() + .to_typed_response::() + .unwrap() + .payload, + 2 + ); + } + } +} diff --git a/distant-net/src/client/builder.rs b/distant-net/src/client/builder.rs new file mode 100644 index 0000000..13a15f6 --- /dev/null +++ b/distant-net/src/client/builder.rs @@ -0,0 +1,142 @@ +mod tcp; +pub use tcp::*; + +#[cfg(unix)] +mod unix; + +#[cfg(unix)] +pub use unix::*; + +#[cfg(windows)] +mod windows; + +#[cfg(windows)] +pub use windows::*; + +use crate::client::{Client, ReconnectStrategy, UntypedClient}; +use crate::common::{authentication::AuthHandler, Connection, Transport}; +use async_trait::async_trait; +use std::{convert, io, time::Duration}; + +/// Interface that performs the connection to produce a [`Transport`] for use by the [`Client`]. +#[async_trait] +pub trait Connector { + /// Type of transport produced by the connection. + type Transport: Transport + 'static; + + async fn connect(self) -> io::Result; +} + +#[async_trait] +impl Connector for T { + type Transport = T; + + async fn connect(self) -> io::Result { + Ok(self) + } +} + +/// Builder for a [`Client`] or [`UntypedClient`]. +pub struct ClientBuilder { + auth_handler: H, + connector: C, + reconnect_strategy: ReconnectStrategy, + timeout: Option, +} + +impl ClientBuilder { + pub fn auth_handler(self, auth_handler: U) -> ClientBuilder { + ClientBuilder { + auth_handler, + connector: self.connector, + reconnect_strategy: self.reconnect_strategy, + timeout: self.timeout, + } + } + + pub fn connector(self, connector: U) -> ClientBuilder { + ClientBuilder { + auth_handler: self.auth_handler, + connector, + reconnect_strategy: self.reconnect_strategy, + timeout: self.timeout, + } + } + + pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> ClientBuilder { + ClientBuilder { + auth_handler: self.auth_handler, + connector: self.connector, + reconnect_strategy, + timeout: self.timeout, + } + } + + pub fn timeout(self, timeout: impl Into>) -> Self { + Self { + auth_handler: self.auth_handler, + connector: self.connector, + reconnect_strategy: self.reconnect_strategy, + timeout: timeout.into(), + } + } +} + +impl ClientBuilder<(), ()> { + pub fn new() -> Self { + Self { + auth_handler: (), + reconnect_strategy: ReconnectStrategy::default(), + connector: (), + timeout: None, + } + } +} + +impl Default for ClientBuilder<(), ()> { + fn default() -> Self { + Self::new() + } +} + +impl ClientBuilder +where + H: AuthHandler + Send, + C: Connector, +{ + /// Establishes a connection with a remote server using the configured [`Transport`] + /// and other settings, returning a new [`UntypedClient`] instance once the connection + /// is fully established and authenticated. + pub async fn connect_untyped(self) -> io::Result { + let auth_handler = self.auth_handler; + let retry_strategy = self.reconnect_strategy; + let timeout = self.timeout; + + let f = async move { + let transport = match timeout { + Some(duration) => tokio::time::timeout(duration, self.connector.connect()) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity)?, + None => self.connector.connect().await?, + }; + let connection = Connection::client(transport, auth_handler).await?; + Ok(UntypedClient::spawn(connection, retry_strategy)) + }; + + match timeout { + Some(duration) => tokio::time::timeout(duration, f) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity), + None => f.await, + } + } + + /// Establishes a connection with a remote server using the configured [`Transport`] and other + /// settings, returning a new [`Client`] instance once the connection is fully established and + /// authenticated. + pub async fn connect(self) -> io::Result> { + Ok(self.connect_untyped().await?.into_typed_client()) + } +} diff --git a/distant-net/src/client/builder/tcp.rs b/distant-net/src/client/builder/tcp.rs new file mode 100644 index 0000000..ae7f345 --- /dev/null +++ b/distant-net/src/client/builder/tcp.rs @@ -0,0 +1,31 @@ +use super::Connector; +use crate::common::TcpTransport; +use async_trait::async_trait; +use std::io; +use tokio::net::ToSocketAddrs; + +/// Implementation of [`Connector`] to support connecting via TCP. +pub struct TcpConnector { + addr: T, +} + +impl TcpConnector { + pub fn new(addr: T) -> Self { + Self { addr } + } +} + +impl From for TcpConnector { + fn from(addr: T) -> Self { + Self::new(addr) + } +} + +#[async_trait] +impl Connector for TcpConnector { + type Transport = TcpTransport; + + async fn connect(self) -> io::Result { + TcpTransport::connect(self.addr).await + } +} diff --git a/distant-net/src/client/builder/unix.rs b/distant-net/src/client/builder/unix.rs new file mode 100644 index 0000000..b934e28 --- /dev/null +++ b/distant-net/src/client/builder/unix.rs @@ -0,0 +1,30 @@ +use super::Connector; +use crate::common::UnixSocketTransport; +use async_trait::async_trait; +use std::{io, path::PathBuf}; + +/// Implementation of [`Connector`] to support connecting via a Unix socket. +pub struct UnixSocketConnector { + path: PathBuf, +} + +impl UnixSocketConnector { + pub fn new(path: impl Into) -> Self { + Self { path: path.into() } + } +} + +impl> From for UnixSocketConnector { + fn from(path: T) -> Self { + Self::new(path) + } +} + +#[async_trait] +impl Connector for UnixSocketConnector { + type Transport = UnixSocketTransport; + + async fn connect(self) -> io::Result { + UnixSocketTransport::connect(self.path).await + } +} diff --git a/distant-net/src/client/builder/windows.rs b/distant-net/src/client/builder/windows.rs new file mode 100644 index 0000000..4b0b0ba --- /dev/null +++ b/distant-net/src/client/builder/windows.rs @@ -0,0 +1,50 @@ +use super::Connector; +use crate::common::WindowsPipeTransport; +use async_trait::async_trait; +use std::ffi::OsString; +use std::io; + +/// Implementation of [`Connector`] to support connecting via a Windows named pipe. +pub struct WindowsPipeConnector { + addr: OsString, + pub(crate) local: bool, +} + +impl WindowsPipeConnector { + /// Creates a new connector for a non-local pipe using the given `addr`. + pub fn new(addr: impl Into) -> Self { + Self { + addr: addr.into(), + local: false, + } + } + + /// Creates a new connector for a local pipe using the given `name`. + pub fn local(name: impl Into) -> Self { + Self { + addr: name.into(), + local: true, + } + } +} + +impl> From for WindowsPipeConnector { + fn from(addr: T) -> Self { + Self::new(addr) + } +} + +#[async_trait] +impl Connector for WindowsPipeConnector { + type Transport = WindowsPipeTransport; + + async fn connect(self) -> io::Result { + if self.local { + let mut full_addr = OsString::from(r"\\.\pipe\"); + full_addr.push(self.addr); + WindowsPipeTransport::connect(full_addr).await + } else { + WindowsPipeTransport::connect(self.addr).await + } + } +} diff --git a/distant-net/src/client/channel.rs b/distant-net/src/client/channel.rs index 02827e2..fb86641 100644 --- a/distant-net/src/client/channel.rs +++ b/distant-net/src/client/channel.rs @@ -1,5 +1,7 @@ -use crate::{Request, Response}; -use std::{convert, io, sync::Weak}; +use crate::common::{Request, Response, UntypedRequest, UntypedResponse}; +use log::*; +use serde::{de::DeserializeOwned, Serialize}; +use std::{convert, fmt, io, marker::PhantomData, sync::Weak}; use tokio::{sync::mpsc, time::Duration}; mod mailbox; @@ -9,26 +11,181 @@ pub use mailbox::*; const CHANNEL_MAILBOX_CAPACITY: usize = 10000; /// Represents a sender of requests tied to a session, holding onto a weak reference of -/// mailboxes to relay responses, meaning that once the [`Session`] is closed or dropped, -/// any sent request will no longer be able to receive responses -pub struct Channel +/// mailboxes to relay responses, meaning that once the [`Client`] is closed or dropped, +/// any sent request will no longer be able to receive responses. +/// +/// [`Client`]: crate::client::Client +pub struct Channel { + inner: UntypedChannel, + _request: PhantomData, + _response: PhantomData, +} + +// NOTE: Implemented manually to avoid needing clone to be defined on generic types +impl Clone for Channel { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + _request: self._request, + _response: self._response, + } + } +} + +impl fmt::Debug for Channel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Channel") + .field("tx", &self.inner.tx) + .field("post_office", &self.inner.post_office) + .field("_request", &self._request) + .field("_response", &self._response) + .finish() + } +} + +impl Channel where - T: Send + Sync, - U: Send + Sync, + T: Send + Sync + Serialize + 'static, + U: Send + Sync + DeserializeOwned + 'static, { + /// Returns true if no more requests can be transferred + pub fn is_closed(&self) -> bool { + self.inner.is_closed() + } + + /// Consumes this channel, returning an untyped variant + pub fn into_untyped_channel(self) -> UntypedChannel { + self.inner + } + + /// Assigns a default mailbox for any response received that does not match another mailbox. + pub async fn assign_default_mailbox(&self, buffer: usize) -> io::Result>> { + Ok(map_to_typed_mailbox( + self.inner.assign_default_mailbox(buffer).await?, + )) + } + + /// Removes the default mailbox used for unmatched responses such that any response without a + /// matching mailbox will be dropped. + pub async fn remove_default_mailbox(&self) -> io::Result<()> { + self.inner.remove_default_mailbox().await + } + + /// Sends a request and returns a mailbox that can receive one or more responses, failing if + /// unable to send a request or if the session's receiving line to the remote server has + /// already been severed + pub async fn mail(&mut self, req: impl Into>) -> io::Result>> { + Ok(map_to_typed_mailbox( + self.inner.mail(req.into().to_untyped_request()?).await?, + )) + } + + /// Sends a request and returns a mailbox, timing out after duration has passed + pub async fn mail_timeout( + &mut self, + req: impl Into>, + duration: impl Into>, + ) -> io::Result>> { + match duration.into() { + Some(duration) => tokio::time::timeout(duration, self.mail(req)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity), + None => self.mail(req).await, + } + } + + /// Sends a request and waits for a response, failing if unable to send a request or if + /// the session's receiving line to the remote server has already been severed + pub async fn send(&mut self, req: impl Into>) -> io::Result> { + // Send mail and get back a mailbox + let mut mailbox = self.mail(req).await?; + + // Wait for first response, and then drop the mailbox + mailbox + .next() + .await + .ok_or_else(|| io::Error::from(io::ErrorKind::ConnectionAborted)) + } + + /// Sends a request and waits for a response, timing out after duration has passed + pub async fn send_timeout( + &mut self, + req: impl Into>, + duration: impl Into>, + ) -> io::Result> { + match duration.into() { + Some(duration) => tokio::time::timeout(duration, self.send(req)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity), + None => self.send(req).await, + } + } + + /// Sends a request without waiting for a response; this method is able to be used even + /// if the session's receiving line to the remote server has been severed + pub async fn fire(&mut self, req: impl Into>) -> io::Result<()> { + self.inner.fire(req.into().to_untyped_request()?).await + } + + /// Sends a request without waiting for a response, timing out after duration has passed + pub async fn fire_timeout( + &mut self, + req: impl Into>, + duration: impl Into>, + ) -> io::Result<()> { + match duration.into() { + Some(duration) => tokio::time::timeout(duration, self.fire(req)) + .await + .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) + .and_then(convert::identity), + None => self.fire(req).await, + } + } +} + +fn map_to_typed_mailbox( + mailbox: Mailbox>, +) -> Mailbox> { + mailbox.map_opt(|res| match res.to_typed_response() { + Ok(res) => Some(res), + Err(x) => { + if log::log_enabled!(Level::Trace) { + trace!( + "Invalid response payload: {}", + String::from_utf8_lossy(&res.payload) + ); + } + + error!( + "Unable to parse response payload into {}: {x}", + std::any::type_name::() + ); + None + } + }) +} + +/// Represents a sender of requests tied to a session, holding onto a weak reference of +/// mailboxes to relay responses, meaning that once the [`Client`] is closed or dropped, +/// any sent request will no longer be able to receive responses. +/// +/// In contrast to [`Channel`], this implementation is untyped, meaning that the payload of +/// requests and responses are not validated. +/// +/// [`Client`]: crate::client::Client +#[derive(Debug)] +pub struct UntypedChannel { /// Used to send requests to a server - pub(crate) tx: mpsc::Sender>, + pub(crate) tx: mpsc::Sender>, /// Collection of mailboxes for receiving responses to requests - pub(crate) post_office: Weak>>, + pub(crate) post_office: Weak>>, } // NOTE: Implemented manually to avoid needing clone to be defined on generic types -impl Clone for Channel -where - T: Send + Sync, - U: Send + Sync, -{ +impl Clone for UntypedChannel { fn clone(&self) -> Self { Self { tx: self.tx.clone(), @@ -37,31 +194,66 @@ where } } -impl Channel -where - T: Send + Sync, - U: Send + Sync + 'static, -{ +impl UntypedChannel { /// Returns true if no more requests can be transferred pub fn is_closed(&self) -> bool { self.tx.is_closed() } + /// Consumes this channel, returning a typed variant + pub fn into_typed_channel(self) -> Channel { + Channel { + inner: self, + _request: PhantomData, + _response: PhantomData, + } + } + + /// Assigns a default mailbox for any response received that does not match another mailbox. + pub async fn assign_default_mailbox( + &self, + buffer: usize, + ) -> io::Result>> { + match Weak::upgrade(&self.post_office) { + Some(post_office) => Ok(post_office.assign_default_mailbox(buffer).await), + None => Err(io::Error::new( + io::ErrorKind::NotConnected, + "Channel's post office is no longer available", + )), + } + } + + /// Removes the default mailbox used for unmatched responses such that any response without a + /// matching mailbox will be dropped. + pub async fn remove_default_mailbox(&self) -> io::Result<()> { + match Weak::upgrade(&self.post_office) { + Some(post_office) => { + post_office.remove_default_mailbox().await; + Ok(()) + } + None => Err(io::Error::new( + io::ErrorKind::NotConnected, + "Channel's post office is no longer available", + )), + } + } + /// Sends a request and returns a mailbox that can receive one or more responses, failing if /// unable to send a request or if the session's receiving line to the remote server has /// already been severed - pub async fn mail(&mut self, req: impl Into>) -> io::Result>> { - let req = req.into(); - + pub async fn mail( + &mut self, + req: UntypedRequest<'_>, + ) -> io::Result>> { // First, create a mailbox using the request's id let mailbox = Weak::upgrade(&self.post_office) .ok_or_else(|| { io::Error::new( io::ErrorKind::NotConnected, - "Session's post office is no longer available", + "Channel's post office is no longer available", ) })? - .make_mailbox(req.id.clone(), CHANNEL_MAILBOX_CAPACITY) + .make_mailbox(req.id.clone().into_owned(), CHANNEL_MAILBOX_CAPACITY) .await; // Second, send the request @@ -74,9 +266,9 @@ where /// Sends a request and returns a mailbox, timing out after duration has passed pub async fn mail_timeout( &mut self, - req: impl Into>, + req: UntypedRequest<'_>, duration: impl Into>, - ) -> io::Result>> { + ) -> io::Result>> { match duration.into() { Some(duration) => tokio::time::timeout(duration, self.mail(req)) .await @@ -88,7 +280,7 @@ where /// Sends a request and waits for a response, failing if unable to send a request or if /// the session's receiving line to the remote server has already been severed - pub async fn send(&mut self, req: impl Into>) -> io::Result> { + pub async fn send(&mut self, req: UntypedRequest<'_>) -> io::Result> { // Send mail and get back a mailbox let mut mailbox = self.mail(req).await?; @@ -102,9 +294,9 @@ where /// Sends a request and waits for a response, timing out after duration has passed pub async fn send_timeout( &mut self, - req: impl Into>, + req: UntypedRequest<'_>, duration: impl Into>, - ) -> io::Result> { + ) -> io::Result> { match duration.into() { Some(duration) => tokio::time::timeout(duration, self.send(req)) .await @@ -116,9 +308,9 @@ where /// Sends a request without waiting for a response; this method is able to be used even /// if the session's receiving line to the remote server has been severed - pub async fn fire(&mut self, req: impl Into>) -> io::Result<()> { + pub async fn fire(&mut self, req: UntypedRequest<'_>) -> io::Result<()> { self.tx - .send(req.into()) + .send(req.into_owned()) .await .map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x.to_string())) } @@ -126,7 +318,7 @@ where /// Sends a request without waiting for a response, timing out after duration has passed pub async fn fire_timeout( &mut self, - req: impl Into>, + req: UntypedRequest<'_>, duration: impl Into>, ) -> io::Result<()> { match duration.into() { @@ -142,95 +334,227 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{Client, FramedTransport, TypedAsyncRead, TypedAsyncWrite}; - use std::time::Duration; - type TestClient = Client; + mod typed { + use super::*; + use std::sync::Arc; + use std::time::Duration; + use test_log::test; + + type TestChannel = Channel; + type Setup = ( + TestChannel, + mpsc::Receiver>, + Arc>>, + ); + + fn setup(buffer: usize) -> Setup { + let post_office = Arc::new(PostOffice::default()); + let (tx, rx) = mpsc::channel(buffer); + let channel = { + let post_office = Arc::downgrade(&post_office); + UntypedChannel { tx, post_office } + }; + + (channel.into_typed_channel(), rx, post_office) + } + + #[test(tokio::test)] + async fn mail_should_return_mailbox_that_receives_responses_until_post_office_drops_it() { + let (mut channel, _server, post_office) = setup(100); + + let req = Request::new(0); + let res = Response::new(req.id.clone(), 1); + + let mut mailbox = channel.mail(req).await.unwrap(); + + // Send and receive first response + assert!( + post_office + .deliver_untyped_response(res.to_untyped_response().unwrap().into_owned()) + .await, + "Failed to deliver: {res:?}" + ); + assert_eq!(mailbox.next().await, Some(res.clone())); + + // Send and receive second response + assert!( + post_office + .deliver_untyped_response(res.to_untyped_response().unwrap().into_owned()) + .await, + "Failed to deliver: {res:?}" + ); + assert_eq!(mailbox.next().await, Some(res.clone())); + + // Trigger the mailbox to wait BEFORE closing our mailbox to ensure that + // we don't get stuck if the mailbox was already waiting + let next_task = tokio::spawn(async move { mailbox.next().await }); + tokio::task::yield_now().await; + + // Close our specific mailbox + post_office.cancel(&res.origin_id).await; + + match next_task.await { + Ok(None) => {} + x => panic!("Unexpected response: {:?}", x), + } + } - #[tokio::test] - async fn mail_should_return_mailbox_that_receives_responses_until_transport_closes() { - let (t1, mut t2) = FramedTransport::make_test_pair(); - let session: TestClient = Client::from_framed_transport(t1).unwrap(); - let mut channel = session.clone_channel(); + #[test(tokio::test)] + async fn send_should_wait_until_response_received() { + let (mut channel, _server, post_office) = setup(100); + + let req = Request::new(0); + let res = Response::new(req.id.clone(), 1); + + let (actual, _) = tokio::join!( + channel.send(req), + post_office + .deliver_untyped_response(res.to_untyped_response().unwrap().into_owned()) + ); + match actual { + Ok(actual) => assert_eq!(actual, res), + x => panic!("Unexpected response: {:?}", x), + } + } - let req = Request::new(0); - let res = Response::new(req.id.clone(), 1); + #[test(tokio::test)] + async fn send_timeout_should_fail_if_response_not_received_in_time() { + let (mut channel, mut server, _post_office) = setup(100); - let mut mailbox = channel.mail(req).await.unwrap(); + let req = Request::new(0); + match channel.send_timeout(req, Duration::from_millis(30)).await { + Err(x) => assert_eq!(x.kind(), io::ErrorKind::TimedOut), + x => panic!("Unexpected response: {:?}", x), + } - // Get first response - match tokio::join!(mailbox.next(), t2.write(res.clone())) { - (Some(actual), _) => assert_eq!(actual, res), - x => panic!("Unexpected response: {:?}", x), + let _frame = server.recv().await.unwrap(); } - // Get second response - match tokio::join!(mailbox.next(), t2.write(res.clone())) { - (Some(actual), _) => assert_eq!(actual, res), - x => panic!("Unexpected response: {:?}", x), - } + #[test(tokio::test)] + async fn fire_should_send_request_and_not_wait_for_response() { + let (mut channel, mut server, _post_office) = setup(100); - // Trigger the mailbox to wait BEFORE closing our transport to ensure that - // we don't get stuck if the mailbox was already waiting - let next_task = tokio::spawn(async move { mailbox.next().await }); - tokio::task::yield_now().await; + let req = Request::new(0); + match channel.fire(req).await { + Ok(_) => {} + x => panic!("Unexpected response: {:?}", x), + } - drop(t2); - match next_task.await { - Ok(None) => {} - x => panic!("Unexpected response: {:?}", x), + let _frame = server.recv().await.unwrap(); } } - #[tokio::test] - async fn send_should_wait_until_response_received() { - let (t1, mut t2) = FramedTransport::make_test_pair(); - let session: TestClient = Client::from_framed_transport(t1).unwrap(); - let mut channel = session.clone_channel(); + mod untyped { + use super::*; + use std::sync::Arc; + use std::time::Duration; + use test_log::test; + + type TestChannel = UntypedChannel; + type Setup = ( + TestChannel, + mpsc::Receiver>, + Arc>>, + ); + + fn setup(buffer: usize) -> Setup { + let post_office = Arc::new(PostOffice::default()); + let (tx, rx) = mpsc::channel(buffer); + let channel = { + let post_office = Arc::downgrade(&post_office); + TestChannel { tx, post_office } + }; + + (channel, rx, post_office) + } - let req = Request::new(0); - let res = Response::new(req.id.clone(), 1); + #[test(tokio::test)] + async fn mail_should_return_mailbox_that_receives_responses_until_post_office_drops_it() { + let (mut channel, _server, post_office) = setup(100); + + let req = Request::new(0).to_untyped_request().unwrap().into_owned(); + let res = Response::new(req.id.clone().into_owned(), 1) + .to_untyped_response() + .unwrap() + .into_owned(); + + let mut mailbox = channel.mail(req).await.unwrap(); + + // Send and receive first response + assert!( + post_office.deliver_untyped_response(res.clone()).await, + "Failed to deliver: {res:?}" + ); + assert_eq!(mailbox.next().await, Some(res.clone())); + + // Send and receive second response + assert!( + post_office.deliver_untyped_response(res.clone()).await, + "Failed to deliver: {res:?}" + ); + assert_eq!(mailbox.next().await, Some(res.clone())); + + // Trigger the mailbox to wait BEFORE closing our mailbox to ensure that + // we don't get stuck if the mailbox was already waiting + let next_task = tokio::spawn(async move { mailbox.next().await }); + tokio::task::yield_now().await; + + // Close our specific mailbox + post_office + .cancel(&res.origin_id.clone().into_owned()) + .await; + + match next_task.await { + Ok(None) => {} + x => panic!("Unexpected response: {:?}", x), + } + } - let (actual, _) = tokio::join!(channel.send(req), t2.write(res.clone())); - match actual { - Ok(actual) => assert_eq!(actual, res), - x => panic!("Unexpected response: {:?}", x), + #[test(tokio::test)] + async fn send_should_wait_until_response_received() { + let (mut channel, _server, post_office) = setup(100); + + let req = Request::new(0).to_untyped_request().unwrap().into_owned(); + let res = Response::new(req.id.clone().into_owned(), 1) + .to_untyped_response() + .unwrap() + .into_owned(); + + let (actual, _) = tokio::join!( + channel.send(req), + post_office.deliver_untyped_response(res.clone()) + ); + match actual { + Ok(actual) => assert_eq!(actual, res), + x => panic!("Unexpected response: {:?}", x), + } } - } - #[tokio::test] - async fn send_timeout_should_fail_if_response_not_received_in_time() { - let (t1, mut t2) = FramedTransport::make_test_pair(); - let session: TestClient = Client::from_framed_transport(t1).unwrap(); - let mut channel = session.clone_channel(); + #[test(tokio::test)] + async fn send_timeout_should_fail_if_response_not_received_in_time() { + let (mut channel, mut server, _post_office) = setup(100); + + let req = Request::new(0).to_untyped_request().unwrap().into_owned(); + match channel.send_timeout(req, Duration::from_millis(30)).await { + Err(x) => assert_eq!(x.kind(), io::ErrorKind::TimedOut), + x => panic!("Unexpected response: {:?}", x), + } - let req = Request::new(0); - match channel.send_timeout(req, Duration::from_millis(30)).await { - Err(x) => assert_eq!(x.kind(), io::ErrorKind::TimedOut), - x => panic!("Unexpected response: {:?}", x), + let _frame = server.recv().await.unwrap(); } - let _req = TypedAsyncRead::>::read(&mut t2) - .await - .unwrap() - .unwrap(); - } + #[test(tokio::test)] + async fn fire_should_send_request_and_not_wait_for_response() { + let (mut channel, mut server, _post_office) = setup(100); - #[tokio::test] - async fn fire_should_send_request_and_not_wait_for_response() { - let (t1, mut t2) = FramedTransport::make_test_pair(); - let session: TestClient = Client::from_framed_transport(t1).unwrap(); - let mut channel = session.clone_channel(); + let req = Request::new(0).to_untyped_request().unwrap().into_owned(); + match channel.fire(req).await { + Ok(_) => {} + x => panic!("Unexpected response: {:?}", x), + } - let req = Request::new(0); - match channel.fire(req).await { - Ok(_) => {} - x => panic!("Unexpected response: {:?}", x), + let _frame = server.recv().await.unwrap(); } - - let _req = TypedAsyncRead::>::read(&mut t2) - .await - .unwrap() - .unwrap(); } } diff --git a/distant-net/src/client/channel/mailbox.rs b/distant-net/src/client/channel/mailbox.rs index ca3a7c4..872bc44 100644 --- a/distant-net/src/client/channel/mailbox.rs +++ b/distant-net/src/client/channel/mailbox.rs @@ -1,4 +1,5 @@ -use crate::{Id, Response}; +use crate::common::{Id, Response, UntypedResponse}; +use async_trait::async_trait; use std::{ collections::HashMap, sync::{Arc, Weak}, @@ -6,13 +7,14 @@ use std::{ }; use tokio::{ io, - sync::{mpsc, Mutex}, + sync::{mpsc, Mutex, RwLock}, time, }; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct PostOffice { mailboxes: Arc>>>, + default_box: Arc>>>, } impl Default for PostOffice @@ -51,7 +53,10 @@ where } }); - Self { mailboxes } + Self { + mailboxes, + default_box: Arc::new(RwLock::new(None)), + } } /// Creates a new mailbox using the given id and buffer size for maximum values that @@ -60,7 +65,10 @@ where let (tx, rx) = mpsc::channel(buffer); self.mailboxes.lock().await.insert(id.clone(), tx); - Mailbox { id, rx } + Mailbox { + id, + rx: Box::new(rx), + } } /// Delivers some value to appropriate mailbox, returning false if no mailbox is found @@ -75,10 +83,54 @@ where } success + } else if let Some(tx) = self.default_box.read().await.as_ref() { + tx.send(value).await.is_ok() } else { false } } + + /// Creates a new default mailbox that will be used whenever no mailbox is found to deliver + /// mail. This will replace any existing default mailbox. + pub async fn assign_default_mailbox(&self, buffer: usize) -> Mailbox { + let (tx, rx) = mpsc::channel(buffer); + *self.default_box.write().await = Some(tx); + + Mailbox { + id: "".to_string(), + rx: Box::new(rx), + } + } + + /// Removes the default mailbox such that any mail without a matching mailbox will be dropped + /// instead of being delivered to a default mailbox. + pub async fn remove_default_mailbox(&self) { + *self.default_box.write().await = None; + } + + /// Returns true if the post office is using a default mailbox for all mail that does not map + /// to another mailbox. + pub async fn has_default_mailbox(&self) -> bool { + self.default_box.read().await.is_some() + } + + /// Cancels delivery to the mailbox with the specified `id`. + pub async fn cancel(&self, id: &Id) { + self.mailboxes.lock().await.remove(id); + } + + /// Cancels delivery to the mailboxes with the specified `id`s. + pub async fn cancel_many(&self, ids: impl Iterator) { + let mut lock = self.mailboxes.lock().await; + for id in ids { + lock.remove(id); + } + } + + /// Cancels delivery to all mailboxes. + pub async fn cancel_all(&self) { + self.mailboxes.lock().await.clear(); + } } impl PostOffice> @@ -92,13 +144,120 @@ where } } +impl PostOffice> { + /// Delivers some response to appropriate mailbox, returning false if no mailbox is found + /// for the response's origin or if the mailbox is no longer receiving values + pub async fn deliver_untyped_response(&self, res: UntypedResponse<'static>) -> bool { + self.deliver(&res.origin_id.clone().into_owned(), res).await + } +} + +/// Error encountered when invoking [`try_recv`] for [`MailboxReceiver`]. +pub enum MailboxTryNextError { + Empty, + Closed, +} + +#[async_trait] +trait MailboxReceiver: Send + Sync { + type Output; + + fn try_recv(&mut self) -> Result; + + async fn recv(&mut self) -> Option; + + fn close(&mut self); +} + +#[async_trait] +impl MailboxReceiver for mpsc::Receiver { + type Output = T; + + fn try_recv(&mut self) -> Result { + match mpsc::Receiver::try_recv(self) { + Ok(x) => Ok(x), + Err(mpsc::error::TryRecvError::Empty) => Err(MailboxTryNextError::Empty), + Err(mpsc::error::TryRecvError::Disconnected) => Err(MailboxTryNextError::Closed), + } + } + + async fn recv(&mut self) -> Option { + mpsc::Receiver::recv(self).await + } + + fn close(&mut self) { + mpsc::Receiver::close(self) + } +} + +struct MappedMailboxReceiver { + rx: Box>, + f: Box U + Send + Sync>, +} + +#[async_trait] +impl MailboxReceiver for MappedMailboxReceiver { + type Output = U; + + fn try_recv(&mut self) -> Result { + match self.rx.try_recv() { + Ok(x) => Ok((self.f)(x)), + Err(x) => Err(x), + } + } + + async fn recv(&mut self) -> Option { + let value = self.rx.recv().await?; + Some((self.f)(value)) + } + + fn close(&mut self) { + self.rx.close() + } +} + +struct MappedOptMailboxReceiver { + rx: Box>, + f: Box Option + Send + Sync>, +} + +#[async_trait] +impl MailboxReceiver for MappedOptMailboxReceiver { + type Output = U; + + fn try_recv(&mut self) -> Result { + match self.rx.try_recv() { + Ok(x) => match (self.f)(x) { + Some(x) => Ok(x), + None => Err(MailboxTryNextError::Empty), + }, + Err(x) => Err(x), + } + } + + async fn recv(&mut self) -> Option { + // Continually receive a new value and convert it to Option + // until Option == Some(U) or we receive None from our inner receiver + loop { + let value = self.rx.recv().await?; + if let Some(x) = (self.f)(value) { + return Some(x); + } + } + } + + fn close(&mut self) { + self.rx.close() + } +} + /// Represents a destination for responses pub struct Mailbox { /// Represents id associated with the mailbox id: Id, /// Underlying mailbox storage - rx: mpsc::Receiver, + rx: Box>, } impl Mailbox { @@ -107,6 +266,11 @@ impl Mailbox { &self.id } + /// Tries to receive the next value in mailbox without blocking or waiting async + pub fn try_next(&mut self) -> Result { + self.rx.try_recv() + } + /// Receives next value in mailbox pub async fn next(&mut self) -> Option { self.rx.recv().await @@ -126,3 +290,31 @@ impl Mailbox { self.rx.close() } } + +impl Mailbox { + /// Maps the results of each mailbox value into a new type `U` + pub fn map(self, f: impl Fn(T) -> U + Send + Sync + 'static) -> Mailbox { + Mailbox { + id: self.id, + rx: Box::new(MappedMailboxReceiver { + rx: self.rx, + f: Box::new(f), + }), + } + } + + /// Maps the results of each mailbox value into a new type `U` by returning an `Option` + /// where the option is `None` in the case that `T` cannot be converted into `U` + pub fn map_opt( + self, + f: impl Fn(T) -> Option + Send + Sync + 'static, + ) -> Mailbox { + Mailbox { + id: self.id, + rx: Box::new(MappedOptMailboxReceiver { + rx: self.rx, + f: Box::new(f), + }), + } + } +} diff --git a/distant-net/src/client/ext/tcp.rs b/distant-net/src/client/ext/tcp.rs deleted file mode 100644 index e58a345..0000000 --- a/distant-net/src/client/ext/tcp.rs +++ /dev/null @@ -1,49 +0,0 @@ -use crate::{Client, Codec, FramedTransport, TcpTransport}; -use async_trait::async_trait; -use serde::{de::DeserializeOwned, Serialize}; -use std::{convert, net::SocketAddr}; -use tokio::{io, time::Duration}; - -#[async_trait] -pub trait TcpClientExt -where - T: Serialize + Send + Sync, - U: DeserializeOwned + Send + Sync, -{ - /// Connect to a remote TCP server using the provided information - async fn connect(addr: SocketAddr, codec: C) -> io::Result> - where - C: Codec + Send + 'static; - - /// Connect to a remote TCP server, timing out after duration has passed - async fn connect_timeout( - addr: SocketAddr, - codec: C, - duration: Duration, - ) -> io::Result> - where - C: Codec + Send + 'static, - { - tokio::time::timeout(duration, Self::connect(addr, codec)) - .await - .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) - .and_then(convert::identity) - } -} - -#[async_trait] -impl TcpClientExt for Client -where - T: Send + Sync + Serialize + 'static, - U: Send + Sync + DeserializeOwned + 'static, -{ - /// Connect to a remote TCP server using the provided information - async fn connect(addr: SocketAddr, codec: C) -> io::Result> - where - C: Codec + Send + 'static, - { - let transport = TcpTransport::connect(addr).await?; - let transport = FramedTransport::new(transport, codec); - Self::from_framed_transport(transport) - } -} diff --git a/distant-net/src/client/ext/unix.rs b/distant-net/src/client/ext/unix.rs deleted file mode 100644 index 9188f53..0000000 --- a/distant-net/src/client/ext/unix.rs +++ /dev/null @@ -1,54 +0,0 @@ -use crate::{Client, Codec, FramedTransport, IntoSplit, UnixSocketTransport}; -use async_trait::async_trait; -use serde::{de::DeserializeOwned, Serialize}; -use std::{convert, path::Path}; -use tokio::{io, time::Duration}; - -#[async_trait] -pub trait UnixSocketClientExt -where - T: Serialize + Send + Sync, - U: DeserializeOwned + Send + Sync, -{ - /// Connect to a proxy unix socket - async fn connect(path: P, codec: C) -> io::Result> - where - P: AsRef + Send, - C: Codec + Send + 'static; - - /// Connect to a proxy unix socket, timing out after duration has passed - async fn connect_timeout( - path: P, - codec: C, - duration: Duration, - ) -> io::Result> - where - P: AsRef + Send, - C: Codec + Send + 'static, - { - tokio::time::timeout(duration, Self::connect(path, codec)) - .await - .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) - .and_then(convert::identity) - } -} - -#[async_trait] -impl UnixSocketClientExt for Client -where - T: Send + Sync + Serialize + 'static, - U: Send + Sync + DeserializeOwned + 'static, -{ - /// Connect to a proxy unix socket - async fn connect(path: P, codec: C) -> io::Result> - where - P: AsRef + Send, - C: Codec + Send + 'static, - { - let p = path.as_ref(); - let transport = UnixSocketTransport::connect(p).await?; - let transport = FramedTransport::new(transport, codec); - let (writer, reader) = transport.into_split(); - Ok(Client::new(writer, reader)?) - } -} diff --git a/distant-net/src/client/ext/windows.rs b/distant-net/src/client/ext/windows.rs deleted file mode 100644 index 5186caa..0000000 --- a/distant-net/src/client/ext/windows.rs +++ /dev/null @@ -1,86 +0,0 @@ -use crate::{Client, Codec, FramedTransport, IntoSplit, WindowsPipeTransport}; -use async_trait::async_trait; -use serde::{de::DeserializeOwned, Serialize}; -use std::{ - convert, - ffi::{OsStr, OsString}, -}; -use tokio::{io, time::Duration}; - -#[async_trait] -pub trait WindowsPipeClientExt -where - T: Serialize + Send + Sync, - U: DeserializeOwned + Send + Sync, -{ - /// Connect to a server listening on a Windows pipe at the specified address - /// using the given codec - async fn connect(addr: A, codec: C) -> io::Result> - where - A: AsRef + Send, - C: Codec + Send + 'static; - - /// Connect to a server listening on a Windows pipe at the specified address - /// via `\\.\pipe\{name}` using the given codec - async fn connect_local(name: N, codec: C) -> io::Result> - where - N: AsRef + Send, - C: Codec + Send + 'static, - { - let mut addr = OsString::from(r"\\.\pipe\"); - addr.push(name.as_ref()); - Self::connect(addr, codec).await - } - - /// Connect to a server listening on a Windows pipe at the specified address - /// using the given codec, timing out after duration has passed - async fn connect_timeout( - addr: A, - codec: C, - duration: Duration, - ) -> io::Result> - where - A: AsRef + Send, - C: Codec + Send + 'static, - { - tokio::time::timeout(duration, Self::connect(addr, codec)) - .await - .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) - .and_then(convert::identity) - } - - /// Connect to a server listening on a Windows pipe at the specified address - /// via `\\.\pipe\{name}` using the given codec, timing out after duration has passed - async fn connect_local_timeout( - name: N, - codec: C, - duration: Duration, - ) -> io::Result> - where - N: AsRef + Send, - C: Codec + Send + 'static, - { - let mut addr = OsString::from(r"\\.\pipe\"); - addr.push(name.as_ref()); - Self::connect_timeout(addr, codec, duration).await - } -} - -#[async_trait] -impl WindowsPipeClientExt for Client -where - T: Send + Sync + Serialize + 'static, - U: Send + Sync + DeserializeOwned + 'static, -{ - async fn connect(addr: A, codec: C) -> io::Result> - where - A: AsRef + Send, - C: Codec + Send + 'static, - { - let a = addr.as_ref(); - let transport = WindowsPipeTransport::connect(a).await?; - let transport = FramedTransport::new(transport, codec); - let (writer, reader) = transport.into_split(); - Ok(Client::new(writer, reader)?) - } -} diff --git a/distant-net/src/client/reconnect.rs b/distant-net/src/client/reconnect.rs new file mode 100644 index 0000000..f3d1ffd --- /dev/null +++ b/distant-net/src/client/reconnect.rs @@ -0,0 +1,208 @@ +use super::Reconnectable; +use std::io; +use std::time::Duration; + +/// Represents the strategy to apply when attempting to reconnect the client to the server. +#[derive(Clone, Debug)] +pub enum ReconnectStrategy { + /// A retry strategy that will fail immediately if a reconnect is attempted. + Fail, + + /// A retry strategy driven by exponential back-off. + ExponentialBackoff { + /// Represents the initial time to wait between reconnect attempts. + base: Duration, + + /// Factor to use when modifying the retry time, used as a multiplier. + factor: f64, + + /// Represents the maximum duration to wait between attempts. None indicates no limit. + max_duration: Option, + + /// Represents the maximum attempts to retry before failing. None indicates no limit. + max_retries: Option, + + /// Represents the maximum time to wait for a reconnect attempt. None indicates no limit. + timeout: Option, + }, + + /// A retry strategy driven by the fibonacci series. + FibonacciBackoff { + /// Represents the initial time to wait between reconnect attempts. + base: Duration, + + /// Represents the maximum duration to wait between attempts. None indicates no limit. + max_duration: Option, + + /// Represents the maximum attempts to retry before failing. None indicates no limit. + max_retries: Option, + + /// Represents the maximum time to wait for a reconnect attempt. None indicates no limit. + timeout: Option, + }, + + /// A retry strategy driven by a fixed interval. + FixedInterval { + /// Represents the time between reconnect attempts. + interval: Duration, + + /// Represents the maximum attempts to retry before failing. None indicates no limit. + max_retries: Option, + + /// Represents the maximum time to wait for a reconnect attempt. None indicates no limit. + timeout: Option, + }, +} + +impl Default for ReconnectStrategy { + /// Creates a reconnect strategy that will immediately fail. + fn default() -> Self { + Self::Fail + } +} + +impl ReconnectStrategy { + pub async fn reconnect(&mut self, reconnectable: &mut T) -> io::Result<()> { + // If our strategy is to immediately fail, do so + if self.is_fail() { + return Err(io::Error::from(io::ErrorKind::ConnectionAborted)); + } + + // Keep track of last sleep length for use in adjustment + let mut previous_sleep = None; + let mut current_sleep = self.initial_sleep_duration(); + + // Keep track of remaining retries + let mut retries_remaining = self.max_retries(); + + // Get timeout if strategy will employ one + let timeout = self.timeout(); + + // Get maximum allowed duration between attempts + let max_duration = self.max_duration(); + + // Continue trying to reconnect while we have more tries remaining, otherwise + // we will return the last error encountered + let mut result = Ok(()); + + while retries_remaining.is_none() || retries_remaining > Some(0) { + // Perform reconnect attempt + result = match timeout { + Some(timeout) => { + match tokio::time::timeout(timeout, reconnectable.reconnect()).await { + Ok(x) => x, + Err(x) => Err(x.into()), + } + } + None => reconnectable.reconnect().await, + }; + + // If reconnect was successful, we're done and we can exit + if result.is_ok() { + return Ok(()); + } + + // Decrement remaining retries if we have a limit + if let Some(remaining) = retries_remaining.as_mut() { + if *remaining > 0 { + *remaining -= 1; + } + } + + // Sleep before making next attempt + tokio::time::sleep(current_sleep).await; + + // Update our sleep duration + let next_sleep = self.adjust_sleep(previous_sleep, current_sleep); + previous_sleep = Some(current_sleep); + current_sleep = if let Some(duration) = max_duration { + std::cmp::min(next_sleep, duration) + } else { + next_sleep + }; + } + + result + } + + /// Returns true if this strategy is the fail variant. + pub fn is_fail(&self) -> bool { + matches!(self, Self::Fail) + } + + /// Returns true if this strategy is the exponential backoff variant. + pub fn is_exponential_backoff(&self) -> bool { + matches!(self, Self::ExponentialBackoff { .. }) + } + + /// Returns true if this strategy is the fibonacci backoff variant. + pub fn is_fibonacci_backoff(&self) -> bool { + matches!(self, Self::FibonacciBackoff { .. }) + } + + /// Returns true if this strategy is the fixed interval variant. + pub fn is_fixed_interval(&self) -> bool { + matches!(self, Self::FixedInterval { .. }) + } + + /// Returns the maximum duration between reconnect attempts, or None if there is no limit. + pub fn max_duration(&self) -> Option { + match self { + ReconnectStrategy::Fail => None, + ReconnectStrategy::ExponentialBackoff { max_duration, .. } => *max_duration, + ReconnectStrategy::FibonacciBackoff { max_duration, .. } => *max_duration, + ReconnectStrategy::FixedInterval { .. } => None, + } + } + + /// Returns the maximum reconnect attempts the strategy will perform, or None if will attempt + /// forever. + pub fn max_retries(&self) -> Option { + match self { + ReconnectStrategy::Fail => None, + ReconnectStrategy::ExponentialBackoff { max_retries, .. } => *max_retries, + ReconnectStrategy::FibonacciBackoff { max_retries, .. } => *max_retries, + ReconnectStrategy::FixedInterval { max_retries, .. } => *max_retries, + } + } + + /// Returns the timeout per reconnect attempt that is associated with the strategy. + pub fn timeout(&self) -> Option { + match self { + ReconnectStrategy::Fail => None, + ReconnectStrategy::ExponentialBackoff { timeout, .. } => *timeout, + ReconnectStrategy::FibonacciBackoff { timeout, .. } => *timeout, + ReconnectStrategy::FixedInterval { timeout, .. } => *timeout, + } + } + + /// Returns the initial duration to sleep. + fn initial_sleep_duration(&self) -> Duration { + match self { + ReconnectStrategy::Fail => Duration::new(0, 0), + ReconnectStrategy::ExponentialBackoff { base, .. } => *base, + ReconnectStrategy::FibonacciBackoff { base, .. } => *base, + ReconnectStrategy::FixedInterval { interval, .. } => *interval, + } + } + + /// Adjusts next sleep duration based on the strategy. + fn adjust_sleep(&self, prev: Option, curr: Duration) -> Duration { + match self { + ReconnectStrategy::Fail => Duration::new(0, 0), + ReconnectStrategy::ExponentialBackoff { factor, .. } => { + let next_millis = (curr.as_millis() as f64) * factor; + Duration::from_millis(if next_millis > (std::u64::MAX as f64) { + std::u64::MAX + } else { + next_millis as u64 + }) + } + ReconnectStrategy::FibonacciBackoff { .. } => { + let prev = prev.unwrap_or_else(|| Duration::new(0, 0)); + prev.checked_add(curr).unwrap_or(Duration::MAX) + } + ReconnectStrategy::FixedInterval { .. } => curr, + } + } +} diff --git a/distant-net/src/client/shutdown.rs b/distant-net/src/client/shutdown.rs new file mode 100644 index 0000000..3a0fe94 --- /dev/null +++ b/distant-net/src/client/shutdown.rs @@ -0,0 +1,36 @@ +use async_trait::async_trait; +use dyn_clone::DynClone; +use std::io; +use tokio::sync::{mpsc, oneshot}; + +/// Interface representing functionality to shut down an active client. +#[async_trait] +pub trait Shutdown: DynClone + Send + Sync { + /// Attempts to shutdown the client. + async fn shutdown(&self) -> io::Result<()>; +} + +#[async_trait] +impl Shutdown for mpsc::Sender>> { + async fn shutdown(&self) -> io::Result<()> { + let (tx, rx) = oneshot::channel(); + match self.send(tx).await { + Ok(_) => match rx.await { + Ok(x) => x, + Err(_) => Err(already_shutdown()), + }, + Err(_) => Err(already_shutdown()), + } + } +} + +#[inline] +fn already_shutdown() -> io::Error { + io::Error::new(io::ErrorKind::Other, "Client already shutdown") +} + +impl Clone for Box { + fn clone(&self) -> Self { + dyn_clone::clone_box(&**self) + } +} diff --git a/distant-net/src/codec.rs b/distant-net/src/codec.rs deleted file mode 100644 index 9a544bd..0000000 --- a/distant-net/src/codec.rs +++ /dev/null @@ -1,38 +0,0 @@ -use bytes::BytesMut; -use std::io; -use tokio_util::codec::{Decoder, Encoder}; - -/// Represents abstraction of a codec that implements specific encoder and decoder for distant -pub trait Codec: - for<'a> Encoder<&'a [u8], Error = io::Error> + Decoder, Error = io::Error> + Clone -{ - fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()>; - fn decode(&mut self, src: &mut BytesMut) -> io::Result>>; -} - -macro_rules! impl_traits_for_codec { - ($type:ident) => { - impl<'a> tokio_util::codec::Encoder<&'a [u8]> for $type { - type Error = io::Error; - - fn encode(&mut self, item: &'a [u8], dst: &mut BytesMut) -> Result<(), Self::Error> { - Codec::encode(self, item, dst) - } - } - - impl tokio_util::codec::Decoder for $type { - type Item = Vec; - type Error = io::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - Codec::decode(self, src) - } - } - }; -} - -mod plain; -pub use plain::PlainCodec; - -mod xchacha20poly1305; -pub use xchacha20poly1305::XChaCha20Poly1305Codec; diff --git a/distant-net/src/codec/plain.rs b/distant-net/src/codec/plain.rs deleted file mode 100644 index d0e6697..0000000 --- a/distant-net/src/codec/plain.rs +++ /dev/null @@ -1,193 +0,0 @@ -use crate::Codec; -use bytes::{Buf, BufMut, BytesMut}; -use std::convert::TryInto; -use tokio::io; - -/// Total bytes to use as the len field denoting a frame's size -const LEN_SIZE: usize = 8; - -/// Represents a codec that just ships messages back and forth with no encryption or authentication -#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] -pub struct PlainCodec; -impl_traits_for_codec!(PlainCodec); - -impl PlainCodec { - pub fn new() -> Self { - Self::default() - } -} - -impl Codec for PlainCodec { - fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> { - // Validate that we can fit the message plus nonce + - if item.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Empty item provided", - )); - } - - dst.reserve(8 + item.len()); - - // Add data in form of {LEN}{ITEM} - dst.put_u64((item.len()) as u64); - dst.put_slice(item); - - Ok(()) - } - - fn decode(&mut self, src: &mut BytesMut) -> io::Result>> { - // First, check if we have more data than just our frame's message length - if src.len() <= LEN_SIZE { - return Ok(None); - } - - // Second, retrieve total size of our frame's message - let msg_len = u64::from_be_bytes(src[..LEN_SIZE].try_into().unwrap()) as usize; - if msg_len == 0 { - // Ensure we advance to remove the frame - src.advance(LEN_SIZE); - - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Frame's msg cannot have length of 0", - )); - } - - // Third, check if we have all data for our frame; if not, exit early - if src.len() < msg_len + LEN_SIZE { - return Ok(None); - } - - // Fourth, get and return our item - let item = src[LEN_SIZE..(LEN_SIZE + msg_len)].to_vec(); - - // Fifth, advance so frame is no longer kept around - src.advance(LEN_SIZE + msg_len); - - Ok(Some(item)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn encode_should_fail_when_item_is_zero_bytes() { - let mut codec = PlainCodec::new(); - - let mut buf = BytesMut::new(); - let result = codec.encode(&[], &mut buf); - - match result { - Err(x) if x.kind() == io::ErrorKind::InvalidInput => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[test] - fn encode_should_build_a_frame_containing_a_length_and_item() { - let mut codec = PlainCodec::new(); - - let mut buf = BytesMut::new(); - codec - .encode(b"hello, world", &mut buf) - .expect("Failed to encode"); - - let len = buf.get_u64() as usize; - assert_eq!(len, 12, "Wrong length encoded"); - assert_eq!(buf.as_ref(), b"hello, world"); - } - - #[test] - fn decode_should_return_none_if_data_smaller_than_or_equal_to_item_length_field() { - let mut codec = PlainCodec::new(); - - let mut buf = BytesMut::new(); - buf.put_bytes(0, LEN_SIZE); - - let result = codec.decode(&mut buf); - assert!( - matches!(result, Ok(None)), - "Unexpected result: {:?}", - result - ); - } - - #[test] - fn decode_should_return_none_if_not_enough_data_for_frame() { - let mut codec = PlainCodec::new(); - - let mut buf = BytesMut::new(); - buf.put_u64(0); - - let result = codec.decode(&mut buf); - assert!( - matches!(result, Ok(None)), - "Unexpected result: {:?}", - result - ); - } - - #[test] - fn decode_should_fail_if_encoded_item_length_is_zero() { - let mut codec = PlainCodec::new(); - - let mut buf = BytesMut::new(); - buf.put_u64(0); - buf.put_u8(255); - - let result = codec.decode(&mut buf); - match result { - Err(x) if x.kind() == io::ErrorKind::InvalidData => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[test] - fn decode_should_advance_src_by_frame_size_even_if_item_length_is_zero() { - let mut codec = PlainCodec::new(); - - let mut buf = BytesMut::new(); - buf.put_u64(0); - buf.put_bytes(0, 3); - - assert!( - codec.decode(&mut buf).is_err(), - "Decode unexpectedly succeeded" - ); - assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf"); - } - - #[test] - fn decode_should_advance_src_by_frame_size_when_successful() { - let mut codec = PlainCodec::new(); - - // Add 3 extra bytes after a full frame - let mut buf = BytesMut::new(); - codec - .encode(b"hello, world", &mut buf) - .expect("Failed to encode"); - buf.put_bytes(0, 3); - - assert!(codec.decode(&mut buf).is_ok(), "Decode unexpectedly failed"); - assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf"); - } - - #[test] - fn decode_should_return_some_byte_vec_when_successful() { - let mut codec = PlainCodec::new(); - - let mut buf = BytesMut::new(); - codec - .encode(b"hello, world", &mut buf) - .expect("Failed to encode"); - - let item = codec - .decode(&mut buf) - .expect("Failed to decode") - .expect("Item not properly captured"); - assert_eq!(item, b"hello, world"); - } -} diff --git a/distant-net/src/codec/xchacha20poly1305.rs b/distant-net/src/codec/xchacha20poly1305.rs deleted file mode 100644 index 8dcc627..0000000 --- a/distant-net/src/codec/xchacha20poly1305.rs +++ /dev/null @@ -1,269 +0,0 @@ -use crate::{Codec, SecretKey, SecretKey32}; -use bytes::{Buf, BufMut, BytesMut}; -use chacha20poly1305::{aead::Aead, Key, KeyInit, XChaCha20Poly1305, XNonce}; -use std::{convert::TryInto, fmt}; -use tokio::io; - -/// Total bytes to use as the len field denoting a frame's size -const LEN_SIZE: usize = 8; - -/// Total bytes to use for nonce -const NONCE_SIZE: usize = 24; - -/// Represents the codec to encode & decode data while also encrypting/decrypting it -/// -/// Uses a 32-byte key internally -#[derive(Clone)] -pub struct XChaCha20Poly1305Codec { - cipher: XChaCha20Poly1305, -} -impl_traits_for_codec!(XChaCha20Poly1305Codec); - -impl XChaCha20Poly1305Codec { - pub fn new(key: &[u8]) -> Self { - let key = Key::from_slice(key); - let cipher = XChaCha20Poly1305::new(key); - Self { cipher } - } -} - -impl From for XChaCha20Poly1305Codec { - /// Create a new XChaCha20Poly1305 codec with a 32-byte key - fn from(secret_key: SecretKey32) -> Self { - Self::new(secret_key.unprotected_as_bytes()) - } -} - -impl fmt::Debug for XChaCha20Poly1305Codec { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("XChaCha20Poly1305Codec") - .field("cipher", &"**OMITTED**".to_string()) - .finish() - } -} - -impl Codec for XChaCha20Poly1305Codec { - fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> { - // Validate that we can fit the message plus nonce + - if item.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Empty item provided", - )); - } - // NOTE: As seen in orion, with a 24-bit nonce, it's safe to generate instead of - // maintaining a stateful counter due to its size (24-byte secret key generation - // will never panic) - let nonce_key = SecretKey::::generate().unwrap(); - let nonce = XNonce::from_slice(nonce_key.unprotected_as_bytes()); - - let ciphertext = self - .cipher - .encrypt(nonce, item) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?; - - dst.reserve(8 + nonce.len() + ciphertext.len()); - - // Add data in form of {LEN}{NONCE}{CIPHER TEXT} - dst.put_u64((nonce_key.len() + ciphertext.len()) as u64); - dst.put_slice(nonce.as_slice()); - dst.extend(ciphertext); - - Ok(()) - } - - fn decode(&mut self, src: &mut BytesMut) -> io::Result>> { - // First, check if we have more data than just our frame's message length - if src.len() <= LEN_SIZE { - return Ok(None); - } - - // Second, retrieve total size of our frame's message - let msg_len = u64::from_be_bytes(src[..LEN_SIZE].try_into().unwrap()) as usize; - if msg_len <= NONCE_SIZE { - // Ensure we advance to remove the frame - src.advance(LEN_SIZE + msg_len); - - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Frame's msg cannot have length less than 25", - )); - } - - // Third, check if we have all data for our frame; if not, exit early - if src.len() < msg_len + LEN_SIZE { - return Ok(None); - } - - // Fourth, retrieve the nonce used with the ciphertext - let nonce = XNonce::from_slice(&src[LEN_SIZE..(NONCE_SIZE + LEN_SIZE)]); - - // Fifth, acquire the encrypted & signed ciphertext - let ciphertext = &src[(NONCE_SIZE + LEN_SIZE)..(msg_len + LEN_SIZE)]; - - // Sixth, convert ciphertext back into our item - let item = self.cipher.decrypt(nonce, ciphertext); - - // Seventh, advance so frame is no longer kept around - src.advance(LEN_SIZE + msg_len); - - // Eighth, report an error if there is one - let item = - item.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?; - - Ok(Some(item)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn encode_should_fail_when_item_is_zero_bytes() { - let key = SecretKey32::default(); - let mut codec = XChaCha20Poly1305Codec::from(key); - - let mut buf = BytesMut::new(); - let result = codec.encode(&[], &mut buf); - - match result { - Err(x) if x.kind() == io::ErrorKind::InvalidInput => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[test] - fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() { - let key = SecretKey32::default(); - let mut codec = XChaCha20Poly1305Codec::from(key); - - let mut buf = BytesMut::new(); - codec - .encode(b"hello, world", &mut buf) - .expect("Failed to encode"); - - let len = buf.get_u64() as usize; - assert!(buf.len() > NONCE_SIZE, "Msg size not big enough"); - assert_eq!(len, buf.len(), "Msg size does not match attached size"); - } - - #[test] - fn decode_should_return_none_if_data_smaller_than_or_equal_to_frame_length_field() { - let key = SecretKey32::default(); - let mut codec = XChaCha20Poly1305Codec::from(key); - - let mut buf = BytesMut::new(); - buf.put_bytes(0, LEN_SIZE); - - let result = codec.decode(&mut buf); - assert!( - matches!(result, Ok(None)), - "Unexpected result: {:?}", - result - ); - } - - #[test] - fn decode_should_return_none_if_not_enough_data_for_frame() { - let key = SecretKey32::default(); - let mut codec = XChaCha20Poly1305Codec::from(key); - - let mut buf = BytesMut::new(); - buf.put_u64(0); - - let result = codec.decode(&mut buf); - assert!( - matches!(result, Ok(None)), - "Unexpected result: {:?}", - result - ); - } - - #[test] - fn decode_should_fail_if_encoded_frame_length_is_smaller_than_nonce_plus_data() { - let key = SecretKey32::default(); - let mut codec = XChaCha20Poly1305Codec::from(key); - - // NONCE_SIZE + 1 is minimum for frame length - let mut buf = BytesMut::new(); - buf.put_u64(NONCE_SIZE as u64); - buf.put_bytes(0, NONCE_SIZE); - - let result = codec.decode(&mut buf); - match result { - Err(x) if x.kind() == io::ErrorKind::InvalidData => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[test] - fn decode_should_advance_src_by_frame_size_even_if_frame_length_is_too_small() { - let key = SecretKey32::default(); - let mut codec = XChaCha20Poly1305Codec::from(key); - - // LEN_SIZE + NONCE_SIZE + msg not matching encryption + 3 more bytes - let mut buf = BytesMut::new(); - buf.put_u64(NONCE_SIZE as u64); - buf.put_bytes(0, NONCE_SIZE); - buf.put_bytes(0, 3); - - assert!( - codec.decode(&mut buf).is_err(), - "Decode unexpectedly succeeded" - ); - assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf"); - } - - #[test] - fn decode_should_advance_src_by_frame_size_even_if_decryption_fails() { - let key = SecretKey32::default(); - let mut codec = XChaCha20Poly1305Codec::from(key); - - // LEN_SIZE + NONCE_SIZE + msg not matching encryption + 3 more bytes - let mut buf = BytesMut::new(); - buf.put_u64((NONCE_SIZE + 12) as u64); - buf.put_bytes(0, NONCE_SIZE); - buf.put_slice(b"hello, world"); - buf.put_bytes(0, 3); - - assert!( - codec.decode(&mut buf).is_err(), - "Decode unexpectedly succeeded" - ); - assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf"); - } - - #[test] - fn decode_should_advance_src_by_frame_size_when_successful() { - let key = SecretKey32::default(); - let mut codec = XChaCha20Poly1305Codec::from(key); - - // Add 3 extra bytes after a full frame - let mut buf = BytesMut::new(); - codec - .encode(b"hello, world", &mut buf) - .expect("Failed to encode"); - buf.put_bytes(0, 3); - - assert!(codec.decode(&mut buf).is_ok(), "Decode unexpectedly failed"); - assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf"); - } - - #[test] - fn decode_should_return_some_byte_vec_when_successful() { - let key = SecretKey32::default(); - let mut codec = XChaCha20Poly1305Codec::from(key); - - let mut buf = BytesMut::new(); - codec - .encode(b"hello, world", &mut buf) - .expect("Failed to encode"); - - let item = codec - .decode(&mut buf) - .expect("Failed to decode") - .expect("Item not properly captured"); - assert_eq!(item, b"hello, world"); - } -} diff --git a/distant-net/src/common.rs b/distant-net/src/common.rs new file mode 100644 index 0000000..7f9ba94 --- /dev/null +++ b/distant-net/src/common.rs @@ -0,0 +1,20 @@ +mod any; +pub mod authentication; +mod connection; +mod destination; +mod listener; +mod map; +mod packet; +mod port; +mod transport; +pub(crate) mod utils; + +pub use any::*; +pub(crate) use connection::Connection; +pub use connection::ConnectionId; +pub use destination::*; +pub use listener::*; +pub use map::*; +pub use packet::*; +pub use port::*; +pub use transport::*; diff --git a/distant-net/src/any.rs b/distant-net/src/common/any.rs similarity index 100% rename from distant-net/src/any.rs rename to distant-net/src/common/any.rs diff --git a/distant-net/src/common/authentication.rs b/distant-net/src/common/authentication.rs new file mode 100644 index 0000000..2a18ccd --- /dev/null +++ b/distant-net/src/common/authentication.rs @@ -0,0 +1,10 @@ +mod authenticator; +mod handler; +mod keychain; +mod methods; +pub mod msg; + +pub use authenticator::*; +pub use handler::*; +pub use keychain::*; +pub use methods::*; diff --git a/distant-net/src/common/authentication/authenticator.rs b/distant-net/src/common/authentication/authenticator.rs new file mode 100644 index 0000000..a37555f --- /dev/null +++ b/distant-net/src/common/authentication/authenticator.rs @@ -0,0 +1,672 @@ +use super::{msg::*, AuthHandler}; +use crate::common::{utils, FramedTransport, Transport}; +use async_trait::async_trait; +use log::*; +use std::io; + +/// Represents an interface for authenticating with a server. +#[async_trait] +pub trait Authenticate { + /// Performs authentication by leveraging the `handler` for any received challenge. + async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()>; +} + +/// Represents an interface for submitting challenges for authentication. +#[async_trait] +pub trait Authenticator: Send { + /// Issues an initialization notice and returns the response indicating which authentication + /// methods to pursue + async fn initialize( + &mut self, + initialization: Initialization, + ) -> io::Result; + + /// Issues a challenge and returns the answers to the `questions` asked. + async fn challenge(&mut self, challenge: Challenge) -> io::Result; + + /// Requests verification of some `kind` and `text`, returning true if passed verification. + async fn verify(&mut self, verification: Verification) -> io::Result; + + /// Reports information with no response expected. + async fn info(&mut self, info: Info) -> io::Result<()>; + + /// Reports an error occurred during authentication, consuming the authenticator since no more + /// challenges should be issued. + async fn error(&mut self, error: Error) -> io::Result<()>; + + /// Reports that the authentication has started for a specific method. + async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()>; + + /// Reports that the authentication has finished successfully, consuming the authenticator + /// since no more challenges should be issued. + async fn finished(&mut self) -> io::Result<()>; +} + +macro_rules! write_frame { + ($transport:expr, $data:expr) => {{ + let data = utils::serialize_to_vec(&$data)?; + if log_enabled!(Level::Trace) { + trace!("Writing data as frame: {data:?}"); + } + + $transport.write_frame(data).await? + }}; +} + +macro_rules! next_frame_as { + ($transport:expr, $type:ident, $variant:ident) => {{ + match { next_frame_as!($transport, $type) } { + $type::$variant(x) => x, + x => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Unexpected frame: {x:?}"), + )) + } + } + }}; + ($transport:expr, $type:ident) => {{ + let frame = $transport.read_frame().await?.ok_or_else(|| { + io::Error::new( + io::ErrorKind::UnexpectedEof, + concat!( + "Transport closed early waiting for frame of type ", + stringify!($type), + ), + ) + })?; + + match utils::deserialize_from_slice::<$type>(frame.as_item()) { + Ok(frame) => frame, + Err(x) => { + if log_enabled!(Level::Trace) { + trace!( + "Failed to deserialize frame item as {}: {:?}", + stringify!($type), + frame.as_item() + ); + } + + Err(x)?; + unreachable!(); + } + } + }}; +} + +#[async_trait] +impl Authenticate for FramedTransport +where + T: Transport, +{ + async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()> { + loop { + trace!("Authenticate::authenticate waiting on next authentication frame"); + match next_frame_as!(self, Authentication) { + Authentication::Initialization(x) => { + trace!("Authenticate::Initialization({x:?})"); + let response = handler.on_initialization(x).await?; + write_frame!(self, AuthenticationResponse::Initialization(response)); + } + Authentication::Challenge(x) => { + trace!("Authenticate::Challenge({x:?})"); + let response = handler.on_challenge(x).await?; + write_frame!(self, AuthenticationResponse::Challenge(response)); + } + Authentication::Verification(x) => { + trace!("Authenticate::Verify({x:?})"); + let response = handler.on_verification(x).await?; + write_frame!(self, AuthenticationResponse::Verification(response)); + } + Authentication::Info(x) => { + trace!("Authenticate::Info({x:?})"); + handler.on_info(x).await?; + } + Authentication::Error(x) => { + trace!("Authenticate::Error({x:?})"); + handler.on_error(x.clone()).await?; + + if x.is_fatal() { + return Err(x.into_io_permission_denied()); + } + } + Authentication::StartMethod(x) => { + trace!("Authenticate::StartMethod({x:?})"); + handler.on_start_method(x).await?; + } + Authentication::Finished => { + trace!("Authenticate::Finished"); + handler.on_finished().await?; + return Ok(()); + } + } + } + } +} + +#[async_trait] +impl Authenticator for FramedTransport +where + T: Transport, +{ + async fn initialize( + &mut self, + initialization: Initialization, + ) -> io::Result { + trace!("Authenticator::initialize({initialization:?})"); + write_frame!(self, Authentication::Initialization(initialization)); + let response = next_frame_as!(self, AuthenticationResponse, Initialization); + Ok(response) + } + + async fn challenge(&mut self, challenge: Challenge) -> io::Result { + trace!("Authenticator::challenge({challenge:?})"); + write_frame!(self, Authentication::Challenge(challenge)); + let response = next_frame_as!(self, AuthenticationResponse, Challenge); + Ok(response) + } + + async fn verify(&mut self, verification: Verification) -> io::Result { + trace!("Authenticator::verify({verification:?})"); + write_frame!(self, Authentication::Verification(verification)); + let response = next_frame_as!(self, AuthenticationResponse, Verification); + Ok(response) + } + + async fn info(&mut self, info: Info) -> io::Result<()> { + trace!("Authenticator::info({info:?})"); + write_frame!(self, Authentication::Info(info)); + Ok(()) + } + + async fn error(&mut self, error: Error) -> io::Result<()> { + trace!("Authenticator::error({error:?})"); + write_frame!(self, Authentication::Error(error)); + Ok(()) + } + + async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()> { + trace!("Authenticator::start_method({start_method:?})"); + write_frame!(self, Authentication::StartMethod(start_method)); + Ok(()) + } + + async fn finished(&mut self) -> io::Result<()> { + trace!("Authenticator::finished()"); + write_frame!(self, Authentication::Finished); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::authentication::AuthMethodHandler; + use test_log::test; + use tokio::sync::mpsc; + + #[async_trait] + trait TestAuthHandler { + async fn on_initialization( + &mut self, + _: Initialization, + ) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + async fn on_start_method(&mut self, _: StartMethod) -> io::Result<()> { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + async fn on_finished(&mut self) -> io::Result<()> { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + async fn on_challenge(&mut self, _: Challenge) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + async fn on_verification(&mut self, _: Verification) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + async fn on_info(&mut self, _: Info) -> io::Result<()> { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + async fn on_error(&mut self, _: Error) -> io::Result<()> { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + } + + #[async_trait] + impl AuthHandler for T { + async fn on_initialization( + &mut self, + x: Initialization, + ) -> io::Result { + TestAuthHandler::on_initialization(self, x).await + } + + async fn on_start_method(&mut self, x: StartMethod) -> io::Result<()> { + TestAuthHandler::on_start_method(self, x).await + } + + async fn on_finished(&mut self) -> io::Result<()> { + TestAuthHandler::on_finished(self).await + } + } + + #[async_trait] + impl AuthMethodHandler for T { + async fn on_challenge(&mut self, x: Challenge) -> io::Result { + TestAuthHandler::on_challenge(self, x).await + } + + async fn on_verification(&mut self, x: Verification) -> io::Result { + TestAuthHandler::on_verification(self, x).await + } + + async fn on_info(&mut self, x: Info) -> io::Result<()> { + TestAuthHandler::on_info(self, x).await + } + + async fn on_error(&mut self, x: Error) -> io::Result<()> { + TestAuthHandler::on_error(self, x).await + } + } + + macro_rules! auth_handler { + (@no_challenge @no_verification @tx($tx:ident, $ty:ty) $($methods:item)*) => { + auth_handler! { + @tx($tx, $ty) + + async fn on_challenge(&mut self, _: Challenge) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + async fn on_verification( + &mut self, + _: Verification, + ) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + $($methods)* + } + }; + (@no_challenge @tx($tx:ident, $ty:ty) $($methods:item)*) => { + auth_handler! { + @tx($tx, $ty) + + async fn on_challenge(&mut self, _: Challenge) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + $($methods)* + } + }; + (@no_verification @tx($tx:ident, $ty:ty) $($methods:item)*) => { + auth_handler! { + @tx($tx, $ty) + + async fn on_verification( + &mut self, + _: Verification, + ) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + $($methods)* + } + }; + (@tx($tx:ident, $ty:ty) $($methods:item)*) => {{ + #[allow(dead_code)] + struct __InlineAuthHandler { + tx: mpsc::Sender<$ty>, + } + + #[async_trait] + impl TestAuthHandler for __InlineAuthHandler { + $($methods)* + } + + __InlineAuthHandler { tx: $tx } + }}; + } + + #[test(tokio::test)] + async fn authenticator_initialization_should_be_able_to_successfully_complete_round_trip() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + let (tx, _) = mpsc::channel(1); + + let task = tokio::spawn(async move { + t2.authenticate(auth_handler! { + @no_challenge + @no_verification + @tx(tx, ()) + + async fn on_initialization( + &mut self, + initialization: Initialization, + ) -> io::Result { + Ok(InitializationResponse { + methods: initialization.methods, + }) + } + }) + .await + .unwrap() + }); + + let response = t1 + .initialize(Initialization { + methods: vec!["test method".to_string()].into_iter().collect(), + }) + .await + .unwrap(); + + assert!( + !task.is_finished(), + "Auth handler unexpectedly finished without signal" + ); + + assert_eq!( + response, + InitializationResponse { + methods: vec!["test method".to_string()].into_iter().collect() + } + ); + } + + #[test(tokio::test)] + async fn authenticator_challenge_should_be_able_to_successfully_complete_round_trip() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + let (tx, _) = mpsc::channel(1); + + let task = tokio::spawn(async move { + t2.authenticate(auth_handler! { + @no_verification + @tx(tx, ()) + + async fn on_challenge(&mut self, challenge: Challenge) -> io::Result { + assert_eq!(challenge.questions, vec![Question { + label: "label".to_string(), + text: "text".to_string(), + options: vec![("question_key".to_string(), "question_value".to_string())] + .into_iter() + .collect(), + }]); + assert_eq!( + challenge.options, + vec![("key".to_string(), "value".to_string())].into_iter().collect(), + ); + Ok(ChallengeResponse { + answers: vec!["some answer".to_string()].into_iter().collect(), + }) + } + }) + .await + .unwrap() + }); + + let response = t1 + .challenge(Challenge { + questions: vec![Question { + label: "label".to_string(), + text: "text".to_string(), + options: vec![("question_key".to_string(), "question_value".to_string())] + .into_iter() + .collect(), + }], + options: vec![("key".to_string(), "value".to_string())] + .into_iter() + .collect(), + }) + .await + .unwrap(); + + assert!( + !task.is_finished(), + "Auth handler unexpectedly finished without signal" + ); + + assert_eq!( + response, + ChallengeResponse { + answers: vec!["some answer".to_string()], + } + ); + } + + #[test(tokio::test)] + async fn authenticator_verification_should_be_able_to_successfully_complete_round_trip() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + let (tx, _) = mpsc::channel(1); + + let task = tokio::spawn(async move { + t2.authenticate(auth_handler! { + @no_challenge + @tx(tx, ()) + + async fn on_verification( + &mut self, + verification: Verification, + ) -> io::Result { + assert_eq!(verification.kind, VerificationKind::Host); + assert_eq!(verification.text, "some text"); + Ok(VerificationResponse { + valid: true, + }) + } + }) + .await + .unwrap() + }); + + let response = t1 + .verify(Verification { + kind: VerificationKind::Host, + text: "some text".to_string(), + }) + .await + .unwrap(); + + assert!( + !task.is_finished(), + "Auth handler unexpectedly finished without signal" + ); + + assert_eq!(response, VerificationResponse { valid: true }); + } + + #[test(tokio::test)] + async fn authenticator_info_should_be_able_to_be_sent_to_auth_handler() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + let (tx, mut rx) = mpsc::channel(1); + + let task = tokio::spawn(async move { + t2.authenticate(auth_handler! { + @no_challenge + @no_verification + @tx(tx, Info) + + async fn on_info( + &mut self, + info: Info, + ) -> io::Result<()> { + self.tx.send(info).await.unwrap(); + Ok(()) + } + }) + .await + .unwrap() + }); + + t1.info(Info { + text: "some text".to_string(), + }) + .await + .unwrap(); + + assert_eq!( + rx.recv().await.unwrap(), + Info { + text: "some text".to_string() + } + ); + + assert!( + !task.is_finished(), + "Auth handler unexpectedly finished without signal" + ); + } + + #[test(tokio::test)] + async fn authenticator_error_should_be_able_to_be_sent_to_auth_handler() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + let (tx, mut rx) = mpsc::channel(1); + + let task = tokio::spawn(async move { + t2.authenticate(auth_handler! { + @no_challenge + @no_verification + @tx(tx, Error) + + async fn on_error(&mut self, error: Error) -> io::Result<()> { + self.tx.send(error).await.unwrap(); + Ok(()) + } + }) + .await + .unwrap() + }); + + t1.error(Error { + kind: ErrorKind::Error, + text: "some text".to_string(), + }) + .await + .unwrap(); + + assert_eq!( + rx.recv().await.unwrap(), + Error { + kind: ErrorKind::Error, + text: "some text".to_string(), + } + ); + + assert!( + !task.is_finished(), + "Auth handler unexpectedly finished without signal" + ); + } + + #[test(tokio::test)] + async fn auth_handler_received_error_should_fail_auth_handler_if_fatal() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + let (tx, mut rx) = mpsc::channel(1); + + let task = tokio::spawn(async move { + t2.authenticate(auth_handler! { + @no_challenge + @no_verification + @tx(tx, Error) + + async fn on_error(&mut self, error: Error) -> io::Result<()> { + self.tx.send(error).await.unwrap(); + Ok(()) + } + }) + .await + .unwrap() + }); + + t1.error(Error { + kind: ErrorKind::Fatal, + text: "some text".to_string(), + }) + .await + .unwrap(); + + assert_eq!( + rx.recv().await.unwrap(), + Error { + kind: ErrorKind::Fatal, + text: "some text".to_string(), + } + ); + + // Verify that the handler exited with an error + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn authenticator_start_method_should_be_able_to_be_sent_to_auth_handler() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + let (tx, mut rx) = mpsc::channel(1); + + let task = tokio::spawn(async move { + t2.authenticate(auth_handler! { + @no_challenge + @no_verification + @tx(tx, StartMethod) + + async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> { + self.tx.send(start_method).await.unwrap(); + Ok(()) + } + }) + .await + .unwrap() + }); + + t1.start_method(StartMethod { + method: "some method".to_string(), + }) + .await + .unwrap(); + + assert_eq!( + rx.recv().await.unwrap(), + StartMethod { + method: "some method".to_string() + } + ); + + assert!( + !task.is_finished(), + "Auth handler unexpectedly finished without signal" + ); + } + + #[test(tokio::test)] + async fn authenticator_finished_should_be_able_to_be_sent_to_auth_handler() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + let (tx, mut rx) = mpsc::channel(1); + + let task = tokio::spawn(async move { + t2.authenticate(auth_handler! { + @no_challenge + @no_verification + @tx(tx, ()) + + async fn on_finished(&mut self) -> io::Result<()> { + self.tx.send(()).await.unwrap(); + Ok(()) + } + }) + .await + .unwrap() + }); + + t1.finished().await.unwrap(); + + // Verify that the callback was triggered + rx.recv().await.unwrap(); + + // Finished should signal that the handler completed successfully + task.await.unwrap(); + } +} diff --git a/distant-net/src/common/authentication/handler.rs b/distant-net/src/common/authentication/handler.rs new file mode 100644 index 0000000..945bf6d --- /dev/null +++ b/distant-net/src/common/authentication/handler.rs @@ -0,0 +1,343 @@ +use super::msg::*; +use crate::common::authentication::Authenticator; +use crate::common::HeapSecretKey; +use async_trait::async_trait; +use std::collections::HashMap; +use std::io; + +mod methods; +pub use methods::*; + +/// Interface for a handler of authentication requests for all methods. +#[async_trait] +pub trait AuthHandler: AuthMethodHandler + Send { + /// Callback when authentication is beginning, providing available authentication methods and + /// returning selected authentication methods to pursue. + async fn on_initialization( + &mut self, + initialization: Initialization, + ) -> io::Result { + Ok(InitializationResponse { + methods: initialization.methods, + }) + } + + /// Callback when authentication starts for a specific method. + #[allow(unused_variables)] + async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> { + Ok(()) + } + + /// Callback when authentication is finished and no more requests will be received. + async fn on_finished(&mut self) -> io::Result<()> { + Ok(()) + } +} + +/// Dummy implementation of [`AuthHandler`] where any challenge or verification request will +/// instantly fail. +pub struct DummyAuthHandler; + +#[async_trait] +impl AuthHandler for DummyAuthHandler {} + +#[async_trait] +impl AuthMethodHandler for DummyAuthHandler { + async fn on_challenge(&mut self, _: Challenge) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + async fn on_verification(&mut self, _: Verification) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + async fn on_info(&mut self, _: Info) -> io::Result<()> { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + async fn on_error(&mut self, _: Error) -> io::Result<()> { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } +} + +/// Implementation of [`AuthHandler`] that uses the same [`AuthMethodHandler`] for all methods. +pub struct SingleAuthHandler(Box); + +impl SingleAuthHandler { + pub fn new(method_handler: T) -> Self { + Self(Box::new(method_handler)) + } +} + +#[async_trait] +impl AuthHandler for SingleAuthHandler {} + +#[async_trait] +impl AuthMethodHandler for SingleAuthHandler { + async fn on_challenge(&mut self, challenge: Challenge) -> io::Result { + self.0.on_challenge(challenge).await + } + + async fn on_verification( + &mut self, + verification: Verification, + ) -> io::Result { + self.0.on_verification(verification).await + } + + async fn on_info(&mut self, info: Info) -> io::Result<()> { + self.0.on_info(info).await + } + + async fn on_error(&mut self, error: Error) -> io::Result<()> { + self.0.on_error(error).await + } +} + +/// Implementation of [`AuthHandler`] that maintains a map of [`AuthMethodHandler`] implementations +/// for specific methods, invoking [`on_challenge`], [`on_verification`], [`on_info`], and +/// [`on_error`] for a specific handler based on an associated id. +/// +/// [`on_challenge`]: AuthMethodHandler::on_challenge +/// [`on_verification`]: AuthMethodHandler::on_verification +/// [`on_info`]: AuthMethodHandler::on_info +/// [`on_error`]: AuthMethodHandler::on_error +pub struct AuthHandlerMap { + active: String, + map: HashMap<&'static str, Box>, +} + +impl AuthHandlerMap { + /// Creates a new, empty map of auth method handlers. + pub fn new() -> Self { + Self { + active: String::new(), + map: HashMap::new(), + } + } + + /// Returns the `id` of the active [`AuthMethodHandler`]. + pub fn active_id(&self) -> &str { + &self.active + } + + /// Sets the active [`AuthMethodHandler`] by its `id`. + pub fn set_active_id(&mut self, id: impl Into) { + self.active = id.into(); + } + + /// Inserts the specified `handler` into the map, associating it with `id` for determining the + /// method that would trigger this handler. + pub fn insert_method_handler( + &mut self, + id: &'static str, + handler: T, + ) -> Option> { + self.map.insert(id, Box::new(handler)) + } + + /// Removes a handler with the associated `id`. + pub fn remove_method_handler( + &mut self, + id: &'static str, + ) -> Option> { + self.map.remove(id) + } + + /// Retrieves a mutable reference to the active [`AuthMethodHandler`] with the specified `id`, + /// returning an error if no handler for the active id is found. + pub fn get_mut_active_method_handler_or_error( + &mut self, + ) -> io::Result<&mut (dyn AuthMethodHandler + 'static)> { + let id = self.active.clone(); + self.get_mut_active_method_handler().ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, format!("No active handler for {id}")) + }) + } + + /// Retrieves a mutable reference to the active [`AuthMethodHandler`] with the specified `id`. + pub fn get_mut_active_method_handler( + &mut self, + ) -> Option<&mut (dyn AuthMethodHandler + 'static)> { + // TODO: Optimize this + self.get_mut_method_handler(&self.active.clone()) + } + + /// Retrieves a mutable reference to the [`AuthMethodHandler`] with the specified `id`. + pub fn get_mut_method_handler( + &mut self, + id: &str, + ) -> Option<&mut (dyn AuthMethodHandler + 'static)> { + self.map.get_mut(id).map(|h| h.as_mut()) + } +} + +impl AuthHandlerMap { + /// Consumes the map, returning a new map that supports the `static_key` method. + pub fn with_static_key(mut self, key: impl Into) -> Self { + self.insert_method_handler("static_key", StaticKeyAuthMethodHandler::simple(key)); + self + } +} + +impl Default for AuthHandlerMap { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl AuthHandler for AuthHandlerMap { + async fn on_initialization( + &mut self, + initialization: Initialization, + ) -> io::Result { + let methods = initialization + .methods + .into_iter() + .filter(|method| self.map.contains_key(method.as_str())) + .collect(); + + Ok(InitializationResponse { methods }) + } + + async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> { + self.set_active_id(start_method.method); + Ok(()) + } + + async fn on_finished(&mut self) -> io::Result<()> { + Ok(()) + } +} + +#[async_trait] +impl AuthMethodHandler for AuthHandlerMap { + async fn on_challenge(&mut self, challenge: Challenge) -> io::Result { + let handler = self.get_mut_active_method_handler_or_error()?; + handler.on_challenge(challenge).await + } + + async fn on_verification( + &mut self, + verification: Verification, + ) -> io::Result { + let handler = self.get_mut_active_method_handler_or_error()?; + handler.on_verification(verification).await + } + + async fn on_info(&mut self, info: Info) -> io::Result<()> { + let handler = self.get_mut_active_method_handler_or_error()?; + handler.on_info(info).await + } + + async fn on_error(&mut self, error: Error) -> io::Result<()> { + let handler = self.get_mut_active_method_handler_or_error()?; + handler.on_error(error).await + } +} + +/// Implementation of [`AuthHandler`] that redirects all requests to an [`Authenticator`]. +pub struct ProxyAuthHandler<'a>(&'a mut dyn Authenticator); + +impl<'a> ProxyAuthHandler<'a> { + pub fn new(authenticator: &'a mut dyn Authenticator) -> Self { + Self(authenticator) + } +} + +#[async_trait] +impl<'a> AuthHandler for ProxyAuthHandler<'a> { + async fn on_initialization( + &mut self, + initialization: Initialization, + ) -> io::Result { + Authenticator::initialize(self.0, initialization).await + } + + async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> { + Authenticator::start_method(self.0, start_method).await + } + + async fn on_finished(&mut self) -> io::Result<()> { + Authenticator::finished(self.0).await + } +} + +#[async_trait] +impl<'a> AuthMethodHandler for ProxyAuthHandler<'a> { + async fn on_challenge(&mut self, challenge: Challenge) -> io::Result { + Authenticator::challenge(self.0, challenge).await + } + + async fn on_verification( + &mut self, + verification: Verification, + ) -> io::Result { + Authenticator::verify(self.0, verification).await + } + + async fn on_info(&mut self, info: Info) -> io::Result<()> { + Authenticator::info(self.0, info).await + } + + async fn on_error(&mut self, error: Error) -> io::Result<()> { + Authenticator::error(self.0, error).await + } +} + +/// Implementation of [`AuthHandler`] that holds a mutable reference to another [`AuthHandler`] +/// trait object to use underneath. +pub struct DynAuthHandler<'a>(&'a mut dyn AuthHandler); + +impl<'a> DynAuthHandler<'a> { + pub fn new(handler: &'a mut dyn AuthHandler) -> Self { + Self(handler) + } +} + +impl<'a, T: AuthHandler> From<&'a mut T> for DynAuthHandler<'a> { + fn from(handler: &'a mut T) -> Self { + Self::new(handler as &mut dyn AuthHandler) + } +} + +#[async_trait] +impl<'a> AuthHandler for DynAuthHandler<'a> { + async fn on_initialization( + &mut self, + initialization: Initialization, + ) -> io::Result { + self.0.on_initialization(initialization).await + } + + async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> { + self.0.on_start_method(start_method).await + } + + async fn on_finished(&mut self) -> io::Result<()> { + self.0.on_finished().await + } +} + +#[async_trait] +impl<'a> AuthMethodHandler for DynAuthHandler<'a> { + async fn on_challenge(&mut self, challenge: Challenge) -> io::Result { + self.0.on_challenge(challenge).await + } + + async fn on_verification( + &mut self, + verification: Verification, + ) -> io::Result { + self.0.on_verification(verification).await + } + + async fn on_info(&mut self, info: Info) -> io::Result<()> { + self.0.on_info(info).await + } + + async fn on_error(&mut self, error: Error) -> io::Result<()> { + self.0.on_error(error).await + } +} diff --git a/distant-net/src/common/authentication/handler/methods.rs b/distant-net/src/common/authentication/handler/methods.rs new file mode 100644 index 0000000..3fda4a7 --- /dev/null +++ b/distant-net/src/common/authentication/handler/methods.rs @@ -0,0 +1,33 @@ +use super::{ + Challenge, ChallengeResponse, Error, Info, Verification, VerificationKind, VerificationResponse, +}; +use async_trait::async_trait; +use std::io; + +/// Interface for a handler of authentication requests for a specific authentication method. +#[async_trait] +pub trait AuthMethodHandler: Send { + /// Callback when a challenge is received, returning answers to the given questions. + async fn on_challenge(&mut self, challenge: Challenge) -> io::Result; + + /// Callback when a verification request is received, returning true if approvided or false if + /// unapproved. + async fn on_verification( + &mut self, + verification: Verification, + ) -> io::Result; + + /// Callback when information is received. To fail, return an error from this function. + async fn on_info(&mut self, info: Info) -> io::Result<()>; + + /// Callback when an error is received. Regardless of the result returned, this will terminate + /// the authenticator. In the situation where a custom error would be preferred, have this + /// callback return an error. + async fn on_error(&mut self, error: Error) -> io::Result<()>; +} + +mod prompt; +pub use prompt::*; + +mod static_key; +pub use static_key::*; diff --git a/distant-net/src/common/authentication/handler/methods/prompt.rs b/distant-net/src/common/authentication/handler/methods/prompt.rs new file mode 100644 index 0000000..8d6ab53 --- /dev/null +++ b/distant-net/src/common/authentication/handler/methods/prompt.rs @@ -0,0 +1,88 @@ +use super::{ + AuthMethodHandler, Challenge, ChallengeResponse, Error, Info, Verification, VerificationKind, + VerificationResponse, +}; +use async_trait::async_trait; +use log::*; +use std::io; + +/// Blocking implementation of [`AuthMethodHandler`] that uses prompts to communicate challenge & +/// verification requests, receiving responses to relay back. +pub struct PromptAuthMethodHandler { + text_prompt: T, + password_prompt: U, +} + +impl PromptAuthMethodHandler { + pub fn new(text_prompt: T, password_prompt: U) -> Self { + Self { + text_prompt, + password_prompt, + } + } +} + +#[async_trait] +impl AuthMethodHandler for PromptAuthMethodHandler +where + T: Fn(&str) -> io::Result + Send + Sync + 'static, + U: Fn(&str) -> io::Result + Send + Sync + 'static, +{ + async fn on_challenge(&mut self, challenge: Challenge) -> io::Result { + trace!("on_challenge({challenge:?})"); + let mut answers = Vec::new(); + for question in challenge.questions.iter() { + // Contains all prompt lines including same line + let mut lines = question.text.split('\n').collect::>(); + + // Line that is prompt on same line as answer + let line = lines.pop().unwrap(); + + // Go ahead and display all other lines + for line in lines.into_iter() { + eprintln!("{}", line); + } + + // Get an answer from user input, or use a blank string as an answer + // if we fail to get input from the user + let answer = (self.password_prompt)(line).unwrap_or_default(); + + answers.push(answer); + } + Ok(ChallengeResponse { answers }) + } + + async fn on_verification( + &mut self, + verification: Verification, + ) -> io::Result { + trace!("on_verify({verification:?})"); + match verification.kind { + VerificationKind::Host => { + eprintln!("{}", verification.text); + + let answer = (self.text_prompt)("Enter [y/N]> ")?; + trace!("Verify? Answer = '{answer}'"); + Ok(VerificationResponse { + valid: matches!(answer.trim(), "y" | "Y" | "yes" | "YES"), + }) + } + x => { + error!("Unsupported verify kind: {x}"); + Ok(VerificationResponse { valid: false }) + } + } + } + + async fn on_info(&mut self, info: Info) -> io::Result<()> { + trace!("on_info({info:?})"); + println!("{}", info.text); + Ok(()) + } + + async fn on_error(&mut self, error: Error) -> io::Result<()> { + trace!("on_error({error:?})"); + eprintln!("{}: {}", error.kind, error.text); + Ok(()) + } +} diff --git a/distant-net/src/common/authentication/handler/methods/static_key.rs b/distant-net/src/common/authentication/handler/methods/static_key.rs new file mode 100644 index 0000000..c6ce17c --- /dev/null +++ b/distant-net/src/common/authentication/handler/methods/static_key.rs @@ -0,0 +1,171 @@ +use super::{ + AuthMethodHandler, Challenge, ChallengeResponse, Error, Info, Verification, + VerificationResponse, +}; +use crate::common::HeapSecretKey; +use async_trait::async_trait; +use log::*; +use std::io; + +/// Implementation of [`AuthMethodHandler`] that answers challenge requests using a static +/// [`HeapSecretKey`]. All other portions of method authentication are handled by another +/// [`AuthMethodHandler`]. +pub struct StaticKeyAuthMethodHandler { + key: HeapSecretKey, + handler: Box, +} + +impl StaticKeyAuthMethodHandler { + /// Creates a new [`StaticKeyAuthMethodHandler`] that responds to challenges using a static + /// `key`. All other requests are passed to the `handler`. + pub fn new(key: impl Into, handler: T) -> Self { + Self { + key: key.into(), + handler: Box::new(handler), + } + } + + /// Creates a new [`StaticKeyAuthMethodHandler`] that responds to challenges using a static + /// `key`. All other requests are passed automatically, meaning that verification is always + /// approvide and info/errors are ignored. + pub fn simple(key: impl Into) -> Self { + Self::new(key, { + struct __AuthMethodHandler; + + #[async_trait] + impl AuthMethodHandler for __AuthMethodHandler { + async fn on_challenge(&mut self, _: Challenge) -> io::Result { + unreachable!("on_challenge should be handled by StaticKeyAuthMethodHandler"); + } + + async fn on_verification( + &mut self, + _: Verification, + ) -> io::Result { + Ok(VerificationResponse { valid: true }) + } + + async fn on_info(&mut self, _: Info) -> io::Result<()> { + Ok(()) + } + + async fn on_error(&mut self, _: Error) -> io::Result<()> { + Ok(()) + } + } + + __AuthMethodHandler + }) + } +} + +#[async_trait] +impl AuthMethodHandler for StaticKeyAuthMethodHandler { + async fn on_challenge(&mut self, challenge: Challenge) -> io::Result { + trace!("on_challenge({challenge:?})"); + let mut answers = Vec::new(); + for question in challenge.questions.iter() { + // Only challenges with a "key" label are allowed, all else will fail + if question.label != "key" { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Only 'key' challenges are supported", + )); + } + answers.push(self.key.to_string()); + } + Ok(ChallengeResponse { answers }) + } + + async fn on_verification( + &mut self, + verification: Verification, + ) -> io::Result { + trace!("on_verify({verification:?})"); + self.handler.on_verification(verification).await + } + + async fn on_info(&mut self, info: Info) -> io::Result<()> { + trace!("on_info({info:?})"); + self.handler.on_info(info).await + } + + async fn on_error(&mut self, error: Error) -> io::Result<()> { + trace!("on_error({error:?})"); + self.handler.on_error(error).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::authentication::msg::{ErrorKind, Question, VerificationKind}; + use test_log::test; + + #[test(tokio::test)] + async fn on_challenge_should_fail_if_non_key_question_received() { + let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap()); + + handler + .on_challenge(Challenge { + questions: vec![Question::new("test")], + options: Default::default(), + }) + .await + .unwrap_err(); + } + + #[test(tokio::test)] + async fn on_challenge_should_answer_with_stringified_key_for_key_questions() { + let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap()); + + let response = handler + .on_challenge(Challenge { + questions: vec![Question::new("key")], + options: Default::default(), + }) + .await + .unwrap(); + assert_eq!(response.answers.len(), 1, "Wrong answer set received"); + assert!(!response.answers[0].is_empty(), "Empty answer being sent"); + } + + #[test(tokio::test)] + async fn on_verification_should_leverage_fallback_handler() { + let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap()); + + let response = handler + .on_verification(Verification { + kind: VerificationKind::Host, + text: "host".to_string(), + }) + .await + .unwrap(); + assert!(response.valid, "Unexpected result from fallback handler"); + } + + #[test(tokio::test)] + async fn on_info_should_leverage_fallback_handler() { + let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap()); + + handler + .on_info(Info { + text: "info".to_string(), + }) + .await + .unwrap(); + } + + #[test(tokio::test)] + async fn on_error_should_leverage_fallback_handler() { + let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap()); + + handler + .on_error(Error { + kind: ErrorKind::Error, + text: "text".to_string(), + }) + .await + .unwrap(); + } +} diff --git a/distant-net/src/common/authentication/keychain.rs b/distant-net/src/common/authentication/keychain.rs new file mode 100644 index 0000000..2018d15 --- /dev/null +++ b/distant-net/src/common/authentication/keychain.rs @@ -0,0 +1,156 @@ +use crate::common::HeapSecretKey; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// Represents the result of a request to the database. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum KeychainResult { + /// Id was not found in the database. + InvalidId, + + /// Password match for an id failed. + InvalidPassword, + + /// Successful match of id and password, removing from keychain and returning data `T`. + Ok(T), +} + +impl KeychainResult { + pub fn is_invalid_id(&self) -> bool { + matches!(self, Self::InvalidId) + } + + pub fn is_invalid_password(&self) -> bool { + matches!(self, Self::InvalidPassword) + } + + pub fn is_invalid(&self) -> bool { + matches!(self, Self::InvalidId | Self::InvalidPassword) + } + + pub fn is_ok(&self) -> bool { + matches!(self, Self::Ok(_)) + } + + pub fn into_ok(self) -> Option { + match self { + Self::Ok(x) => Some(x), + _ => None, + } + } +} + +impl From> for Option { + fn from(result: KeychainResult) -> Self { + result.into_ok() + } +} + +/// Manages keys with associated ids. Cloning will result in a copy pointing to the same underlying +/// storage, which enables support of managing the keys across multiple threads. +#[derive(Debug)] +pub struct Keychain { + map: Arc>>, +} + +impl Clone for Keychain { + fn clone(&self) -> Self { + Self { + map: Arc::clone(&self.map), + } + } +} + +impl Keychain { + /// Creates a new keychain without any keys. + pub fn new() -> Self { + Self { + map: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Stores a new `key` and `data` by a given `id`, returning the old data associated with the + /// id if there was one already registered. + pub async fn insert(&self, id: impl Into, key: HeapSecretKey, data: T) -> Option { + self.map + .write() + .await + .insert(id.into(), (key, data)) + .map(|(_, data)| data) + } + + /// Checks if there is an `id` stored within the keychain. + pub async fn has_id(&self, id: impl AsRef) -> bool { + self.map.read().await.contains_key(id.as_ref()) + } + + /// Checks if there is a key with the given `id` that matches the provided `key`. + pub async fn has_key(&self, id: impl AsRef, key: impl PartialEq) -> bool { + self.map + .read() + .await + .get(id.as_ref()) + .map(|(k, _)| key.eq(k)) + .unwrap_or(false) + } + + /// Removes a key and its data by a given `id`, returning the data if the `id` exists. + pub async fn remove(&self, id: impl AsRef) -> Option { + self.map + .write() + .await + .remove(id.as_ref()) + .map(|(_, data)| data) + } + + /// Checks if there is a key with the given `id` that matches the provided `key`, returning the + /// data if the `id` exists and the `key` matches. + pub async fn remove_if_has_key( + &self, + id: impl AsRef, + key: impl PartialEq, + ) -> KeychainResult { + let id = id.as_ref(); + let mut lock = self.map.write().await; + + match lock.get(id) { + Some((k, _)) if key.eq(k) => KeychainResult::Ok(lock.remove(id).unwrap().1), + Some(_) => KeychainResult::InvalidPassword, + None => KeychainResult::InvalidId, + } + } +} + +impl Keychain<()> { + /// Stores a new `key by a given `id`. + pub async fn put(&self, id: impl Into, key: HeapSecretKey) { + self.insert(id, key, ()).await; + } +} + +impl Default for Keychain { + fn default() -> Self { + Self::new() + } +} + +impl From> for Keychain { + /// Creates a new keychain populated with the provided `map`. + fn from(map: HashMap) -> Self { + Self { + map: Arc::new(RwLock::new(map)), + } + } +} + +impl From> for Keychain<()> { + /// Creates a new keychain populated with the provided `map`. + fn from(map: HashMap) -> Self { + Self::from( + map.into_iter() + .map(|(id, key)| (id, (key, ()))) + .collect::>(), + ) + } +} diff --git a/distant-net/src/common/authentication/methods.rs b/distant-net/src/common/authentication/methods.rs new file mode 100644 index 0000000..ee188bf --- /dev/null +++ b/distant-net/src/common/authentication/methods.rs @@ -0,0 +1,376 @@ +use super::{super::HeapSecretKey, msg::*, Authenticator}; +use async_trait::async_trait; +use log::*; +use std::collections::HashMap; +use std::io; + +mod none; +mod static_key; + +pub use none::*; +pub use static_key::*; + +/// Supports authenticating using a variety of methods +pub struct Verifier { + methods: HashMap<&'static str, Box>, +} + +impl Verifier { + pub fn new(methods: I) -> Self + where + I: IntoIterator>, + { + let mut m = HashMap::new(); + + for method in methods { + m.insert(method.id(), method); + } + + Self { methods: m } + } + + /// Creates a verifier with no methods. + pub fn empty() -> Self { + Self { + methods: HashMap::new(), + } + } + + /// Creates a verifier that uses the [`NoneAuthenticationMethod`] exclusively. + pub fn none() -> Self { + Self::new(vec![ + Box::new(NoneAuthenticationMethod::new()) as Box + ]) + } + + /// Creates a verifier that uses the [`StaticKeyAuthenticationMethod`] exclusively. + pub fn static_key(key: impl Into) -> Self { + Self::new(vec![ + Box::new(StaticKeyAuthenticationMethod::new(key)) as Box + ]) + } + + /// Returns an iterator over the ids of the methods supported by the verifier + pub fn methods(&self) -> impl Iterator + '_ { + self.methods.keys().copied() + } + + /// Attempts to verify by submitting challenges using the `authenticator` provided. Returns the + /// id of the authentication method that succeeded. Fails if no authentication method succeeds. + pub async fn verify(&self, authenticator: &mut dyn Authenticator) -> io::Result<&'static str> { + // Initiate the process to get methods to use + let response = authenticator + .initialize(Initialization { + methods: self.methods.keys().map(ToString::to_string).collect(), + }) + .await?; + + for method in response.methods { + match self.methods.get(method.as_str()) { + Some(method) => { + // Report the authentication method + authenticator + .start_method(StartMethod { + method: method.id().to_string(), + }) + .await?; + + // Perform the actual authentication + if method.authenticate(authenticator).await.is_ok() { + authenticator.finished().await?; + return Ok(method.id()); + } + } + None => { + trace!("Skipping authentication {method} as it is not available or supported"); + } + } + } + + Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "No authentication method succeeded", + )) + } +} + +impl From>> for Verifier { + fn from(methods: Vec>) -> Self { + Self::new(methods) + } +} + +/// Represents an interface to authenticate using some method +#[async_trait] +pub trait AuthenticationMethod: Send + Sync { + /// Returns a unique id to distinguish the method from other methods + fn id(&self) -> &'static str; + + /// Performs authentication using the `authenticator` to submit challenges and other + /// information based on the authentication method + async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()>; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::FramedTransport; + use test_log::test; + + struct SuccessAuthenticationMethod; + + #[async_trait] + impl AuthenticationMethod for SuccessAuthenticationMethod { + fn id(&self) -> &'static str { + "success" + } + + async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> { + Ok(()) + } + } + + struct FailAuthenticationMethod; + + #[async_trait] + impl AuthenticationMethod for FailAuthenticationMethod { + fn id(&self) -> &'static str { + "fail" + } + + async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> { + Err(io::Error::from(io::ErrorKind::Other)) + } + } + + #[test(tokio::test)] + async fn verifier_should_fail_to_verify_if_initialization_fails() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame(b"invalid initialization response") + .await + .unwrap(); + + let methods: Vec> = + vec![Box::new(SuccessAuthenticationMethod)]; + let verifier = Verifier::from(methods); + verifier.verify(&mut t1).await.unwrap_err(); + } + + #[test(tokio::test)] + async fn verifier_should_fail_to_verify_if_fails_to_send_finished_indicator_after_success() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame_for(&AuthenticationResponse::Initialization( + InitializationResponse { + methods: vec![SuccessAuthenticationMethod.id().to_string()] + .into_iter() + .collect(), + }, + )) + .await + .unwrap(); + + // Then drop the transport so it cannot receive anything else + drop(t2); + + let methods: Vec> = + vec![Box::new(SuccessAuthenticationMethod)]; + let verifier = Verifier::from(methods); + assert_eq!( + verifier.verify(&mut t1).await.unwrap_err().kind(), + io::ErrorKind::WriteZero + ); + } + + #[test(tokio::test)] + async fn verifier_should_fail_to_verify_if_has_no_authentication_methods() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame_for(&AuthenticationResponse::Initialization( + InitializationResponse { + methods: vec![SuccessAuthenticationMethod.id().to_string()] + .into_iter() + .collect(), + }, + )) + .await + .unwrap(); + + let methods: Vec> = vec![]; + let verifier = Verifier::from(methods); + verifier.verify(&mut t1).await.unwrap_err(); + } + + #[test(tokio::test)] + async fn verifier_should_fail_to_verify_if_initialization_yields_no_valid_authentication_methods( + ) { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame_for(&AuthenticationResponse::Initialization( + InitializationResponse { + methods: vec!["other".to_string()].into_iter().collect(), + }, + )) + .await + .unwrap(); + + let methods: Vec> = + vec![Box::new(SuccessAuthenticationMethod)]; + let verifier = Verifier::from(methods); + verifier.verify(&mut t1).await.unwrap_err(); + } + + #[test(tokio::test)] + async fn verifier_should_fail_to_verify_if_no_authentication_method_succeeds() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame_for(&AuthenticationResponse::Initialization( + InitializationResponse { + methods: vec![FailAuthenticationMethod.id().to_string()] + .into_iter() + .collect(), + }, + )) + .await + .unwrap(); + + let methods: Vec> = vec![Box::new(FailAuthenticationMethod)]; + let verifier = Verifier::from(methods); + verifier.verify(&mut t1).await.unwrap_err(); + } + + #[test(tokio::test)] + async fn verifier_should_return_id_of_authentication_method_upon_success() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame_for(&AuthenticationResponse::Initialization( + InitializationResponse { + methods: vec![SuccessAuthenticationMethod.id().to_string()] + .into_iter() + .collect(), + }, + )) + .await + .unwrap(); + + let methods: Vec> = + vec![Box::new(SuccessAuthenticationMethod)]; + let verifier = Verifier::from(methods); + assert_eq!( + verifier.verify(&mut t1).await.unwrap(), + SuccessAuthenticationMethod.id() + ); + } + + #[test(tokio::test)] + async fn verifier_should_try_authentication_methods_in_order_until_one_succeeds() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame_for(&AuthenticationResponse::Initialization( + InitializationResponse { + methods: vec![ + FailAuthenticationMethod.id().to_string(), + SuccessAuthenticationMethod.id().to_string(), + ] + .into_iter() + .collect(), + }, + )) + .await + .unwrap(); + + let methods: Vec> = vec![ + Box::new(FailAuthenticationMethod), + Box::new(SuccessAuthenticationMethod), + ]; + let verifier = Verifier::from(methods); + assert_eq!( + verifier.verify(&mut t1).await.unwrap(), + SuccessAuthenticationMethod.id() + ); + } + + #[test(tokio::test)] + async fn verifier_should_send_start_method_before_attempting_each_method() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame_for(&AuthenticationResponse::Initialization( + InitializationResponse { + methods: vec![ + FailAuthenticationMethod.id().to_string(), + SuccessAuthenticationMethod.id().to_string(), + ] + .into_iter() + .collect(), + }, + )) + .await + .unwrap(); + + let methods: Vec> = vec![ + Box::new(FailAuthenticationMethod), + Box::new(SuccessAuthenticationMethod), + ]; + Verifier::from(methods).verify(&mut t1).await.unwrap(); + + // Check that we get a start method for each of the attempted methods + match t2.read_frame_as::().await.unwrap().unwrap() { + Authentication::Initialization(_) => (), + x => panic!("Unexpected response: {x:?}"), + } + match t2.read_frame_as::().await.unwrap().unwrap() { + Authentication::StartMethod(x) => assert_eq!(x.method, FailAuthenticationMethod.id()), + x => panic!("Unexpected response: {x:?}"), + } + match t2.read_frame_as::().await.unwrap().unwrap() { + Authentication::StartMethod(x) => { + assert_eq!(x.method, SuccessAuthenticationMethod.id()) + } + x => panic!("Unexpected response: {x:?}"), + } + } + + #[test(tokio::test)] + async fn verifier_should_send_finished_when_a_method_succeeds() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame_for(&AuthenticationResponse::Initialization( + InitializationResponse { + methods: vec![ + FailAuthenticationMethod.id().to_string(), + SuccessAuthenticationMethod.id().to_string(), + ] + .into_iter() + .collect(), + }, + )) + .await + .unwrap(); + + let methods: Vec> = vec![ + Box::new(FailAuthenticationMethod), + Box::new(SuccessAuthenticationMethod), + ]; + Verifier::from(methods).verify(&mut t1).await.unwrap(); + + // Clear out the initialization and start methods + t2.read_frame_as::().await.unwrap().unwrap(); + t2.read_frame_as::().await.unwrap().unwrap(); + t2.read_frame_as::().await.unwrap().unwrap(); + + match t2.read_frame_as::().await.unwrap().unwrap() { + Authentication::Finished => (), + x => panic!("Unexpected response: {x:?}"), + } + } +} diff --git a/distant-net/src/common/authentication/methods/none.rs b/distant-net/src/common/authentication/methods/none.rs new file mode 100644 index 0000000..757b479 --- /dev/null +++ b/distant-net/src/common/authentication/methods/none.rs @@ -0,0 +1,32 @@ +use super::{AuthenticationMethod, Authenticator}; +use async_trait::async_trait; +use std::io; + +/// Authenticaton method for a static secret key +#[derive(Clone, Debug)] +pub struct NoneAuthenticationMethod; + +impl NoneAuthenticationMethod { + #[inline] + pub fn new() -> Self { + Self + } +} + +impl Default for NoneAuthenticationMethod { + #[inline] + fn default() -> Self { + Self + } +} + +#[async_trait] +impl AuthenticationMethod for NoneAuthenticationMethod { + fn id(&self) -> &'static str { + "none" + } + + async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> { + Ok(()) + } +} diff --git a/distant-net/src/common/authentication/methods/static_key.rs b/distant-net/src/common/authentication/methods/static_key.rs new file mode 100644 index 0000000..bace4fd --- /dev/null +++ b/distant-net/src/common/authentication/methods/static_key.rs @@ -0,0 +1,129 @@ +use super::{AuthenticationMethod, Authenticator, Challenge, Error, Question}; +use crate::common::HeapSecretKey; +use async_trait::async_trait; +use std::io; + +/// Authenticaton method for a static secret key +#[derive(Clone, Debug)] +pub struct StaticKeyAuthenticationMethod { + key: HeapSecretKey, +} + +impl StaticKeyAuthenticationMethod { + #[inline] + pub fn new(key: impl Into) -> Self { + Self { key: key.into() } + } +} + +#[async_trait] +impl AuthenticationMethod for StaticKeyAuthenticationMethod { + fn id(&self) -> &'static str { + "static_key" + } + + async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()> { + let response = authenticator + .challenge(Challenge { + questions: vec![Question { + label: "key".to_string(), + text: "Provide a key: ".to_string(), + options: Default::default(), + }], + options: Default::default(), + }) + .await?; + + if response.answers.is_empty() { + return Err(Error::non_fatal("missing answer").into_io_permission_denied()); + } + + match response + .answers + .into_iter() + .next() + .unwrap() + .parse::() + { + Ok(key) if key == self.key => Ok(()), + _ => Err(Error::non_fatal("answer does not match key").into_io_permission_denied()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::{ + authentication::msg::{AuthenticationResponse, ChallengeResponse}, + FramedTransport, + }; + use test_log::test; + + #[test(tokio::test)] + async fn authenticate_should_fail_if_key_challenge_fails() { + let method = StaticKeyAuthenticationMethod::new(b"".to_vec()); + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up an invalid frame for our challenge to ensure it fails + t2.write_frame(b"invalid initialization response") + .await + .unwrap(); + + assert_eq!( + method.authenticate(&mut t1).await.unwrap_err().kind(), + io::ErrorKind::InvalidData + ); + } + + #[test(tokio::test)] + async fn authenticate_should_fail_if_no_answer_included_in_challenge_response() { + let method = StaticKeyAuthenticationMethod::new(b"".to_vec()); + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame_for(&AuthenticationResponse::Challenge(ChallengeResponse { + answers: Vec::new(), + })) + .await + .unwrap(); + + assert_eq!( + method.authenticate(&mut t1).await.unwrap_err().kind(), + io::ErrorKind::PermissionDenied + ); + } + + #[test(tokio::test)] + async fn authenticate_should_fail_if_answer_does_not_match_key() { + let method = StaticKeyAuthenticationMethod::new(b"answer".to_vec()); + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame_for(&AuthenticationResponse::Challenge(ChallengeResponse { + answers: vec![HeapSecretKey::from(b"some key".to_vec()).to_string()], + })) + .await + .unwrap(); + + assert_eq!( + method.authenticate(&mut t1).await.unwrap_err().kind(), + io::ErrorKind::PermissionDenied + ); + } + + #[test(tokio::test)] + async fn authenticate_should_succeed_if_answer_matches_key() { + let method = StaticKeyAuthenticationMethod::new(b"answer".to_vec()); + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up a response to the initialization request + t2.write_frame_for(&AuthenticationResponse::Challenge(ChallengeResponse { + answers: vec![HeapSecretKey::from(b"answer".to_vec()).to_string()], + })) + .await + .unwrap(); + + method.authenticate(&mut t1).await.unwrap(); + } +} diff --git a/distant-net/src/common/authentication/msg.rs b/distant-net/src/common/authentication/msg.rs new file mode 100644 index 0000000..ef6baf6 --- /dev/null +++ b/distant-net/src/common/authentication/msg.rs @@ -0,0 +1,216 @@ +use derive_more::{Display, Error, From}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Represents messages from an authenticator that act as initiators such as providing +/// a challenge, verifying information, presenting information, or highlighting an error +#[derive(Clone, Debug, From, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum Authentication { + /// Indicates the beginning of authentication, providing available methods + #[serde(rename = "auth_initialization")] + Initialization(Initialization), + + /// Indicates that authentication is starting for the specific `method` + #[serde(rename = "auth_start_method")] + StartMethod(StartMethod), + + /// Issues a challenge to be answered + #[serde(rename = "auth_challenge")] + Challenge(Challenge), + + /// Requests verification of some text + #[serde(rename = "auth_verification")] + Verification(Verification), + + /// Reports some information associated with authentication + #[serde(rename = "auth_info")] + Info(Info), + + /// Reports an error occurrred during authentication + #[serde(rename = "auth_error")] + Error(Error), + + /// Indicates that the authentication of all methods is finished + #[serde(rename = "auth_finished")] + Finished, +} + +/// Represents the beginning of the authentication procedure +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Initialization { + /// Available methods to use for authentication + pub methods: Vec, +} + +/// Represents the start of authentication for some method +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct StartMethod { + pub method: String, +} + +/// Represents a challenge comprising a series of questions to be presented +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Challenge { + pub questions: Vec, + pub options: HashMap, +} + +/// Represents an ask to verify some information +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Verification { + pub kind: VerificationKind, + pub text: String, +} + +/// Represents some information to be presented related to authentication +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Info { + pub text: String, +} + +/// Represents authentication messages that are responses to authenticator requests such +/// as answers to challenges or verifying information +#[derive(Clone, Debug, From, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum AuthenticationResponse { + /// Contains response to initialization, providing details about which methods to use + #[serde(rename = "auth_initialization_response")] + Initialization(InitializationResponse), + + /// Contains answers to challenge request + #[serde(rename = "auth_challenge_response")] + Challenge(ChallengeResponse), + + /// Contains response to a verification request + #[serde(rename = "auth_verification_response")] + Verification(VerificationResponse), +} + +/// Represents a response to initialization to specify which authentication methods to pursue +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct InitializationResponse { + /// Methods to use (in order as provided) + pub methods: Vec, +} + +/// Represents the answers to a previously-asked challenge associated with authentication +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct ChallengeResponse { + /// Answers to challenge questions (in order relative to questions) + pub answers: Vec, +} + +/// Represents the answer to a previously-asked verification associated with authentication +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct VerificationResponse { + /// Whether or not the verification was deemed valid + pub valid: bool, +} + +/// Represents the type of verification being requested +#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum VerificationKind { + /// An ask to verify the host such as with SSH + #[display(fmt = "host")] + Host, + + /// When the verification is unknown (happens when other side is unaware of the kind) + #[display(fmt = "unknown")] + #[serde(other)] + Unknown, +} + +impl VerificationKind { + /// Returns all variants except "unknown" + pub const fn known_variants() -> &'static [Self] { + &[Self::Host] + } +} + +/// Represents a single question in a challenge associated with authentication +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Question { + /// Label associated with the question for more programmatic usage + pub label: String, + + /// The text of the question (used for display purposes) + pub text: String, + + /// Any options information specific to a particular auth domain + /// such as including a username and instructions for SSH authentication + pub options: HashMap, +} + +impl Question { + /// Creates a new question without any options data using `text` for both label and text + pub fn new(text: impl Into) -> Self { + let text = text.into(); + + Self { + label: text.clone(), + text, + options: HashMap::new(), + } + } +} + +/// Represents some error that occurred during authentication +#[derive(Clone, Debug, Display, Error, PartialEq, Eq, Serialize, Deserialize)] +#[display(fmt = "{}: {}", kind, text)] +pub struct Error { + /// Represents the kind of error + pub kind: ErrorKind, + + /// Description of the error + pub text: String, +} + +impl Error { + /// Creates a fatal error + pub fn fatal(text: impl Into) -> Self { + Self { + kind: ErrorKind::Fatal, + text: text.into(), + } + } + + /// Creates a non-fatal error + pub fn non_fatal(text: impl Into) -> Self { + Self { + kind: ErrorKind::Error, + text: text.into(), + } + } + + /// Returns true if error represents a fatal error, meaning that there is no recovery possible + /// from this error + pub fn is_fatal(&self) -> bool { + self.kind.is_fatal() + } + + /// Converts the error into a [`std::io::Error`] representing permission denied + pub fn into_io_permission_denied(self) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::PermissionDenied, self) + } +} + +/// Represents the type of error encountered during authentication +#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ErrorKind { + /// Error is unrecoverable + Fatal, + + /// Error is recoverable + Error, +} + +impl ErrorKind { + /// Returns true if error kind represents a fatal error, meaning that there is no recovery + /// possible from this error + pub fn is_fatal(self) -> bool { + matches!(self, Self::Fatal) + } +} diff --git a/distant-net/src/common/connection.rs b/distant-net/src/common/connection.rs new file mode 100644 index 0000000..64833dd --- /dev/null +++ b/distant-net/src/common/connection.rs @@ -0,0 +1,1291 @@ +use super::{ + authentication::{AuthHandler, Authenticate, Keychain, KeychainResult, Verifier}, + Backup, FramedTransport, HeapSecretKey, Reconnectable, Transport, +}; +use async_trait::async_trait; +use log::*; +use serde::{Deserialize, Serialize}; +use std::io; +use std::ops::{Deref, DerefMut}; +use tokio::sync::oneshot; + +#[cfg(test)] +use super::InmemoryTransport; + +/// Id of the connection +pub type ConnectionId = u32; + +/// Represents a connection from either the client or server side +#[derive(Debug)] +pub enum Connection { + /// Connection from the client side + Client { + /// Unique id associated with the connection + id: ConnectionId, + + /// One-time password (OTP) for use in reauthenticating with the server + reauth_otp: HeapSecretKey, + + /// Underlying transport used to communicate + transport: FramedTransport, + }, + + /// Connection from the server side + Server { + /// Unique id associated with the connection + id: ConnectionId, + + /// Used to send the backup into storage when the connection is dropped + tx: oneshot::Sender, + + /// Underlying transport used to communicate + transport: FramedTransport, + }, +} + +impl Deref for Connection { + type Target = FramedTransport; + + fn deref(&self) -> &Self::Target { + match self { + Self::Client { transport, .. } => transport, + Self::Server { transport, .. } => transport, + } + } +} + +impl DerefMut for Connection { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + Self::Client { transport, .. } => transport, + Self::Server { transport, .. } => transport, + } + } +} + +impl Drop for Connection { + /// On drop for a server connection, the connection's backup will be sent via `tx`. For a + /// client connection, nothing happens. + fn drop(&mut self) { + match self { + Self::Client { .. } => (), + Self::Server { tx, transport, .. } => { + // NOTE: We grab the current backup state and store it using the tx, replacing + // the backup with a default and the tx with a disconnected one + let backup = std::mem::take(&mut transport.backup); + let tx = std::mem::replace(tx, oneshot::channel().0); + let _ = tx.send(backup); + } + } + } +} + +#[async_trait] +impl Reconnectable for Connection +where + T: Transport, +{ + /// Attempts to re-establish a connection. + /// + /// ### Client + /// + /// For a client, this means performing an actual [`reconnect`] on the underlying + /// [`Transport`], re-establishing an encrypted codec, submitting a request to the server to + /// reauthenticate using a previously-derived OTP, and refreshing the connection id and OTP for + /// use in a future reauthentication. + /// + /// ### Server + /// + /// For a server, this will fail as unsupported. + /// + /// [`reconnect`]: Reconnectable::reconnect + async fn reconnect(&mut self) -> io::Result<()> { + async fn reconnect_client( + id: &mut ConnectionId, + reauth_otp: &mut HeapSecretKey, + transport: &mut FramedTransport, + ) -> io::Result<()> { + // Re-establish a raw connection + debug!("[Conn {id}] Re-establishing connection"); + Reconnectable::reconnect(transport).await?; + + // Perform a handshake to ensure that the connection is properly established and encrypted + debug!("[Conn {id}] Performing handshake"); + transport.client_handshake().await?; + + // Communicate that we are an existing connection + debug!("[Conn {id}] Performing re-authentication"); + transport + .write_frame_for(&ConnectType::Reconnect { + id: *id, + otp: reauth_otp.unprotected_as_bytes().to_vec(), + }) + .await?; + + // Receive the new id for the connection + // NOTE: If we fail re-authentication above, + // this will fail as the connection is dropped + debug!("[Conn {id}] Receiving new connection id"); + let new_id = transport + .read_frame_as::() + .await? + .ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "Missing connection id frame") + })?; + debug!("[Conn {id}] Resetting id to {new_id}"); + *id = new_id; + + // Derive an OTP for reauthentication + debug!("[Conn {id}] Deriving future OTP for reauthentication"); + *reauth_otp = transport.exchange_keys().await?.into_heap_secret_key(); + + Ok(()) + } + + match self { + Self::Client { + id, + transport, + reauth_otp, + } => { + // Freeze our backup as we don't want the connection logic to alter it + transport.backup.freeze(); + + // Attempt to perform the reconnection and unfreeze our backup regardless of the + // result + let result = reconnect_client(id, reauth_otp, transport).await; + transport.backup.unfreeze(); + result?; + + // Perform synchronization + debug!("[Conn {id}] Synchronizing frame state"); + transport.synchronize().await?; + + Ok(()) + } + + Self::Server { .. } => Err(io::Error::new( + io::ErrorKind::Unsupported, + "Server connection cannot reconnect", + )), + } + } +} + +/// Type of connection to perform +#[derive(Debug, Serialize, Deserialize)] +enum ConnectType { + /// Indicates that the connection from client to server is no and not a reconnection + Connect, + + /// Indicates that the connection from client to server is a reconnection and should attempt to + /// use the connection id and OTP to authenticate + Reconnect { + /// Id of the connection to reauthenticate + id: ConnectionId, + + /// Raw bytes of the OTP + #[serde(with = "serde_bytes")] + otp: Vec, + }, +} + +impl Connection +where + T: Transport, +{ + /// Transforms a raw [`Transport`] into an established [`Connection`] from the client-side by + /// performing the following: + /// + /// 1. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use + /// 2. Authenticates the established connection to ensure it is valid + /// 3. Restores pre-existing state using the provided backup, replaying any missing frames and + /// receiving any frames from the other side + pub async fn client(transport: T, handler: H) -> io::Result { + let id: ConnectionId = rand::random(); + + // Perform a handshake to ensure that the connection is properly established and encrypted + debug!("[Conn {id}] Performing handshake"); + let mut transport: FramedTransport = + FramedTransport::from_client_handshake(transport).await?; + + // Communicate that we are a new connection + debug!("[Conn {id}] Communicating that this is a new connection"); + transport.write_frame_for(&ConnectType::Connect).await?; + + // Receive the new id for the connection + let id = { + debug!("[Conn {id}] Receiving new connection id"); + let new_id = transport + .read_frame_as::() + .await? + .ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "Missing connection id frame") + })?; + debug!("[Conn {id}] Resetting id to {new_id}"); + new_id + }; + + // Authenticate the transport with the server-side + debug!("[Conn {id}] Performing authentication"); + transport.authenticate(handler).await?; + + // Derive an OTP for reauthentication + debug!("[Conn {id}] Deriving future OTP for reauthentication"); + let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key(); + + Ok(Self::Client { + id, + reauth_otp, + transport, + }) + } + + /// Transforms a raw [`Transport`] into an established [`Connection`] from the server-side by + /// performing the following: + /// + /// 1. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use + /// 2. Authenticates the established connection to ensure it is valid by either using the + /// given `verifier` or, if working with an existing client connection, will validate an OTP + /// from our database + /// 3. Restores pre-existing state using the provided backup, replaying any missing frames and + /// receiving any frames from the other side + pub async fn server( + transport: T, + verifier: &Verifier, + keychain: Keychain>, + ) -> io::Result { + let id: ConnectionId = rand::random(); + + // Perform a handshake to ensure that the connection is properly established and encrypted + debug!("[Conn {id}] Performing handshake"); + let mut transport: FramedTransport = + FramedTransport::from_server_handshake(transport).await?; + + // Receive a client id, look up to see if the client id exists already + // + // 1. If it already exists, wait for a password to follow, which is a one-time password used by + // the client. If the password is correct, then generate a new one-time client id and + // password for a future connection (only updating if the connection fully completes) and + // send it to the client, and then perform a replay situation + // + // 2. If it does not exist, ignore the client id and password. Generate a new client id to send + // to the client. Perform verification like usual. Then generate a one-time password and + // send it to the client. + debug!("[Conn {id}] Waiting for connection type"); + let connection_type = transport + .read_frame_as::() + .await? + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Missing connection type frame"))?; + + // Create a oneshot channel used to relay the backup when the connection is dropped + let (tx, rx) = oneshot::channel(); + + // Based on the connection type, we either try to find and validate an existing connection + // or we perform normal verification + match connection_type { + ConnectType::Connect => { + // Communicate the connection id + debug!("[Conn {id}] Telling other side to change connection id"); + transport.write_frame_for(&id).await?; + + // Perform authentication to ensure the connection is valid + debug!("[Conn {id}] Verifying connection"); + verifier.verify(&mut transport).await?; + + // Derive an OTP for reauthentication + debug!("[Conn {id}] Deriving future OTP for reauthentication"); + let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key(); + + // Store the id, OTP, and backup retrieval in our database + keychain.insert(id.to_string(), reauth_otp, rx).await; + } + ConnectType::Reconnect { id: other_id, otp } => { + let reauth_otp = HeapSecretKey::from(otp); + + debug!("[Conn {id}] Checking if {other_id} exists and has matching OTP"); + match keychain + .remove_if_has_key(other_id.to_string(), reauth_otp) + .await + { + KeychainResult::Ok(x) => { + // Communicate the connection id + debug!("[Conn {id}] Telling other side to change connection id"); + transport.write_frame_for(&id).await?; + + // Derive an OTP for reauthentication + debug!("[Conn {id}] Deriving future OTP for reauthentication"); + let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key(); + + // Grab the old backup and swap it into our transport + debug!("[Conn {id}] Acquiring backup for existing connection"); + match x.await { + Ok(backup) => { + transport.backup = backup; + } + Err(_) => { + warn!("[Conn {id}] Missing backup"); + } + } + + // Synchronize using the provided backup + debug!("[Conn {id}] Synchronizing frame state"); + transport.synchronize().await?; + + // Store the id, OTP, and backup retrieval in our database + keychain.insert(id.to_string(), reauth_otp, rx).await; + } + KeychainResult::InvalidPassword => { + return Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "Invalid OTP for reconnect", + )); + } + KeychainResult::InvalidId => { + return Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "Invalid id for reconnect", + )); + } + } + } + } + + Ok(Self::Server { id, tx, transport }) + } +} + +#[cfg(test)] +impl Connection { + /// Establishes a pair of [`Connection`]s using [`InmemoryTransport`] underneath, returning + /// them in the form (client, server). + /// + /// ### Note + /// + /// This skips handshakes, authentication, and backup processing. These connections cannot be + /// reconnected and have no encryption. + pub fn pair(buffer: usize) -> (Self, Self) { + let id = rand::random::(); + let (t1, t2) = FramedTransport::pair(buffer); + + let client = Connection::Client { + id, + reauth_otp: HeapSecretKey::generate(32).unwrap(), + transport: t1, + }; + + let server = Connection::Server { + id, + tx: oneshot::channel().0, + transport: t2, + }; + + (client, server) + } +} + +#[cfg(test)] +impl Connection { + /// Returns the id of the connection. + pub fn id(&self) -> ConnectionId { + match self { + Self::Client { id, .. } => *id, + Self::Server { id, .. } => *id, + } + } + + /// Returns the OTP associated with the connection, or none if connection is server-side. + pub fn otp(&self) -> Option<&HeapSecretKey> { + match self { + Self::Client { reauth_otp, .. } => Some(reauth_otp), + Self::Server { .. } => None, + } + } + + /// Returns a reference to the underlying transport. + pub fn transport(&self) -> &FramedTransport { + match self { + Self::Client { transport, .. } => transport, + Self::Server { transport, .. } => transport, + } + } + + /// Returns a mutable reference to the underlying transport. + pub fn mut_transport(&mut self) -> &mut FramedTransport { + match self { + Self::Client { transport, .. } => transport, + Self::Server { transport, .. } => transport, + } + } +} + +#[cfg(test)] +impl Connection { + pub fn test_client(transport: T) -> Self { + Self::Client { + id: rand::random(), + reauth_otp: HeapSecretKey::generate(32).unwrap(), + transport: FramedTransport::plain(transport), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::{ + authentication::{msg::Challenge, Authenticator, DummyAuthHandler}, + Frame, + }; + use std::sync::Arc; + use test_log::test; + + #[test(tokio::test)] + async fn client_should_fail_if_codec_handshake_fails() { + let (mut t1, t2) = FramedTransport::pair(100); + + // Spawn a task to perform the client connection so we don't deadlock while simulating the + // server actions on the other side + let task = tokio::spawn(async move { + Connection::client(t2.into_inner(), DummyAuthHandler) + .await + .unwrap() + }); + + // Send garbage to fail the handshake + t1.write_frame(Frame::new(b"invalid")).await.unwrap(); + + // Client should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn client_should_fail_if_unable_to_receive_connection_id_from_server() { + let (mut t1, t2) = FramedTransport::pair(100); + + // Spawn a task to perform the client connection so we don't deadlock while simulating the + // server actions on the other side + let task = tokio::spawn(async move { + Connection::client(t2.into_inner(), DummyAuthHandler) + .await + .unwrap() + }); + + // Perform first step of connection by establishing the codec + t1.server_handshake().await.unwrap(); + + // Receive a type that indicates a new connection + let ct = t1.read_frame_as::().await.unwrap().unwrap(); + assert!( + matches!(ct, ConnectType::Connect), + "Unexpected connect type: {ct:?}" + ); + + // Drop to cause id retrieval on client to fail + drop(t1); + + // Client should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn client_should_fail_if_authentication_fails() { + let (mut t1, t2) = FramedTransport::pair(100); + + // Spawn a task to perform the client connection so we don't deadlock while simulating the + // server actions on the other side + let task = tokio::spawn(async move { + Connection::client(t2.into_inner(), DummyAuthHandler) + .await + .unwrap() + }); + + // Perform first step of connection by establishing the codec + t1.server_handshake().await.unwrap(); + + // Receive a type that indicates a new connection + let ct = t1.read_frame_as::().await.unwrap().unwrap(); + assert!( + matches!(ct, ConnectType::Connect), + "Unexpected connect type: {ct:?}" + ); + + // Send a connection id as second step of connection + t1.write_frame_for(&rand::random::()) + .await + .unwrap(); + + // Perform an authentication request that will fail on the client side, which will + // cause the client to drop and therefore this transport to fail in getting a response + t1.challenge(Challenge { + questions: Vec::new(), + options: Default::default(), + }) + .await + .unwrap_err(); + + // Client should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn client_should_fail_if_unable_to_exchange_otp_for_reauthentication() { + let (mut t1, t2) = FramedTransport::pair(100); + + // Spawn a task to perform the client connection so we don't deadlock while simulating the + // server actions on the other side + let task = tokio::spawn(async move { + Connection::client(t2.into_inner(), DummyAuthHandler) + .await + .unwrap() + }); + + // Perform first step of connection by establishing the codec + t1.server_handshake().await.unwrap(); + + // Receive a type that indicates a new connection + let ct = t1.read_frame_as::().await.unwrap().unwrap(); + assert!( + matches!(ct, ConnectType::Connect), + "Unexpected connect type: {ct:?}" + ); + + // Send a connection id as second step of connection + t1.write_frame_for(&rand::random::()) + .await + .unwrap(); + + // Perform verification as third step using none method, which should always succeed + // without challenging + Verifier::none().verify(&mut t1).await.unwrap(); + + // Send garbage to fail the key exchange + t1.write_frame(Frame::new(b"invalid")).await.unwrap(); + + // Client should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn client_should_succeed_if_establishes_connection_with_server() { + let (mut t1, t2) = FramedTransport::pair(100); + + // Spawn a task to perform the client connection so we don't deadlock while simulating the + // server actions on the other side + let task = tokio::spawn(async move { + Connection::client(t2.into_inner(), DummyAuthHandler) + .await + .unwrap() + }); + + // Perform first step of connection by establishing the codec + t1.server_handshake().await.unwrap(); + + // Receive a type that indicates a new connection + let ct = t1.read_frame_as::().await.unwrap().unwrap(); + assert!( + matches!(ct, ConnectType::Connect), + "Unexpected connect type: {ct:?}" + ); + + // Send a connection id as second step of connection + t1.write_frame_for(&rand::random::()) + .await + .unwrap(); + + // Perform verification as third step using none method, which should always succeed + // without challenging + Verifier::none().verify(&mut t1).await.unwrap(); + + // Perform fourth step of key exchange for OTP + let otp = t1.exchange_keys().await.unwrap().into_heap_secret_key(); + + // Client should succeed and have an OTP that matches the server-side version + let client = task.await.unwrap(); + assert_eq!(client.otp(), Some(&otp)); + } + + #[test(tokio::test)] + async fn server_should_fail_if_codec_handshake_fails() { + let (mut t1, t2) = FramedTransport::pair(100); + let verifier = Verifier::none(); + let keychain = Keychain::new(); + + // Spawn a task to perform the server connection so we don't deadlock while simulating the + // client actions on the other side + let task = tokio::spawn(async move { + Connection::server(t2.into_inner(), &verifier, keychain) + .await + .unwrap() + }); + + // Send garbage to fail the handshake + t1.write_frame(Frame::new(b"invalid")).await.unwrap(); + + // Server should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn server_should_fail_if_unable_to_receive_connect_type() { + let (mut t1, t2) = FramedTransport::pair(100); + let verifier = Verifier::none(); + let keychain = Keychain::new(); + + // Spawn a task to perform the server connection so we don't deadlock while simulating the + // client actions on the other side + let task = tokio::spawn(async move { + Connection::server(t2.into_inner(), &verifier, keychain) + .await + .unwrap() + }); + + // Perform first step of completing client-side of handshake + t1.client_handshake().await.unwrap(); + + // Send some garbage that is not the connection type + t1.write_frame(Frame::new(b"hello")).await.unwrap(); + + // Server should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn server_should_fail_if_unable_to_verify_new_client() { + let (mut t1, t2) = FramedTransport::pair(100); + let verifier = Verifier::static_key(HeapSecretKey::generate(32).unwrap()); + let keychain = Keychain::new(); + + // Spawn a task to perform the server connection so we don't deadlock while simulating the + // client actions on the other side + let task = tokio::spawn(async move { + Connection::server(t2.into_inner(), &verifier, keychain) + .await + .unwrap() + }); + + // Perform first step of completing client-side of handshake + t1.client_handshake().await.unwrap(); + + // Send type to indicate a new connection + t1.write_frame_for(&ConnectType::Connect).await.unwrap(); + + // Receive the connection id + let _id = t1.read_frame_as::().await.unwrap().unwrap(); + + // Fail verification using the dummy handler that will fail when asked for a static key + t1.authenticate(DummyAuthHandler).await.unwrap_err(); + + // Drop the transport so we kill the server-side connection + // NOTE: If we don't drop here, the above authentication failure won't kill the server + drop(t1); + + // Server should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn server_should_fail_if_unable_to_exchange_otp_for_reauthentication_with_new_client() { + let (mut t1, t2) = FramedTransport::pair(100); + let verifier = Verifier::none(); + let keychain = Keychain::new(); + + // Spawn a task to perform the server connection so we don't deadlock while simulating the + // client actions on the other side + let task = tokio::spawn(async move { + Connection::server(t2.into_inner(), &verifier, keychain) + .await + .unwrap() + }); + + // Perform first step of completing client-side of handshake + t1.client_handshake().await.unwrap(); + + // Send type to indicate a new connection + t1.write_frame_for(&ConnectType::Connect).await.unwrap(); + + // Receive the connection id + let _id = t1.read_frame_as::().await.unwrap().unwrap(); + + // Pass verification using the dummy handler since our verifier supports no authentication + t1.authenticate(DummyAuthHandler).await.unwrap(); + + // Send some garbage to fail the exchange + t1.write_frame(Frame::new(b"hello")).await.unwrap(); + + // Server should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn server_should_fail_if_existing_client_id_is_invalid() { + let (mut t1, t2) = FramedTransport::pair(100); + let verifier = Verifier::none(); + let keychain = Keychain::new(); + + // Spawn a task to perform the server connection so we don't deadlock while simulating the + // client actions on the other side + let task = tokio::spawn(async move { + Connection::server(t2.into_inner(), &verifier, keychain) + .await + .unwrap() + }); + + // Perform first step of completing client-side of handshake + t1.client_handshake().await.unwrap(); + + // Send type to indicate an existing connection, which should cause the server-side to fail + // because there is no matching id + t1.write_frame_for(&ConnectType::Reconnect { + id: 1234, + otp: HeapSecretKey::generate(32) + .unwrap() + .unprotected_into_bytes(), + }) + .await + .unwrap(); + + // Server should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn server_should_fail_if_existing_client_otp_is_invalid() { + let (mut t1, t2) = FramedTransport::pair(100); + let verifier = Verifier::none(); + let keychain = Keychain::new(); + + keychain + .insert( + 1234.to_string(), + HeapSecretKey::generate(32).unwrap(), + oneshot::channel().1, + ) + .await; + + // Spawn a task to perform the server connection so we don't deadlock while simulating the + // client actions on the other side + let task = tokio::spawn(async move { + Connection::server(t2.into_inner(), &verifier, keychain) + .await + .unwrap() + }); + + // Perform first step of completing client-side of handshake + t1.client_handshake().await.unwrap(); + + // Send type to indicate an existing connection, which should cause the server-side to fail + // because the OTP is wrong for the given id + t1.write_frame_for(&ConnectType::Reconnect { + id: 1234, + otp: HeapSecretKey::generate(32) + .unwrap() + .unprotected_into_bytes(), + }) + .await + .unwrap(); + + // Server should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn server_should_fail_if_unable_to_exchange_otp_for_reauthentication_with_existing_client( + ) { + let (mut t1, t2) = FramedTransport::pair(100); + let verifier = Verifier::none(); + let keychain = Keychain::new(); + let key = HeapSecretKey::generate(32).unwrap(); + + keychain + .insert(1234.to_string(), key.clone(), oneshot::channel().1) + .await; + + // Spawn a task to perform the server connection so we don't deadlock while simulating the + // client actions on the other side + let task = tokio::spawn(async move { + Connection::server(t2.into_inner(), &verifier, keychain) + .await + .unwrap() + }); + + // Perform first step of completing client-side of handshake + t1.client_handshake().await.unwrap(); + + // Send type to indicate an existing connection, which should cause the server-side to fail + // because the OTP is wrong for the given id + t1.write_frame_for(&ConnectType::Reconnect { + id: 1234, + otp: key.unprotected_into_bytes(), + }) + .await + .unwrap(); + + // Receive a new client id + let _id = t1.read_frame_as::().await.unwrap().unwrap(); + + // Send garbage to fail the otp exchange + t1.write_frame(Frame::new(b"hello")).await.unwrap(); + + // Server should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn server_should_fail_if_unable_to_synchronize_with_existing_client() { + let (mut t1, t2) = FramedTransport::pair(100); + let verifier = Verifier::none(); + let keychain = Keychain::new(); + let key = HeapSecretKey::generate(32).unwrap(); + + keychain + .insert(1234.to_string(), key.clone(), oneshot::channel().1) + .await; + + // Spawn a task to perform the server connection so we don't deadlock while simulating the + // client actions on the other side + let task = tokio::spawn(async move { + Connection::server(t2.into_inner(), &verifier, keychain) + .await + .unwrap() + }); + + // Perform first step of completing client-side of handshake + t1.client_handshake().await.unwrap(); + + // Send type to indicate an existing connection, which should cause the server-side to fail + // because the OTP is wrong for the given id + t1.write_frame_for(&ConnectType::Reconnect { + id: 1234, + otp: key.unprotected_into_bytes(), + }) + .await + .unwrap(); + + // Receive a new client id + let _id = t1.read_frame_as::().await.unwrap().unwrap(); + + // Perform otp exchange + let _otp = t1.exchange_keys().await.unwrap(); + + // Send garbage to fail synchronization + t1.write_frame(b"hello").await.unwrap(); + + // Server should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn server_should_succeed_if_establishes_connection_with_new_client() { + let (mut t1, t2) = FramedTransport::pair(100); + let verifier = Verifier::none(); + let keychain = Keychain::new(); + + // Spawn a task to perform the server connection so we don't deadlock while simulating the + // client actions on the other side + let task = tokio::spawn({ + let keychain = keychain.clone(); + async move { + Connection::server(t2.into_inner(), &verifier, keychain) + .await + .unwrap() + } + }); + + // Perform first step of completing client-side of handshake + t1.client_handshake().await.unwrap(); + + // Send type to indicate a new connection + t1.write_frame_for(&ConnectType::Connect).await.unwrap(); + + // Receive the connection id + let id = t1.read_frame_as::().await.unwrap().unwrap(); + + // Pass verification using the dummy handler since our verifier supports no authentication + t1.authenticate(DummyAuthHandler).await.unwrap(); + + // Perform otp exchange + let otp = t1.exchange_keys().await.unwrap(); + + // Server connection should be established, and have received some replayed frames + let server = task.await.unwrap(); + + // Validate the connection ids match + assert_eq!(server.id(), id); + + // Validate the OTP was stored in our keychain + assert!( + keychain + .has_key(id.to_string(), otp.into_heap_secret_key()) + .await, + "Missing OTP" + ); + } + + #[test(tokio::test)] + async fn server_should_succeed_if_establishes_connection_with_existing_client() { + let (mut t1, t2) = FramedTransport::pair(100); + let verifier = Verifier::none(); + let keychain = Keychain::new(); + let key = HeapSecretKey::generate(32).unwrap(); + + keychain + .insert(1234.to_string(), key.clone(), { + // Create a custom backup we'll use to replay frames from the server-side + let mut backup = Backup::new(); + + backup.push_frame(Frame::new(b"hello")); + backup.push_frame(Frame::new(b"world")); + backup.increment_sent_cnt(); + backup.increment_sent_cnt(); + + let (tx, rx) = oneshot::channel(); + tx.send(backup).unwrap(); + rx + }) + .await; + + // Spawn a task to perform the server connection so we don't deadlock while simulating the + // client actions on the other side + let task = tokio::spawn({ + let keychain = keychain.clone(); + async move { + Connection::server(t2.into_inner(), &verifier, keychain) + .await + .unwrap() + } + }); + + // Perform first step of completing client-side of handshake + t1.client_handshake().await.unwrap(); + + // Send type to indicate an existing connection, which should cause the server-side to fail + // because the OTP is wrong for the given id + t1.write_frame_for(&ConnectType::Reconnect { + id: 1234, + otp: key.unprotected_into_bytes(), + }) + .await + .unwrap(); + + // Receive a new client id + let id = t1.read_frame_as::().await.unwrap().unwrap(); + + // Perform otp exchange + let otp = t1.exchange_keys().await.unwrap(); + + // Queue up some frames to send to the server + t1.backup.clear(); + t1.backup.push_frame(Frame::new(b"foo")); + t1.backup.push_frame(Frame::new(b"bar")); + t1.backup.increment_sent_cnt(); + t1.backup.increment_sent_cnt(); + + // Perform synchronization + t1.synchronize().await.unwrap(); + + // Verify that we received frames from the server + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello"); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world"); + + // Server connection should be established, and have received some replayed frames + let mut server = task.await.unwrap(); + assert_eq!(server.read_frame().await.unwrap().unwrap(), b"foo"); + assert_eq!(server.read_frame().await.unwrap().unwrap(), b"bar"); + + // Validate the connection ids match + assert_eq!(server.id(), id); + + // Check that our old connection id is no longer contained in the keychain + assert!(!keychain.has_id("1234").await, "Old OTP still exists"); + + // Validate the OTP was stored in our keychain + assert!( + keychain + .has_key(id.to_string(), otp.into_heap_secret_key()) + .await, + "Missing OTP" + ); + } + + #[test(tokio::test)] + async fn client_server_new_connection_e2e_should_establish_connection() { + let (t1, t2) = InmemoryTransport::pair(100); + let verifier = Verifier::none(); + let keychain = Keychain::new(); + + // Spawn a task to perform the server connection so we don't deadlock + let task = tokio::spawn(async move { + Connection::server(t2, &verifier, keychain) + .await + .expect("Failed to connect from server") + }); + + // Perform the client-side of the connection + let mut client = Connection::client(t1, DummyAuthHandler) + .await + .expect("Failed to connect from client"); + let mut server = task.await.unwrap(); + + // Test out the connection + client.write_frame(Frame::new(b"hello")).await.unwrap(); + assert_eq!(server.read_frame().await.unwrap().unwrap(), b"hello"); + server.write_frame(Frame::new(b"goodbye")).await.unwrap(); + assert_eq!(client.read_frame().await.unwrap().unwrap(), b"goodbye"); + } + + /// Helper utility to set up for a client reconnection + async fn setup_reconnect_scenario() -> ( + Connection, + InmemoryTransport, + Arc, + Keychain>, + ) { + let (t1, t2) = InmemoryTransport::pair(100); + let verifier = Arc::new(Verifier::none()); + let keychain = Keychain::new(); + + // Spawn a task to perform the server connection so we don't deadlock + let task = { + let verifier = Arc::clone(&verifier); + let keychain = keychain.clone(); + tokio::spawn(async move { + Connection::server(t2, &verifier, keychain) + .await + .expect("Failed to connect from server") + }) + }; + + // Perform the client-side of the connection + let mut client = Connection::client(t1, DummyAuthHandler) + .await + .expect("Failed to connect from client"); + + // Ensure the server is established and then drop it + let server = task.await.unwrap(); + drop(server); + + // Create a new inmemory transport and link it to the client + let mut t2 = InmemoryTransport::pair(100).0; + t2.link(client.mut_transport().as_mut_inner(), 100); + + (client, t2, verifier, keychain) + } + + #[test(tokio::test)] + async fn reconnect_should_fail_if_client_side_connection_handshake_fails() { + let (mut client, transport, _verifier, _keychain) = setup_reconnect_scenario().await; + let mut transport = FramedTransport::plain(transport); + + // Spawn a task to perform the client reconnection so we don't deadlock + let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); + + // Send garbage to fail handshake from server-side + transport.write_frame(b"hello").await.unwrap(); + + // Client should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn reconnect_should_fail_if_client_side_connection_unable_to_receive_new_connection_id() { + let (mut client, transport, _verifier, _keychain) = setup_reconnect_scenario().await; + let mut transport = FramedTransport::plain(transport); + + // Spawn a task to perform the client reconnection so we don't deadlock + let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); + + // Perform first step of completing server-side of handshake + transport.server_handshake().await.unwrap(); + + // Drop transport to cause client to fail in not receiving connection id + drop(transport); + + // Client should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn reconnect_should_fail_if_client_side_connection_unable_to_exchange_otp_with_server() { + let (mut client, transport, _verifier, keychain) = setup_reconnect_scenario().await; + let mut transport = FramedTransport::plain(transport); + + // Spawn a task to perform the client reconnection so we don't deadlock + let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); + + // Perform first step of completing server-side of handshake + transport.server_handshake().await.unwrap(); + + // Receive reconnect data from client-side + let (id, otp) = match transport.read_frame_as::().await { + Ok(Some(ConnectType::Reconnect { id, otp })) => (id, HeapSecretKey::from(otp)), + x => panic!("Unexpected result: {x:?}"), + }; + + // Verify the id and OTP matches the one stored into our keychain from the setup + assert!( + keychain.has_key(id.to_string(), otp).await, + "Wrong id or OTP" + ); + + // Send a new id back to the client connection + transport + .write_frame_for(&rand::random::()) + .await + .unwrap(); + + // Send garbage to fail the key exchange for new OTP + transport.write_frame(Frame::new(b"hello")).await.unwrap(); + + // Client should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn reconnect_should_fail_if_client_side_connection_unable_to_synchronize_with_server() { + let (mut client, transport, _verifier, keychain) = setup_reconnect_scenario().await; + let mut transport = FramedTransport::plain(transport); + + // Spawn a task to perform the client reconnection so we don't deadlock + let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); + + // Perform first step of completing server-side of handshake + transport.server_handshake().await.unwrap(); + + // Receive reconnect data from client-side + let (id, otp) = match transport.read_frame_as::().await { + Ok(Some(ConnectType::Reconnect { id, otp })) => (id, HeapSecretKey::from(otp)), + x => panic!("Unexpected result: {x:?}"), + }; + + // Verify the id and OTP matches the one stored into our keychain from the setup + assert!( + keychain.has_key(id.to_string(), otp).await, + "Wrong id or OTP" + ); + + // Send a new id back to the client connection + transport + .write_frame_for(&rand::random::()) + .await + .unwrap(); + + // Send garbage to fail the key exchange for new OTP + transport.write_frame(Frame::new(b"hello")).await.unwrap(); + + // Client should fail + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn reconnect_should_succeed_if_client_side_connection_fully_connects_and_synchronizes_with_server( + ) { + let (mut client, transport, _verifier, keychain) = setup_reconnect_scenario().await; + let mut transport = FramedTransport::plain(transport); + + // Copy client backup for verification later + let client_backup = client.transport().backup.clone(); + + // Spawn a task to perform the client reconnection so we don't deadlock + let task = tokio::spawn(async move { + client.reconnect().await.unwrap(); + client + }); + + // Perform first step of completing server-side of handshake + transport.server_handshake().await.unwrap(); + + // Receive reconnect data from client-side + let (id, otp) = match transport.read_frame_as::().await { + Ok(Some(ConnectType::Reconnect { id, otp })) => (id, HeapSecretKey::from(otp)), + x => panic!("Unexpected result: {x:?}"), + }; + + // Retrieve server backup + let backup = keychain + .remove_if_has_key(id.to_string(), otp) + .await + .into_ok() + .expect("Invalid id or OTP") + .await + .expect("Failed to retrieve backup"); + + // Send a new id back to the client connection + transport + .write_frame_for(&rand::random::()) + .await + .unwrap(); + + // Perform key exchange + let otp = transport.exchange_keys().await.unwrap(); + + // Perform synchronization after restoring backup + transport.backup = backup; + transport.synchronize().await.unwrap(); + + // Client should succeed + let mut client = task.await.unwrap(); + assert_eq!(client.otp(), Some(&otp.into_heap_secret_key())); + + // Verify client backup sent/received count was not modified (stored frames may be + // truncated, though) + assert_eq!( + client.transport().backup.sent_cnt(), + client_backup.sent_cnt(), + "Client backup sent cnt altered" + ); + assert_eq!( + client.transport().backup.received_cnt(), + client_backup.received_cnt(), + "Client backup received cnt altered" + ); + + // Verify that client can send a frame and receive a frame, and that there is + // nothing unexpected in the buffers on either side + client.write_frame(Frame::new(b"hello")).await.unwrap(); + assert_eq!(transport.read_frame().await.unwrap().unwrap(), b"hello"); + transport.write_frame(Frame::new(b"goodbye")).await.unwrap(); + assert_eq!(client.read_frame().await.unwrap().unwrap(), b"goodbye"); + } + + #[test(tokio::test)] + async fn reconnect_should_fail_if_connection_is_server_side() { + let mut connection = Connection::Server { + id: rand::random(), + tx: oneshot::channel().0, + transport: FramedTransport::pair(100).0, + }; + + assert_eq!( + connection.reconnect().await.unwrap_err().kind(), + io::ErrorKind::Unsupported + ); + } + + #[test(tokio::test)] + async fn client_server_returning_connection_e2e_should_reestablish_connection() { + let (mut client, transport, verifier, keychain) = setup_reconnect_scenario().await; + + // Spawn a task to perform the server reconnection so we don't deadlock + let task = tokio::spawn(async move { + Connection::server(transport, &verifier, keychain) + .await + .expect("Failed to connect from server") + }); + + // Reconnect and verify that the connection still works + client + .reconnect() + .await + .expect("Failed to reconnect from client"); + + // Ensure the server is established and then drop it + let mut server = task.await.unwrap(); + + // Test out the connection + client.write_frame(Frame::new(b"hello")).await.unwrap(); + assert_eq!(server.read_frame().await.unwrap().unwrap(), b"hello"); + server.write_frame(Frame::new(b"goodbye")).await.unwrap(); + assert_eq!(client.read_frame().await.unwrap().unwrap(), b"goodbye"); + } +} diff --git a/distant-core/src/manager/data/destination.rs b/distant-net/src/common/destination.rs similarity index 92% rename from distant-core/src/manager/data/destination.rs rename to distant-net/src/common/destination.rs index 4913d67..af01179 100644 --- a/distant-core/src/manager/data/destination.rs +++ b/distant-net/src/common/destination.rs @@ -1,4 +1,4 @@ -use crate::serde_str::{deserialize_from_str, serialize_to_str}; +use super::utils::{deserialize_from_str, serialize_to_str}; use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize}; use std::{fmt, hash::Hash, str::FromStr}; @@ -38,17 +38,8 @@ pub struct Destination { } impl Destination { - /// Returns true if destination represents a distant server - pub fn is_distant(&self) -> bool { - self.scheme_eq("distant") - } - - /// Returns true if destination represents an ssh server - pub fn is_ssh(&self) -> bool { - self.scheme_eq("ssh") - } - - fn scheme_eq(&self, s: &str) -> bool { + /// Returns true if the destination's scheme represents the specified (case-insensitive). + pub fn scheme_eq(&self, s: &str) -> bool { match self.scheme.as_ref() { Some(scheme) => scheme.eq_ignore_ascii_case(s), None => false, @@ -58,13 +49,13 @@ impl Destination { impl AsRef for &Destination { fn as_ref(&self) -> &Destination { - *self + self } } impl AsMut for &mut Destination { fn as_mut(&mut self) -> &mut Destination { - *self + self } } diff --git a/distant-core/src/manager/data/destination/host.rs b/distant-net/src/common/destination/host.rs similarity index 99% rename from distant-core/src/manager/data/destination/host.rs rename to distant-net/src/common/destination/host.rs index 30cd5f7..49d61f7 100644 --- a/distant-core/src/manager/data/destination/host.rs +++ b/distant-net/src/common/destination/host.rs @@ -1,4 +1,4 @@ -use crate::serde_str::{deserialize_from_str, serialize_to_str}; +use super::{deserialize_from_str, serialize_to_str}; use derive_more::{Display, Error, From}; use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize}; use std::{ @@ -109,7 +109,7 @@ impl FromStr for Host { /// ### Examples /// /// ``` - /// # use distant_core::Host; + /// # use distant_net::common::Host; /// # use std::net::{Ipv4Addr, Ipv6Addr}; /// // IPv4 address /// assert_eq!("127.0.0.1".parse(), Ok(Host::Ipv4(Ipv4Addr::new(127, 0, 0, 1)))); diff --git a/distant-core/src/manager/data/destination/parser.rs b/distant-net/src/common/destination/parser.rs similarity index 100% rename from distant-core/src/manager/data/destination/parser.rs rename to distant-net/src/common/destination/parser.rs diff --git a/distant-net/src/listener.rs b/distant-net/src/common/listener.rs similarity index 100% rename from distant-net/src/listener.rs rename to distant-net/src/common/listener.rs diff --git a/distant-net/src/listener/mapped.rs b/distant-net/src/common/listener/mapped.rs similarity index 97% rename from distant-net/src/listener/mapped.rs rename to distant-net/src/common/listener/mapped.rs index 82e0cfc..55c4d51 100644 --- a/distant-net/src/listener/mapped.rs +++ b/distant-net/src/common/listener/mapped.rs @@ -1,4 +1,4 @@ -use crate::Listener; +use super::Listener; use async_trait::async_trait; use std::io; diff --git a/distant-net/src/listener/mpsc.rs b/distant-net/src/common/listener/mpsc.rs similarity index 97% rename from distant-net/src/listener/mpsc.rs rename to distant-net/src/common/listener/mpsc.rs index fe70779..05937a4 100644 --- a/distant-net/src/listener/mpsc.rs +++ b/distant-net/src/common/listener/mpsc.rs @@ -1,4 +1,4 @@ -use crate::Listener; +use super::Listener; use async_trait::async_trait; use derive_more::From; use std::io; diff --git a/distant-net/src/listener/oneshot.rs b/distant-net/src/common/listener/oneshot.rs similarity index 96% rename from distant-net/src/listener/oneshot.rs rename to distant-net/src/common/listener/oneshot.rs index 98d8e01..1db1cd6 100644 --- a/distant-net/src/listener/oneshot.rs +++ b/distant-net/src/common/listener/oneshot.rs @@ -1,4 +1,4 @@ -use crate::Listener; +use super::Listener; use async_trait::async_trait; use derive_more::From; use std::io; @@ -48,9 +48,10 @@ impl Listener for OneshotListener { #[cfg(test)] mod tests { use super::*; + use test_log::test; use tokio::task::JoinHandle; - #[tokio::test] + #[test(tokio::test)] async fn from_value_should_return_value_on_first_call_to_accept() { let mut listener = OneshotListener::from_value("hello world"); assert_eq!(listener.accept().await.unwrap(), "hello world"); @@ -60,7 +61,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn channel_should_return_a_oneshot_sender_to_feed_first_call_to_accept() { let (tx, mut listener) = OneshotListener::channel(); let accept_task: JoinHandle<(io::Result<&str>, io::Result<&str>)> = diff --git a/distant-net/src/listener/tcp.rs b/distant-net/src/common/listener/tcp.rs similarity index 90% rename from distant-net/src/listener/tcp.rs rename to distant-net/src/common/listener/tcp.rs index dc681f4..4160718 100644 --- a/distant-net/src/listener/tcp.rs +++ b/distant-net/src/common/listener/tcp.rs @@ -1,4 +1,5 @@ -use crate::{Listener, PortRange, TcpTransport}; +use super::Listener; +use crate::common::{PortRange, TcpTransport}; use async_trait::async_trait; use std::{fmt, io, net::IpAddr}; use tokio::net::TcpListener as TokioTcpListener; @@ -64,14 +65,12 @@ impl Listener for TcpListener { #[cfg(test)] mod tests { use super::*; + use crate::common::TransportExt; use std::net::{Ipv6Addr, SocketAddr}; - use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - sync::oneshot, - task::JoinHandle, - }; + use test_log::test; + use tokio::{sync::oneshot, task::JoinHandle}; - #[tokio::test] + #[test(tokio::test)] async fn should_fail_to_bind_if_port_already_bound() { let addr = IpAddr::V6(Ipv6Addr::LOCALHOST); let port = 0; // Ephemeral port @@ -91,8 +90,8 @@ mod tests { )); } - #[tokio::test] - async fn should_be_able_to_receive_connections_and_send_and_receive_data_with_them() { + #[test(tokio::test)] + async fn should_be_able_to_receive_connections_and_read_and_write_data_with_them() { let (tx, rx) = oneshot::channel(); // Spawn a task that will wait for two connections and then @@ -109,7 +108,7 @@ mod tests { .map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string()))?; // Get first connection - let mut conn_1 = listener.accept().await?; + let conn_1 = listener.accept().await?; // Send some data to the first connection (12 bytes) conn_1.write_all(b"hello conn 1").await?; @@ -120,7 +119,7 @@ mod tests { assert_eq!(&buf, b"hello server 1"); // Get second connection - let mut conn_2 = listener.accept().await?; + let conn_2 = listener.accept().await?; // Send some data on to second connection (12 bytes) conn_2.write_all(b"hello conn 2").await?; @@ -139,7 +138,7 @@ mod tests { // Connect to the listener twice, sending some bytes and receiving some bytes from each let mut buf: [u8; 12] = [0; 12]; - let mut conn = TcpTransport::connect(&address) + let conn = TcpTransport::connect(&address) .await .expect("Conn 1 failed to connect"); conn.write_all(b"hello server 1") @@ -150,7 +149,7 @@ mod tests { .expect("Conn 1 failed to read"); assert_eq!(&buf, b"hello conn 1"); - let mut conn = TcpTransport::connect(&address) + let conn = TcpTransport::connect(&address) .await .expect("Conn 2 failed to connect"); conn.write_all(b"hello server 2") diff --git a/distant-net/src/listener/unix.rs b/distant-net/src/common/listener/unix.rs similarity index 92% rename from distant-net/src/listener/unix.rs rename to distant-net/src/common/listener/unix.rs index 21f8bac..b46e9a3 100644 --- a/distant-net/src/listener/unix.rs +++ b/distant-net/src/common/listener/unix.rs @@ -1,4 +1,5 @@ -use crate::{Listener, UnixSocketTransport}; +use super::Listener; +use crate::common::UnixSocketTransport; use async_trait::async_trait; use std::{ fmt, io, @@ -94,14 +95,12 @@ impl Listener for UnixSocketListener { #[cfg(test)] mod tests { use super::*; + use crate::common::TransportExt; use tempfile::NamedTempFile; - use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - sync::oneshot, - task::JoinHandle, - }; + use test_log::test; + use tokio::{sync::oneshot, task::JoinHandle}; - #[tokio::test] + #[test(tokio::test)] async fn should_succeed_to_bind_if_file_exists_at_path_but_nothing_listening() { // Generate a socket path let path = NamedTempFile::new() @@ -114,7 +113,7 @@ mod tests { .expect("Unexpectedly failed to bind to existing file"); } - #[tokio::test] + #[test(tokio::test)] async fn should_fail_to_bind_if_socket_already_bound() { // Generate a socket path and delete the file after let path = NamedTempFile::new() @@ -133,8 +132,8 @@ mod tests { .expect_err("Unexpectedly succeeded in binding to same socket"); } - #[tokio::test] - async fn should_be_able_to_receive_connections_and_send_and_receive_data_with_them() { + #[test(tokio::test)] + async fn should_be_able_to_receive_connections_and_read_and_write_data_with_them() { let (tx, rx) = oneshot::channel(); // Spawn a task that will wait for two connections and then @@ -154,7 +153,7 @@ mod tests { .map_err(|x| io::Error::new(io::ErrorKind::Other, x.display().to_string()))?; // Get first connection - let mut conn_1 = listener.accept().await?; + let conn_1 = listener.accept().await?; // Send some data to the first connection (12 bytes) conn_1.write_all(b"hello conn 1").await?; @@ -165,7 +164,7 @@ mod tests { assert_eq!(&buf, b"hello server 1"); // Get second connection - let mut conn_2 = listener.accept().await?; + let conn_2 = listener.accept().await?; // Send some data on to second connection (12 bytes) conn_2.write_all(b"hello conn 2").await?; @@ -184,7 +183,7 @@ mod tests { // Connect to the listener twice, sending some bytes and receiving some bytes from each let mut buf: [u8; 12] = [0; 12]; - let mut conn = UnixSocketTransport::connect(&path) + let conn = UnixSocketTransport::connect(&path) .await .expect("Conn 1 failed to connect"); conn.write_all(b"hello server 1") @@ -195,7 +194,7 @@ mod tests { .expect("Conn 1 failed to read"); assert_eq!(&buf, b"hello conn 1"); - let mut conn = UnixSocketTransport::connect(&path) + let conn = UnixSocketTransport::connect(&path) .await .expect("Conn 2 failed to connect"); conn.write_all(b"hello server 2") diff --git a/distant-net/src/listener/windows.rs b/distant-net/src/common/listener/windows.rs similarity index 89% rename from distant-net/src/listener/windows.rs rename to distant-net/src/common/listener/windows.rs index ef30f4e..e0a12c6 100644 --- a/distant-net/src/listener/windows.rs +++ b/distant-net/src/common/listener/windows.rs @@ -1,4 +1,5 @@ -use crate::{Listener, NamedPipe, WindowsPipeTransport}; +use super::Listener; +use crate::common::{NamedPipe, WindowsPipeTransport}; use async_trait::async_trait; use std::{ ffi::{OsStr, OsString}, @@ -66,13 +67,11 @@ impl Listener for WindowsPipeListener { #[cfg(test)] mod tests { use super::*; - use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - sync::oneshot, - task::JoinHandle, - }; + use crate::common::TransportExt; + use test_log::test; + use tokio::{sync::oneshot, task::JoinHandle}; - #[tokio::test] + #[test(tokio::test)] async fn should_fail_to_bind_if_pipe_already_bound() { // Generate a pipe name let name = format!("test_pipe_{}", rand::random::()); @@ -86,8 +85,8 @@ mod tests { .expect_err("Unexpectedly succeeded in binding to same pipe"); } - #[tokio::test] - async fn should_be_able_to_receive_connections_and_send_and_receive_data_with_them() { + #[test(tokio::test)] + async fn should_be_able_to_receive_connections_and_read_and_write_data_with_them() { let (tx, rx) = oneshot::channel(); // Spawn a task that will wait for two connections and then @@ -104,7 +103,7 @@ mod tests { .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; // Get first connection - let mut conn_1 = listener.accept().await?; + let conn_1 = listener.accept().await?; // Send some data to the first connection (12 bytes) conn_1.write_all(b"hello conn 1").await?; @@ -115,7 +114,7 @@ mod tests { assert_eq!(&buf, b"hello server 1"); // Get second connection - let mut conn_2 = listener.accept().await?; + let conn_2 = listener.accept().await?; // Send some data on to second connection (12 bytes) conn_2.write_all(b"hello conn 2").await?; @@ -134,7 +133,7 @@ mod tests { // Connect to the listener twice, sending some bytes and receiving some bytes from each let mut buf: [u8; 12] = [0; 12]; - let mut conn = WindowsPipeTransport::connect_local(&name) + let conn = WindowsPipeTransport::connect_local(&name) .await .expect("Conn 1 failed to connect"); conn.write_all(b"hello server 1") @@ -145,7 +144,7 @@ mod tests { .expect("Conn 1 failed to read"); assert_eq!(&buf, b"hello conn 1"); - let mut conn = WindowsPipeTransport::connect_local(&name) + let conn = WindowsPipeTransport::connect_local(&name) .await .expect("Conn 2 failed to connect"); conn.write_all(b"hello server 2") diff --git a/distant-core/src/data/map.rs b/distant-net/src/common/map.rs similarity index 97% rename from distant-core/src/data/map.rs rename to distant-net/src/common/map.rs index 2d55185..8410588 100644 --- a/distant-core/src/data/map.rs +++ b/distant-net/src/common/map.rs @@ -1,4 +1,4 @@ -use crate::serde_str::{deserialize_from_str, serialize_to_str}; +use crate::common::utils::{deserialize_from_str, serialize_to_str}; use derive_more::{Display, Error, From, IntoIterator}; use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize}; use std::{ @@ -198,6 +198,13 @@ impl<'de> Deserialize<'de> for Map { } } +/// Generates a new [`Map`] of key/value pairs based on literals. +/// +/// ``` +/// use distant_net::map; +/// +/// let _map = map!("key" -> "value", "key2" -> "value2"); +/// ``` #[macro_export] macro_rules! map { ($($key:literal -> $value:literal),* $(,)?) => {{ @@ -207,7 +214,7 @@ macro_rules! map { _map.insert($key.to_string(), $value.to_string()); )* - $crate::Map::from(_map) + $crate::common::Map::from(_map) }}; } diff --git a/distant-net/src/common/packet.rs b/distant-net/src/common/packet.rs new file mode 100644 index 0000000..f78bbfe --- /dev/null +++ b/distant-net/src/common/packet.rs @@ -0,0 +1,628 @@ +/// Represents a generic id type +pub type Id = String; + +mod request; +mod response; + +pub use request::*; +pub use response::*; + +#[derive(Clone, Debug, PartialEq, Eq)] +enum MsgPackStrParseError { + InvalidFormat, + Utf8Error(std::str::Utf8Error), +} + +/// Writes the given str to the end of `buf` as the str's msgpack representation. +/// +/// # Panics +/// +/// Panics if `s.len() >= 2 ^ 32` as the maximum str length for a msgpack str is `(2 ^ 32) - 1`. +fn write_str_msg_pack(s: &str, buf: &mut Vec) { + assert!( + s.len() < 2usize.pow(32), + "str cannot be longer than (2^32)-1 bytes" + ); + + if s.len() < 32 { + buf.push(s.len() as u8 | 0b10100000); + } else if s.len() < 2usize.pow(8) { + buf.push(0xd9); + buf.push(s.len() as u8); + } else if s.len() < 2usize.pow(16) { + buf.push(0xda); + for b in (s.len() as u16).to_be_bytes() { + buf.push(b); + } + } else { + buf.push(0xdb); + for b in (s.len() as u32).to_be_bytes() { + buf.push(b); + } + } + + buf.extend_from_slice(s.as_bytes()); +} + +/// Parse msgpack str, returning remaining bytes and str on success, or error on failure. +fn parse_msg_pack_str(input: &[u8]) -> Result<(&[u8], &str), MsgPackStrParseError> { + let ilen = input.len(); + if ilen == 0 { + return Err(MsgPackStrParseError::InvalidFormat); + } + + // * fixstr using 0xa0 - 0xbf to mark the start of the str where < 32 bytes + // * str 8 (0xd9) if up to (2^8)-1 bytes, using next byte for len + // * str 16 (0xda) if up to (2^16)-1 bytes, using next two bytes for len + // * str 32 (0xdb) if up to (2^32)-1 bytes, using next four bytes for len + let (input, len): (&[u8], usize) = if input[0] >= 0xa0 && input[0] <= 0xbf { + (&input[1..], (input[0] & 0b00011111).into()) + } else if input[0] == 0xd9 && ilen > 2 { + (&input[2..], input[1].into()) + } else if input[0] == 0xda && ilen > 3 { + (&input[3..], u16::from_be_bytes([input[1], input[2]]).into()) + } else if input[0] == 0xdb && ilen > 5 { + ( + &input[5..], + u32::from_be_bytes([input[1], input[2], input[3], input[4]]) + .try_into() + .unwrap(), + ) + } else { + return Err(MsgPackStrParseError::InvalidFormat); + }; + + let s = match std::str::from_utf8(&input[..len]) { + Ok(s) => s, + Err(x) => return Err(MsgPackStrParseError::Utf8Error(x)), + }; + + Ok((&input[len..], s)) +} + +#[cfg(test)] +mod tests { + use super::*; + + mod write_str_msg_pack { + use super::*; + + #[test] + fn should_support_fixstr() { + // 0-byte str + let mut buf = Vec::new(); + write_str_msg_pack("", &mut buf); + assert_eq!(buf, &[0xa0]); + + // 1-byte str + let mut buf = Vec::new(); + write_str_msg_pack("a", &mut buf); + assert_eq!(buf, &[0xa1, b'a']); + + // 2-byte str + let mut buf = Vec::new(); + write_str_msg_pack("ab", &mut buf); + assert_eq!(buf, &[0xa2, b'a', b'b']); + + // 3-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abc", &mut buf); + assert_eq!(buf, &[0xa3, b'a', b'b', b'c']); + + // 4-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcd", &mut buf); + assert_eq!(buf, &[0xa4, b'a', b'b', b'c', b'd']); + + // 5-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcde", &mut buf); + assert_eq!(buf, &[0xa5, b'a', b'b', b'c', b'd', b'e']); + + // 6-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdef", &mut buf); + assert_eq!(buf, &[0xa6, b'a', b'b', b'c', b'd', b'e', b'f']); + + // 7-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefg", &mut buf); + assert_eq!(buf, &[0xa7, b'a', b'b', b'c', b'd', b'e', b'f', b'g']); + + // 8-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefgh", &mut buf); + assert_eq!(buf, &[0xa8, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h']); + + // 9-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghi", &mut buf); + assert_eq!( + buf, + &[0xa9, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i'] + ); + + // 10-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghij", &mut buf); + assert_eq!( + buf, + &[0xaa, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j'] + ); + + // 11-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijk", &mut buf); + assert_eq!( + buf, + &[0xab, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k'] + ); + + // 12-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijkl", &mut buf); + assert_eq!( + buf, + &[0xac, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l'] + ); + + // 13-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklm", &mut buf); + assert_eq!( + buf, + &[ + 0xad, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm' + ] + ); + + // 14-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmn", &mut buf); + assert_eq!( + buf, + &[ + 0xae, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n' + ] + ); + + // 15-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmno", &mut buf); + assert_eq!( + buf, + &[ + 0xaf, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o' + ] + ); + + // 16-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnop", &mut buf); + assert_eq!( + buf, + &[ + 0xb0, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p' + ] + ); + + // 17-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopq", &mut buf); + assert_eq!( + buf, + &[ + 0xb1, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q' + ] + ); + + // 18-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqr", &mut buf); + assert_eq!( + buf, + &[ + 0xb2, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r' + ] + ); + + // 19-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrs", &mut buf); + assert_eq!( + buf, + &[ + 0xb3, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's' + ] + ); + + // 20-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrst", &mut buf); + assert_eq!( + buf, + &[ + 0xb4, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't' + ] + ); + + // 21-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrstu", &mut buf); + assert_eq!( + buf, + &[ + 0xb5, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u' + ] + ); + + // 22-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrstuv", &mut buf); + assert_eq!( + buf, + &[ + 0xb6, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v' + ] + ); + + // 23-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrstuvw", &mut buf); + assert_eq!( + buf, + &[ + 0xb7, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w' + ] + ); + + // 24-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrstuvwx", &mut buf); + assert_eq!( + buf, + &[ + 0xb8, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x' + ] + ); + + // 25-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrstuvwxy", &mut buf); + assert_eq!( + buf, + &[ + 0xb9, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y' + ] + ); + + // 26-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrstuvwxyz", &mut buf); + assert_eq!( + buf, + &[ + 0xba, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', + b'z' + ] + ); + + // 27-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrstuvwxyz0", &mut buf); + assert_eq!( + buf, + &[ + 0xbb, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', + b'z', b'0' + ] + ); + + // 28-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrstuvwxyz01", &mut buf); + assert_eq!( + buf, + &[ + 0xbc, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', + b'z', b'0', b'1' + ] + ); + + // 29-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrstuvwxyz012", &mut buf); + assert_eq!( + buf, + &[ + 0xbd, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', + b'z', b'0', b'1', b'2' + ] + ); + + // 30-byte str + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrstuvwxyz0123", &mut buf); + assert_eq!( + buf, + &[ + 0xbe, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', + b'z', b'0', b'1', b'2', b'3' + ] + ); + + // 31-byte str is maximum len of fixstr + let mut buf = Vec::new(); + write_str_msg_pack("abcdefghijklmnopqrstuvwxyz01234", &mut buf); + assert_eq!( + buf, + &[ + 0xbf, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', + b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', + b'z', b'0', b'1', b'2', b'3', b'4' + ] + ); + } + + #[test] + fn should_support_str_8() { + let input = "a".repeat(32); + let mut buf = Vec::new(); + write_str_msg_pack(&input, &mut buf); + assert_eq!(buf[0], 0xd9); + assert_eq!(buf[1], input.len() as u8); + assert_eq!(&buf[2..], input.as_bytes()); + + let input = "a".repeat(2usize.pow(8) - 1); + let mut buf = Vec::new(); + write_str_msg_pack(&input, &mut buf); + assert_eq!(buf[0], 0xd9); + assert_eq!(buf[1], input.len() as u8); + assert_eq!(&buf[2..], input.as_bytes()); + } + + #[test] + fn should_support_str_16() { + let input = "a".repeat(2usize.pow(8)); + let mut buf = Vec::new(); + write_str_msg_pack(&input, &mut buf); + assert_eq!(buf[0], 0xda); + assert_eq!(&buf[1..3], &(input.len() as u16).to_be_bytes()); + assert_eq!(&buf[3..], input.as_bytes()); + + let input = "a".repeat(2usize.pow(16) - 1); + let mut buf = Vec::new(); + write_str_msg_pack(&input, &mut buf); + assert_eq!(buf[0], 0xda); + assert_eq!(&buf[1..3], &(input.len() as u16).to_be_bytes()); + assert_eq!(&buf[3..], input.as_bytes()); + } + + #[test] + fn should_support_str_32() { + let input = "a".repeat(2usize.pow(16)); + let mut buf = Vec::new(); + write_str_msg_pack(&input, &mut buf); + assert_eq!(buf[0], 0xdb); + assert_eq!(&buf[1..5], &(input.len() as u32).to_be_bytes()); + assert_eq!(&buf[5..], input.as_bytes()); + } + } + + mod parse_msg_pack_str { + use super::*; + + #[test] + fn should_be_able_to_parse_fixstr() { + // Empty str + let (input, s) = parse_msg_pack_str(&[0xa0]).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, ""); + + // Single character + let (input, s) = parse_msg_pack_str(&[0xa1, b'a']).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, "a"); + + // 31 byte str + let (input, s) = parse_msg_pack_str(&[ + 0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', + ]) + .unwrap(); + assert!(input.is_empty()); + assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + + // Verify that we only consume up to fixstr length + assert_eq!(parse_msg_pack_str(&[0xa0, b'a']).unwrap().0, b"a"); + assert_eq!( + parse_msg_pack_str(&[ + 0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', b'a', b'a', b'b' + ]) + .unwrap() + .0, + b"b" + ); + } + + #[test] + fn should_be_able_to_parse_str_8() { + // 32 byte str + let (input, s) = parse_msg_pack_str(&[ + 0xd9, 32, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', + b'a', b'a', b'a', b'a', b'a', b'a', + ]) + .unwrap(); + assert!(input.is_empty()); + assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + + // 2^8 - 1 (255) byte str + let test_str = "a".repeat(2usize.pow(8) - 1); + let mut input = vec![0xd9, 255]; + input.extend_from_slice(test_str.as_bytes()); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, test_str); + + // Verify that we only consume up to 2^8 - 1 length + let mut input = vec![0xd9, 255]; + input.extend_from_slice(test_str.as_bytes()); + input.extend_from_slice(b"hello"); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert_eq!(input, b"hello"); + assert_eq!(s, test_str); + } + + #[test] + fn should_be_able_to_parse_str_16() { + // 2^8 byte str (256) + let test_str = "a".repeat(2usize.pow(8)); + let mut input = vec![0xda, 1, 0]; + input.extend_from_slice(test_str.as_bytes()); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, test_str); + + // 2^16 - 1 (65535) byte str + let test_str = "a".repeat(2usize.pow(16) - 1); + let mut input = vec![0xda, 255, 255]; + input.extend_from_slice(test_str.as_bytes()); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, test_str); + + // Verify that we only consume up to 2^16 - 1 length + let mut input = vec![0xda, 255, 255]; + input.extend_from_slice(test_str.as_bytes()); + input.extend_from_slice(b"hello"); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert_eq!(input, b"hello"); + assert_eq!(s, test_str); + } + + #[test] + fn should_be_able_to_parse_str_32() { + // 2^16 byte str + let test_str = "a".repeat(2usize.pow(16)); + let mut input = vec![0xdb, 0, 1, 0, 0]; + input.extend_from_slice(test_str.as_bytes()); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, test_str); + + // NOTE: We are not going to run the below tests, not because they aren't valid but + // because this generates a 4GB str which takes 20+ seconds to run + + // 2^32 - 1 byte str (4294967295 bytes) + /* let test_str = "a".repeat(2usize.pow(32) - 1); + let mut input = vec![0xdb, 255, 255, 255, 255]; + input.extend_from_slice(test_str.as_bytes()); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert!(input.is_empty()); + assert_eq!(s, test_str); */ + + // Verify that we only consume up to 2^32 - 1 length + /* let mut input = vec![0xdb, 255, 255, 255, 255]; + input.extend_from_slice(test_str.as_bytes()); + input.extend_from_slice(b"hello"); + let (input, s) = parse_msg_pack_str(&input).unwrap(); + assert_eq!(input, b"hello"); + assert_eq!(s, test_str); */ + } + + #[test] + fn should_fail_parsing_str_with_invalid_length() { + // Make sure that parse doesn't fail looking for bytes after str 8 len + assert_eq!( + parse_msg_pack_str(&[0xd9]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xd9, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + + // Make sure that parse doesn't fail looking for bytes after str 16 len + assert_eq!( + parse_msg_pack_str(&[0xda]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xda, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xda, 0, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + + // Make sure that parse doesn't fail looking for bytes after str 32 len + assert_eq!( + parse_msg_pack_str(&[0xdb]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xdb, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xdb, 0, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xdb, 0, 0, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + assert_eq!( + parse_msg_pack_str(&[0xdb, 0, 0, 0, 0]), + Err(MsgPackStrParseError::InvalidFormat) + ); + } + + #[test] + fn should_fail_parsing_other_types() { + assert_eq!( + parse_msg_pack_str(&[0xc3]), // Boolean (true) + Err(MsgPackStrParseError::InvalidFormat) + ); + } + + #[test] + fn should_fail_if_empty_input() { + assert_eq!( + parse_msg_pack_str(&[]), + Err(MsgPackStrParseError::InvalidFormat) + ); + } + + #[test] + fn should_fail_if_str_is_not_utf8() { + assert!(matches!( + parse_msg_pack_str(&[0xa4, 0, 159, 146, 150]), + Err(MsgPackStrParseError::Utf8Error(_)) + )); + } + } +} diff --git a/distant-net/src/packet/request.rs b/distant-net/src/common/packet/request.rs similarity index 87% rename from distant-net/src/packet/request.rs rename to distant-net/src/common/packet/request.rs index 71511f9..b7b950a 100644 --- a/distant-net/src/packet/request.rs +++ b/distant-net/src/common/packet/request.rs @@ -1,5 +1,6 @@ -use super::{parse_msg_pack_str, Id}; -use crate::utils; +use super::{parse_msg_pack_str, write_str_msg_pack, Id}; +use crate::common::utils; +use derive_more::{Display, Error}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::{borrow::Cow, io, str}; @@ -37,6 +38,14 @@ where pub fn to_payload_vec(&self) -> io::Result> { utils::serialize_to_vec(&self.payload) } + + /// Attempts to convert a typed request to an untyped request + pub fn to_untyped_request(&self) -> io::Result { + Ok(UntypedRequest { + id: Cow::Borrowed(&self.id), + payload: Cow::Owned(self.to_payload_vec()?), + }) + } } impl Request @@ -63,7 +72,7 @@ impl From for Request { } /// Error encountered when attempting to parse bytes as an untyped request -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq, Hash)] pub enum UntypedRequestParseError { /// When the bytes do not represent a request WrongType, @@ -119,6 +128,24 @@ impl<'a> UntypedRequest<'a> { } } + /// Updates the id of the request to the given `id`. + pub fn set_id(&mut self, id: impl Into) { + self.id = Cow::Owned(id.into()); + } + + /// Allocates a new collection of bytes representing the request. + pub fn to_bytes(&self) -> Vec { + let mut bytes = vec![0x82]; + + write_str_msg_pack("id", &mut bytes); + write_str_msg_pack(&self.id, &mut bytes); + + write_str_msg_pack("payload", &mut bytes); + bytes.extend_from_slice(&self.payload); + + bytes + } + /// Parses a collection of bytes, returning a partial request if it can be potentially /// represented as a [`Request`] depending on the payload, or the original bytes if it does not /// represent a [`Request`] @@ -169,6 +196,7 @@ impl<'a> UntypedRequest<'a> { #[cfg(test)] mod tests { use super::*; + use test_log::test; const TRUE_BYTE: u8 = 0xc3; const NEVER_USED_BYTE: u8 = 0xc1; @@ -182,6 +210,19 @@ mod tests { /// fixstr of 4 bytes with str "test" const TEST_STR_BYTES: &[u8] = &[0xa4, 0x74, 0x65, 0x73, 0x74]; + #[test] + fn untyped_request_should_support_converting_to_bytes() { + let bytes = Request { + id: "some id".to_string(), + payload: true, + } + .to_vec() + .unwrap(); + + let untyped_request = UntypedRequest::from_slice(&bytes).unwrap(); + assert_eq!(untyped_request.to_bytes(), bytes); + } + #[test] fn untyped_request_should_support_parsing_from_request_bytes_with_valid_payload() { let bytes = Request { diff --git a/distant-net/src/packet/response.rs b/distant-net/src/common/packet/response.rs similarity index 87% rename from distant-net/src/packet/response.rs rename to distant-net/src/common/packet/response.rs index eea2d66..d50e36f 100644 --- a/distant-net/src/packet/response.rs +++ b/distant-net/src/common/packet/response.rs @@ -1,5 +1,6 @@ -use super::{parse_msg_pack_str, Id}; -use crate::utils; +use super::{parse_msg_pack_str, write_str_msg_pack, Id}; +use crate::common::utils; +use derive_more::{Display, Error}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::{borrow::Cow, io}; @@ -41,6 +42,15 @@ where pub fn to_payload_vec(&self) -> io::Result> { utils::serialize_to_vec(&self.payload) } + + /// Attempts to convert a typed response to an untyped response + pub fn to_untyped_response(&self) -> io::Result { + Ok(UntypedResponse { + id: Cow::Borrowed(&self.id), + origin_id: Cow::Borrowed(&self.origin_id), + payload: Cow::Owned(self.to_payload_vec()?), + }) + } } impl Response @@ -61,7 +71,7 @@ impl Response { } /// Error encountered when attempting to parse bytes as an untyped response -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq, Hash)] pub enum UntypedResponseParseError { /// When the bytes do not represent a response WrongType, @@ -88,7 +98,7 @@ pub struct UntypedResponse<'a> { impl<'a> UntypedResponse<'a> { /// Attempts to convert an untyped request to a typed request - pub fn to_typed_request(&self) -> io::Result> { + pub fn to_typed_response(&self) -> io::Result> { Ok(Response { id: self.id.to_string(), origin_id: self.origin_id.to_string(), @@ -132,9 +142,35 @@ impl<'a> UntypedResponse<'a> { } } + /// Updates the id of the response to the given `id`. + pub fn set_id(&mut self, id: impl Into) { + self.id = Cow::Owned(id.into()); + } + + /// Updates the origin id of the response to the given `origin_id`. + pub fn set_origin_id(&mut self, origin_id: impl Into) { + self.origin_id = Cow::Owned(origin_id.into()); + } + + /// Allocates a new collection of bytes representing the response. + pub fn to_bytes(&self) -> Vec { + let mut bytes = vec![0x83]; + + write_str_msg_pack("id", &mut bytes); + write_str_msg_pack(&self.id, &mut bytes); + + write_str_msg_pack("origin_id", &mut bytes); + write_str_msg_pack(&self.origin_id, &mut bytes); + + write_str_msg_pack("payload", &mut bytes); + bytes.extend_from_slice(&self.payload); + + bytes + } + /// Parses a collection of bytes, returning an untyped response if it can be potentially /// represented as a [`Response`] depending on the payload, or the original bytes if it does not - /// represent a [`Response`] + /// represent a [`Response`]. /// /// NOTE: This supports parsing an invalid response where the payload would not properly /// deserialize, but the bytes themselves represent a complete response of some kind. @@ -198,6 +234,7 @@ impl<'a> UntypedResponse<'a> { #[cfg(test)] mod tests { use super::*; + use test_log::test; const TRUE_BYTE: u8 = 0xc3; const NEVER_USED_BYTE: u8 = 0xc1; @@ -215,6 +252,20 @@ mod tests { /// fixstr of 4 bytes with str "test" const TEST_STR_BYTES: &[u8] = &[0xa4, 0x74, 0x65, 0x73, 0x74]; + #[test] + fn untyped_response_should_support_converting_to_bytes() { + let bytes = Response { + id: "some id".to_string(), + origin_id: "some origin id".to_string(), + payload: true, + } + .to_vec() + .unwrap(); + + let untyped_response = UntypedResponse::from_slice(&bytes).unwrap(); + assert_eq!(untyped_response.to_bytes(), bytes); + } + #[test] fn untyped_response_should_support_parsing_from_response_bytes_with_valid_payload() { let bytes = Response { diff --git a/distant-net/src/port.rs b/distant-net/src/common/port.rs similarity index 100% rename from distant-net/src/port.rs rename to distant-net/src/common/port.rs diff --git a/distant-net/src/common/transport.rs b/distant-net/src/common/transport.rs new file mode 100644 index 0000000..41432c8 --- /dev/null +++ b/distant-net/src/common/transport.rs @@ -0,0 +1,629 @@ +use async_trait::async_trait; +use std::{io, time::Duration}; + +mod framed; +pub use framed::*; + +mod inmemory; +pub use inmemory::*; + +mod tcp; +pub use tcp::*; + +#[cfg(test)] +mod test; + +#[cfg(test)] +pub use test::*; + +#[cfg(unix)] +mod unix; + +#[cfg(unix)] +pub use unix::*; + +#[cfg(windows)] +mod windows; + +#[cfg(windows)] +pub use windows::*; + +pub use tokio::io::{Interest, Ready}; + +/// Duration to wait after WouldBlock received during looping operations like `read_exact`. +const SLEEP_DURATION: Duration = Duration::from_millis(1); + +/// Interface representing a connection that is reconnectable. +#[async_trait] +pub trait Reconnectable { + /// Attempts to reconnect an already-established connection. + async fn reconnect(&mut self) -> io::Result<()>; +} + +/// Interface representing a transport of raw bytes into and out of the system. +#[async_trait] +pub trait Transport: Reconnectable + Send + Sync { + /// Tries to read data from the transport into the provided buffer, returning how many bytes + /// were read. + /// + /// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport + /// is not ready to read data. + /// + /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock + fn try_read(&self, buf: &mut [u8]) -> io::Result; + + /// Try to write a buffer to the transport, returning how many bytes were written. + /// + /// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport + /// is not ready to write data. + /// + /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock + fn try_write(&self, buf: &[u8]) -> io::Result; + + /// Waits for the transport to be ready based on the given interest, returning the ready + /// status. + async fn ready(&self, interest: Interest) -> io::Result; +} + +#[async_trait] +impl Transport for Box { + fn try_read(&self, buf: &mut [u8]) -> io::Result { + Transport::try_read(AsRef::as_ref(self), buf) + } + + fn try_write(&self, buf: &[u8]) -> io::Result { + Transport::try_write(AsRef::as_ref(self), buf) + } + + async fn ready(&self, interest: Interest) -> io::Result { + Transport::ready(AsRef::as_ref(self), interest).await + } +} + +#[async_trait] +impl Reconnectable for Box { + async fn reconnect(&mut self) -> io::Result<()> { + Reconnectable::reconnect(AsMut::as_mut(self)).await + } +} + +#[async_trait] +pub trait TransportExt { + /// Waits for the transport to be readable to follow up with `try_read`. + async fn readable(&self) -> io::Result<()>; + + /// Waits for the transport to be writeable to follow up with `try_write`. + async fn writeable(&self) -> io::Result<()>; + + /// Waits for the transport to be either readable or writeable. + async fn readable_or_writeable(&self) -> io::Result<()>; + + /// Reads exactly `n` bytes where `n` is the length of `buf` by continuing to call [`try_read`] + /// until completed. Calls to [`readable`] are made to ensure the transport is ready. Returns + /// the total bytes read. + /// + /// [`try_read`]: Transport::try_read + /// [`readable`]: Transport::readable + async fn read_exact(&self, buf: &mut [u8]) -> io::Result; + + /// Reads all bytes until EOF in this source, placing them into `buf`. + /// + /// All bytes read from this source will be appended to the specified buffer `buf`. This + /// function will continuously call [`try_read`] to append more data to `buf` until + /// [`try_read`] returns either [`Ok(0)`] or an error that is neither [`Interrupted`] or + /// [`WouldBlock`]. + /// + /// If successful, this function will return the total number of bytes read. + /// + /// ### Errors + /// + /// If this function encounters an error of the kind [`Interrupted`] or [`WouldBlock`], then + /// the error is ignored and the operation will continue. + /// + /// If any other read error is encountered then this function immediately returns. Any bytes + /// which have already been read will be appended to `buf`. + /// + /// [`Ok(0)`]: Ok + /// [`try_read`]: Transport::try_read + /// [`readable`]: Transport::readable + async fn read_to_end(&self, buf: &mut Vec) -> io::Result; + + /// Reads all bytes until EOF in this source, placing them into `buf`. + /// + /// If successful, this function will return the total number of bytes read. + /// + /// ### Errors + /// + /// If the data in this stream is *not* valid UTF-8 then an error is returned and `buf` is + /// unchanged. + /// + /// See [`read_to_end`] for other error semantics. + /// + /// [`Ok(0)`]: Ok + /// [`try_read`]: Transport::try_read + /// [`readable`]: Transport::readable + /// [`read_to_end`]: TransportExt::read_to_end + async fn read_to_string(&self, buf: &mut String) -> io::Result; + + /// Writes all of `buf` by continuing to call [`try_write`] until completed. Calls to + /// [`writeable`] are made to ensure the transport is ready. + /// + /// [`try_write`]: Transport::try_write + /// [`writable`]: Transport::writable + async fn write_all(&self, buf: &[u8]) -> io::Result<()>; +} + +#[async_trait] +impl TransportExt for T { + async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + async fn writeable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + async fn readable_or_writeable(&self) -> io::Result<()> { + self.ready(Interest::READABLE | Interest::WRITABLE).await?; + Ok(()) + } + + async fn read_exact(&self, buf: &mut [u8]) -> io::Result { + let mut i = 0; + + while i < buf.len() { + self.readable().await?; + + match self.try_read(&mut buf[i..]) { + // If we get 0 bytes read, this usually means that the underlying reader + // has closed, so we will return an EOF error to reflect that + // + // NOTE: `try_read` can also return 0 if the buf len is zero, but because we check + // that our index is < len, the situation where we call try_read with a buf + // of len 0 will never happen + Ok(0) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + + Ok(n) => i += n, + + // Because we are using `try_read`, it can be possible for it to return + // WouldBlock; so, if we encounter that then we just wait for next readable + Err(x) if x.kind() == io::ErrorKind::WouldBlock => { + // NOTE: We sleep for a little bit before trying again to avoid pegging CPU + tokio::time::sleep(SLEEP_DURATION).await + } + + Err(x) => return Err(x), + } + } + + Ok(i) + } + + async fn read_to_end(&self, buf: &mut Vec) -> io::Result { + let mut i = 0; + let mut tmp = [0u8; 1024]; + + loop { + self.readable().await?; + + match self.try_read(&mut tmp) { + Ok(0) => return Ok(i), + Ok(n) => { + buf.extend_from_slice(&tmp[..n]); + i += n; + } + Err(x) + if x.kind() == io::ErrorKind::WouldBlock + || x.kind() == io::ErrorKind::Interrupted => + { + // NOTE: We sleep for a little bit before trying again to avoid pegging CPU + tokio::time::sleep(SLEEP_DURATION).await + } + + Err(x) => return Err(x), + } + } + } + + async fn read_to_string(&self, buf: &mut String) -> io::Result { + let mut tmp = Vec::new(); + let n = self.read_to_end(&mut tmp).await?; + buf.push_str( + &String::from_utf8(tmp).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?, + ); + Ok(n) + } + + async fn write_all(&self, buf: &[u8]) -> io::Result<()> { + let mut i = 0; + + while i < buf.len() { + self.writeable().await?; + + match self.try_write(&buf[i..]) { + // If we get 0 bytes written, this usually means that the underlying writer + // has closed, so we will return a write zero error to reflect that + // + // NOTE: `try_write` can also return 0 if the buf len is zero, but because we check + // that our index is < len, the situation where we call try_write with a buf + // of len 0 will never happen + Ok(0) => return Err(io::Error::from(io::ErrorKind::WriteZero)), + + Ok(n) => i += n, + + // Because we are using `try_write`, it can be possible for it to return + // WouldBlock; so, if we encounter that then we just wait for next writeable + Err(x) if x.kind() == io::ErrorKind::WouldBlock => { + // NOTE: We sleep for a little bit before trying again to avoid pegging CPU + tokio::time::sleep(SLEEP_DURATION).await + } + + Err(x) => return Err(x), + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + #[test(tokio::test)] + async fn read_exact_should_fail_if_try_read_encounters_error_other_than_would_block() { + let transport = TestTransport { + f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + let mut buf = [0; 1]; + assert_eq!( + transport.read_exact(&mut buf).await.unwrap_err().kind(), + io::ErrorKind::NotConnected + ); + } + + #[test(tokio::test)] + async fn read_exact_should_fail_if_try_read_returns_0_before_necessary_bytes_read() { + let transport = TestTransport { + f_try_read: Box::new(|_| Ok(0)), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + let mut buf = [0; 1]; + assert_eq!( + transport.read_exact(&mut buf).await.unwrap_err().kind(), + io::ErrorKind::UnexpectedEof + ); + } + + #[test(tokio::test)] + async fn read_exact_should_continue_to_call_try_read_until_buffer_is_filled() { + let transport = TestTransport { + f_try_read: Box::new(|buf| { + static mut CNT: u8 = 0; + unsafe { + buf[0] = b'a' + CNT; + CNT += 1; + } + Ok(1) + }), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + let mut buf = [0; 3]; + assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3); + assert_eq!(&buf, b"abc"); + } + + #[test(tokio::test)] + async fn read_exact_should_continue_to_call_try_read_while_it_returns_would_block() { + // Configure `try_read` to alternate between reading a byte and WouldBlock + let transport = TestTransport { + f_try_read: Box::new(|buf| { + static mut CNT: u8 = 0; + unsafe { + buf[0] = b'a' + CNT; + CNT += 1; + if CNT % 2 == 1 { + Ok(1) + } else { + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } + } + }), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + let mut buf = [0; 3]; + assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3); + assert_eq!(&buf, b"ace"); + } + + #[test(tokio::test)] + async fn read_exact_should_return_0_if_given_a_buffer_of_0_len() { + let transport = TestTransport { + f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + let mut buf = [0; 0]; + assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 0); + } + + #[test(tokio::test)] + async fn read_to_end_should_fail_if_try_read_encounters_error_other_than_would_block_and_interrupt( + ) { + let transport = TestTransport { + f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + assert_eq!( + transport + .read_to_end(&mut Vec::new()) + .await + .unwrap_err() + .kind(), + io::ErrorKind::NotConnected + ); + } + + #[test(tokio::test)] + async fn read_to_end_should_read_until_0_bytes_returned_from_try_read() { + let transport = TestTransport { + f_try_read: Box::new(|buf| { + static mut CNT: u8 = 0; + unsafe { + if CNT == 0 { + buf[..5].copy_from_slice(b"hello"); + CNT += 1; + Ok(5) + } else { + Ok(0) + } + } + }), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + let mut buf = Vec::new(); + assert_eq!(transport.read_to_end(&mut buf).await.unwrap(), 5); + assert_eq!(buf, b"hello"); + } + + #[test(tokio::test)] + async fn read_to_end_should_continue_reading_when_interrupt_or_would_block_encountered() { + let transport = TestTransport { + f_try_read: Box::new(|buf| { + static mut CNT: u8 = 0; + unsafe { + CNT += 1; + if CNT == 1 { + buf[..6].copy_from_slice(b"hello "); + Ok(6) + } else if CNT == 2 { + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } else if CNT == 3 { + buf[..5].copy_from_slice(b"world"); + Ok(5) + } else if CNT == 4 { + Err(io::Error::from(io::ErrorKind::Interrupted)) + } else if CNT == 5 { + buf[..6].copy_from_slice(b", test"); + Ok(6) + } else { + Ok(0) + } + } + }), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + let mut buf = Vec::new(); + assert_eq!(transport.read_to_end(&mut buf).await.unwrap(), 17); + assert_eq!(buf, b"hello world, test"); + } + + #[test(tokio::test)] + async fn read_to_string_should_fail_if_try_read_encounters_error_other_than_would_block_and_interrupt( + ) { + let transport = TestTransport { + f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + assert_eq!( + transport + .read_to_string(&mut String::new()) + .await + .unwrap_err() + .kind(), + io::ErrorKind::NotConnected + ); + } + + #[test(tokio::test)] + async fn read_to_string_should_fail_if_non_utf8_characters_read() { + let transport = TestTransport { + f_try_read: Box::new(|buf| { + static mut CNT: u8 = 0; + unsafe { + if CNT == 0 { + buf[0] = 0; + buf[1] = 159; + buf[2] = 146; + buf[3] = 150; + CNT += 1; + Ok(4) + } else { + Ok(0) + } + } + }), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + let mut buf = String::new(); + assert_eq!( + transport.read_to_string(&mut buf).await.unwrap_err().kind(), + io::ErrorKind::InvalidData + ); + } + + #[test(tokio::test)] + async fn read_to_string_should_read_until_0_bytes_returned_from_try_read() { + let transport = TestTransport { + f_try_read: Box::new(|buf| { + static mut CNT: u8 = 0; + unsafe { + if CNT == 0 { + buf[..5].copy_from_slice(b"hello"); + CNT += 1; + Ok(5) + } else { + Ok(0) + } + } + }), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + let mut buf = String::new(); + assert_eq!(transport.read_to_string(&mut buf).await.unwrap(), 5); + assert_eq!(buf, "hello"); + } + + #[test(tokio::test)] + async fn read_to_string_should_continue_reading_when_interrupt_or_would_block_encountered() { + let transport = TestTransport { + f_try_read: Box::new(|buf| { + static mut CNT: u8 = 0; + unsafe { + CNT += 1; + if CNT == 1 { + buf[..6].copy_from_slice(b"hello "); + Ok(6) + } else if CNT == 2 { + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } else if CNT == 3 { + buf[..5].copy_from_slice(b"world"); + Ok(5) + } else if CNT == 4 { + Err(io::Error::from(io::ErrorKind::Interrupted)) + } else if CNT == 5 { + buf[..6].copy_from_slice(b", test"); + Ok(6) + } else { + Ok(0) + } + } + }), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }; + + let mut buf = String::new(); + assert_eq!(transport.read_to_string(&mut buf).await.unwrap(), 17); + assert_eq!(buf, "hello world, test"); + } + + #[test(tokio::test)] + async fn write_all_should_fail_if_try_write_encounters_error_other_than_would_block() { + let transport = TestTransport { + f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }; + + assert_eq!( + transport.write_all(b"abc").await.unwrap_err().kind(), + io::ErrorKind::NotConnected + ); + } + + #[test(tokio::test)] + async fn write_all_should_fail_if_try_write_returns_0_before_all_bytes_written() { + let transport = TestTransport { + f_try_write: Box::new(|_| Ok(0)), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }; + + assert_eq!( + transport.write_all(b"abc").await.unwrap_err().kind(), + io::ErrorKind::WriteZero + ); + } + + #[test(tokio::test)] + async fn write_all_should_continue_to_call_try_write_until_all_bytes_written() { + // Configure `try_write` to alternate between writing a byte and WouldBlock + let transport = TestTransport { + f_try_write: Box::new(|buf| { + static mut CNT: u8 = 0; + unsafe { + assert_eq!(buf[0], b'a' + CNT); + CNT += 1; + Ok(1) + } + }), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }; + + transport.write_all(b"abc").await.unwrap(); + } + + #[test(tokio::test)] + async fn write_all_should_continue_to_call_try_write_while_it_returns_would_block() { + // Configure `try_write` to alternate between writing a byte and WouldBlock + let transport = TestTransport { + f_try_write: Box::new(|buf| { + static mut CNT: u8 = 0; + unsafe { + if CNT % 2 == 0 { + assert_eq!(buf[0], b'a' + CNT); + CNT += 1; + Ok(1) + } else { + CNT += 1; + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } + } + }), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }; + + transport.write_all(b"ace").await.unwrap(); + } + + #[test(tokio::test)] + async fn write_all_should_return_immediately_if_given_buffer_of_0_len() { + let transport = TestTransport { + f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }; + + // No error takes place as we never call try_write + let buf = [0; 0]; + transport.write_all(&buf).await.unwrap(); + } +} diff --git a/distant-net/src/common/transport/framed.rs b/distant-net/src/common/transport/framed.rs new file mode 100644 index 0000000..fb57710 --- /dev/null +++ b/distant-net/src/common/transport/framed.rs @@ -0,0 +1,2237 @@ +use super::{InmemoryTransport, Interest, Ready, Reconnectable, Transport}; +use crate::common::utils; +use async_trait::async_trait; +use bytes::{Buf, BytesMut}; +use log::*; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::{fmt, future::Future, io, time::Duration}; + +mod backup; +mod codec; +mod exchange; +mod frame; +mod handshake; + +pub use backup::*; +pub use codec::*; +pub use exchange::*; +pub use frame::*; +pub use handshake::*; + +/// Size of the read buffer when reading bytes to construct a frame +const READ_BUF_SIZE: usize = 8 * 1024; + +/// Duration to wait after WouldBlock received during looping operations like `read_frame` +const SLEEP_DURATION: Duration = Duration::from_millis(1); + +/// Represents a wrapper around a [`Transport`] that reads and writes using frames defined by a +/// [`Codec`]. +/// +/// [`try_read`]: Transport::try_read +#[derive(Clone)] +pub struct FramedTransport { + /// Inner transport wrapped to support frames of data + inner: T, + + /// Codec used to encoding outgoing bytes and decode incoming bytes + codec: BoxedCodec, + + /// Bytes in queue to be read + incoming: BytesMut, + + /// Bytes in queue to be written + outgoing: BytesMut, + + /// Stores outgoing frames in case of transmission issues + pub backup: Backup, +} + +impl FramedTransport { + pub fn new(inner: T, codec: BoxedCodec) -> Self { + Self { + inner, + codec, + incoming: BytesMut::with_capacity(READ_BUF_SIZE * 2), + outgoing: BytesMut::with_capacity(READ_BUF_SIZE * 2), + backup: Backup::new(), + } + } + + /// Creates a new [`FramedTransport`] using the [`PlainCodec`] + pub fn plain(inner: T) -> Self { + Self::new(inner, Box::new(PlainCodec::new())) + } + + /// Replaces the current codec with the provided codec. Note that any bytes in the incoming or + /// outgoing buffers will remain in the transport, meaning that this can cause corruption if + /// the bytes in the buffers do not match the new codec. + /// + /// For safety, use [`clear`] to wipe the buffers before further use. + /// + /// [`clear`]: FramedTransport::clear + pub fn set_codec(&mut self, codec: BoxedCodec) { + self.codec = codec; + } + + /// Returns a reference to the codec used by the transport. + /// + /// ### Note + /// + /// Be careful when accessing the codec to avoid corrupting it through unexpected modifications + /// as this will place the transport in an undefined state. + pub fn codec(&self) -> &dyn Codec { + self.codec.as_ref() + } + + /// Returns a mutable reference to the codec used by the transport. + /// + /// ### Note + /// + /// Be careful when accessing the codec to avoid corrupting it through unexpected modifications + /// as this will place the transport in an undefined state. + pub fn mut_codec(&mut self) -> &mut dyn Codec { + self.codec.as_mut() + } + + /// Clears the internal transport buffers. + pub fn clear(&mut self) { + self.incoming.clear(); + self.outgoing.clear(); + } + + /// Returns a reference to the inner value this transport wraps. + pub fn as_inner(&self) -> &T { + &self.inner + } + + /// Returns a mutable reference to the inner value this transport wraps. + pub fn as_mut_inner(&mut self) -> &mut T { + &mut self.inner + } + + /// Consumes this transport, returning the inner value that it wraps. + pub fn into_inner(self) -> T { + self.inner + } +} + +impl fmt::Debug for FramedTransport { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FramedTransport") + .field("incoming", &self.incoming) + .field("outgoing", &self.outgoing) + .field("backup", &self.backup) + .finish() + } +} + +impl FramedTransport { + /// Converts this instance to a [`FramedTransport`] whose inner [`Transport`] is [`Box`]ed. + pub fn into_boxed(self) -> FramedTransport> { + FramedTransport { + inner: Box::new(self.inner), + codec: self.codec, + incoming: self.incoming, + outgoing: self.outgoing, + backup: self.backup, + } + } +} + +impl FramedTransport { + /// Waits for the transport to be ready based on the given interest, returning the ready status + pub async fn ready(&self, interest: Interest) -> io::Result { + // If interest includes reading, we check if we already have a frame in our queue, + // as there can be a scenario where a frame was received and then the connection + // was closed, and we still want to be able to read the next frame is if it is + // available in the connection. + let ready = if interest.is_readable() && Frame::available(&self.incoming) { + Ready::READABLE + } else { + Ready::EMPTY + }; + + // If we know that we are readable and not checking for write status, we can short-circuit + // to avoid an async call by returning immediately that we are readable + if !interest.is_writable() && ready.is_readable() { + return Ok(ready); + } + + // Otherwise, we need to check the status using the underlying transport and merge it with + // our current understanding based on internal state + Transport::ready(&self.inner, interest) + .await + .map(|r| r | ready) + } + + /// Waits for the transport to be readable to follow up with [`try_read_frame`]. + /// + /// [`try_read_frame`]: FramedTransport::try_read_frame + pub async fn readable(&self) -> io::Result<()> { + let _ = self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Waits for the transport to be writeable to follow up with [`try_write_frame`]. + /// + /// [`try_write_frame`]: FramedTransport::try_write_frame + pub async fn writeable(&self) -> io::Result<()> { + let _ = self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Waits for the transport to be readable or writeable, returning the [`Ready`] status. + pub async fn readable_or_writeable(&self) -> io::Result { + self.ready(Interest::READABLE | Interest::WRITABLE).await + } + + /// Attempts to flush any remaining bytes in the outgoing queue, returning the total bytes + /// written as a result of the flush. Note that a return of 0 bytes does not indicate that the + /// underlying transport has closed, but rather that no bytes were flushed such as when the + /// outgoing queue is empty. + /// + /// This is accomplished by continually calling the inner transport's `try_write`. If 0 is + /// returned from a call to `try_write`, this will fail with [`ErrorKind::WriteZero`]. + /// + /// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport + /// is not ready to write data. + /// + /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock + pub fn try_flush(&mut self) -> io::Result { + let mut bytes_written = 0; + + // Continue to send from the outgoing buffer until we either finish or fail + while !self.outgoing.is_empty() { + match self.inner.try_write(self.outgoing.as_ref()) { + // Getting 0 bytes on write indicates the channel has closed + Ok(0) => return Err(io::Error::from(io::ErrorKind::WriteZero)), + + // Successful write will advance the outgoing buffer + Ok(n) => { + self.outgoing.advance(n); + bytes_written += n; + } + + // Any error (including WouldBlock) will get bubbled up + Err(x) => return Err(x), + } + } + + Ok(bytes_written) + } + + /// Flushes all buffered, outgoing bytes using repeated calls to [`try_flush`]. + /// + /// [`try_flush`]: FramedTransport::try_flush + pub async fn flush(&mut self) -> io::Result<()> { + while !self.outgoing.is_empty() { + self.writeable().await?; + match self.try_flush() { + Err(x) if x.kind() == io::ErrorKind::WouldBlock => { + // NOTE: We sleep for a little bit before trying again to avoid pegging CPU + tokio::time::sleep(SLEEP_DURATION).await + } + Err(x) => return Err(x), + Ok(_) => return Ok(()), + } + } + + Ok(()) + } + + /// Reads a frame of bytes by using the [`Codec`] tied to this transport. Returns + /// `Ok(Some(frame))` upon reading a frame, or `Ok(None)` if the underlying transport has + /// closed. + /// + /// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport + /// is not ready to read data or has not received a full frame before waiting. + /// + /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock + pub fn try_read_frame(&mut self) -> io::Result> { + // Attempt to read a frame, returning the decoded frame if we get one, returning any error + // that is encountered from reading frames or failing to decode, or otherwise doing nothing + // and continuing forward. + macro_rules! read_next_frame { + () => {{ + match Frame::read(&mut self.incoming) { + Ok(None) => (), + Ok(Some(frame)) => { + self.backup.increment_received_cnt(); + return Ok(Some(self.codec.decode(frame)?.into_owned())); + } + Err(x) => return Err(x), + } + }}; + } + + // If we have data remaining in the buffer, we first try to parse it in case we received + // multiple frames from a previous call. + // + // NOTE: This exists to avoid the situation where there is a valid frame remaining in the + // incoming buffer, but it is never evaluated because a call to `try_read` returns + // `WouldBlock`, 0 bytes, or some other error. + if !self.incoming.is_empty() { + read_next_frame!(); + } + + // Continually read bytes into the incoming queue and then attempt to tease out a frame + let mut buf = [0; READ_BUF_SIZE]; + + loop { + match self.inner.try_read(&mut buf) { + // Getting 0 bytes on read indicates the channel has closed. If we were still + // expecting more bytes for our frame, then this is an error, otherwise if we + // have nothing remaining if our queue then this is an expected end and we + // return None + Ok(0) if self.incoming.is_empty() => return Ok(None), + Ok(0) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + + // Got some additional bytes, which we will add to our queue and then attempt to + // decode into a frame + Ok(n) => { + self.incoming.extend_from_slice(&buf[..n]); + read_next_frame!(); + } + + // Any error (including WouldBlock) will get bubbled up + Err(x) => return Err(x), + } + } + } + + /// Reads a frame using [`try_read_frame`] and then deserializes the bytes into `D`. + /// + /// [`try_read_frame`]: FramedTransport::try_read_frame + pub fn try_read_frame_as(&mut self) -> io::Result> { + match self.try_read_frame() { + Ok(Some(frame)) => Ok(Some(utils::deserialize_from_slice(frame.as_item())?)), + Ok(None) => Ok(None), + Err(x) => Err(x), + } + } + + /// Continues to invoke [`try_read_frame`] until a frame is successfully read, an error is + /// encountered that is not [`ErrorKind::WouldBlock`], or the underlying transport has closed. + /// + /// [`try_read_frame`]: FramedTransport::try_read_frame + /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock + pub async fn read_frame(&mut self) -> io::Result> { + loop { + self.readable().await?; + + match self.try_read_frame() { + Err(x) if x.kind() == io::ErrorKind::WouldBlock => { + // NOTE: We sleep for a little bit before trying again to avoid pegging CPU + tokio::time::sleep(SLEEP_DURATION).await + } + x => return x, + } + } + } + + /// Reads a frame using [`read_frame`] and then deserializes the bytes into `D`. + /// + /// [`read_frame`]: FramedTransport::read_frame + pub async fn read_frame_as(&mut self) -> io::Result> { + match self.read_frame().await { + Ok(Some(frame)) => Ok(Some(utils::deserialize_from_slice(frame.as_item())?)), + Ok(None) => Ok(None), + Err(x) => Err(x), + } + } + + /// Writes a `frame` of bytes by using the [`Codec`] tied to this transport. + /// + /// This is accomplished by continually calling the inner transport's `try_write`. If 0 is + /// returned from a call to `try_write`, this will fail with [`ErrorKind::WriteZero`]. + /// + /// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport + /// is not ready to write data or has not written the entire frame before waiting. + /// + /// [`ErrorKind::WriteZero`]: io::ErrorKind::WriteZero + /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock + pub fn try_write_frame<'a, F>(&mut self, frame: F) -> io::Result<()> + where + F: TryInto>, + F::Error: Into>, + { + // Grab the frame to send + let frame = frame + .try_into() + .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?; + + // Encode the frame and store it in our outgoing queue + self.codec + .encode(frame.as_borrowed())? + .write(&mut self.outgoing)?; + + // Once the frame enters our queue, we count it as written, even if it isn't fully flushed + self.backup.increment_sent_cnt(); + + // Then we store the raw frame (non-encoded) for the future in case we need to retry + // sending it later (possibly with a different codec) + self.backup.push_frame(frame); + + // Attempt to write everything in our queue + self.try_flush()?; + + Ok(()) + } + + /// Serializes `value` into bytes and passes them to [`try_write_frame`]. + /// + /// [`try_write_frame`]: FramedTransport::try_write_frame + pub fn try_write_frame_for(&mut self, value: &D) -> io::Result<()> { + let data = utils::serialize_to_vec(value)?; + self.try_write_frame(data) + } + + /// Invokes [`try_write_frame`] followed by a continuous calls to [`try_flush`] until a frame + /// is successfully written, an error is encountered that is not [`ErrorKind::WouldBlock`], or + /// the underlying transport has closed. + /// + /// [`try_write_frame`]: FramedTransport::try_write_frame + /// [`try_flush`]: FramedTransport::try_flush + /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock + pub async fn write_frame<'a, F>(&mut self, frame: F) -> io::Result<()> + where + F: TryInto>, + F::Error: Into>, + { + self.writeable().await?; + + match self.try_write_frame(frame) { + // Would block, so continually try to flush until good to go + Err(x) if x.kind() == io::ErrorKind::WouldBlock => loop { + self.writeable().await?; + match self.try_flush() { + Err(x) if x.kind() == io::ErrorKind::WouldBlock => { + // NOTE: We sleep for a little bit before trying again to avoid pegging CPU + tokio::time::sleep(SLEEP_DURATION).await + } + Err(x) => return Err(x), + Ok(_) => return Ok(()), + } + }, + + // Already fully succeeded or failed + x => x, + } + } + + /// Serializes `value` into bytes and passes them to [`write_frame`]. + /// + /// [`write_frame`]: FramedTransport::write_frame + pub async fn write_frame_for(&mut self, value: &D) -> io::Result<()> { + let data = utils::serialize_to_vec(value)?; + self.write_frame(data).await + } + + /// Executes the async function while the [`Backup`] of this transport is frozen. + pub async fn do_frozen(&mut self, mut f: F) -> io::Result<()> + where + F: FnMut(&mut Self) -> X, + X: Future>, + { + let is_frozen = self.backup.is_frozen(); + self.backup.freeze(); + let result = f(self).await; + self.backup.set_frozen(is_frozen); + result + } + + /// Places the transport in **synchronize mode** where it communicates with the other side how + /// many frames have been sent and received. From there, any frames not received by the other + /// side are sent again and then this transport waits for any missing frames that it did not + /// receive from the other side. + /// + /// ### Note + /// + /// This will clear the internal incoming and outgoing buffers, so any frame that was in + /// transit in either direction will be dropped. + pub async fn synchronize(&mut self) -> io::Result<()> { + async fn synchronize_impl( + this: &mut FramedTransport, + backup: &mut Backup, + ) -> io::Result<()> { + type Stats = (u64, u64, u64); + + // Stats in the form of (sent, received, available) + let sent_cnt: u64 = backup.sent_cnt(); + let received_cnt: u64 = backup.received_cnt(); + let available_cnt: u64 = backup + .frame_cnt() + .try_into() + .expect("Cannot case usize to u64"); + + // Clear our internal buffers + this.clear(); + + // Communicate frame counters with other side so we can determine how many frames to send + // and how many to receive. Wait until we get the stats from the other side, and then send + // over any missing frames. + trace!( + "Stats: sent = {sent_cnt}, received = {received_cnt}, available = {available_cnt}" + ); + this.write_frame_for(&(sent_cnt, received_cnt, available_cnt)) + .await?; + let (other_sent_cnt, other_received_cnt, other_available_cnt) = + this.read_frame_as::().await?.ok_or_else(|| { + io::Error::new( + io::ErrorKind::UnexpectedEof, + "Transport terminated before getting replay stats", + ) + })?; + trace!("Other stats: sent = {other_sent_cnt}, received = {other_received_cnt}, available = {other_available_cnt}"); + + // Determine how many frames we need to resend. This will either be (sent - received) or + // available frames, whichever is smaller. + let resend_cnt = std::cmp::min( + if sent_cnt > other_received_cnt { + sent_cnt - other_received_cnt + } else { + 0 + }, + available_cnt, + ); + + // Determine how many frames we expect to receive. This will either be (received - sent) or + // available frames, whichever is smaller. + let expected_cnt = std::cmp::min( + if received_cnt < other_sent_cnt { + other_sent_cnt - received_cnt + } else { + 0 + }, + other_available_cnt, + ); + + // Send all missing frames, removing any frames that we know have been received + trace!("Reducing internal replay frames to {resend_cnt}"); + backup.truncate_front(resend_cnt.try_into().expect("Cannot cast usize to u64")); + + debug!("Sending {resend_cnt} frames"); + for frame in backup.frames() { + this.try_write_frame(frame.as_borrowed())?; + } + this.flush().await?; + + // Receive all expected frames, placing their contents into our incoming queue + // + // NOTE: We do not increment our counter as this is done during `try_read_frame`, even + // when the frame comes from our internal queue. To avoid duplicating the increment, + // we do not increment the counter here. + debug!("Waiting for {expected_cnt} frames"); + for i in 0..expected_cnt { + let frame = this.read_frame().await?.ok_or_else(|| { + io::Error::new( + io::ErrorKind::UnexpectedEof, + format!( + "Transport terminated before getting frame {}/{expected_cnt}", + i + 1 + ), + ) + })?; + + // Encode our frame and write it to be queued in our incoming data + // NOTE: We have to do encoding here as incoming bytes are expected to be encoded + this.codec.encode(frame)?.write(&mut this.incoming)?; + } + + // Catch up our read count as we can have the case where the other side has a higher + // count than frames sent if some frames were fully dropped due to size limits + if backup.received_cnt() != other_sent_cnt { + warn!( + "Backup received count ({}) != other sent count ({}), so resetting to match", + backup.received_cnt(), + other_sent_cnt + ); + backup.set_received_cnt(other_sent_cnt); + } + + Ok(()) + } + + // Swap out our backup so we don't mutate it from synchronization efforts + let mut backup = std::mem::take(&mut self.backup); + + // Perform our operation, but don't return immediately so we can restore our backup + let result = synchronize_impl(self, &mut backup).await; + + // Reset our backup to the real version + self.backup = backup; + + result + } + + /// Shorthand for creating a [`FramedTransport`] with a [`PlainCodec`] and then immediately + /// performing a [`client_handshake`], returning the updated [`FramedTransport`] on success. + /// + /// [`client_handshake`]: FramedTransport::client_handshake + #[inline] + pub async fn from_client_handshake(transport: T) -> io::Result { + let mut transport = Self::plain(transport); + transport.client_handshake().await?; + Ok(transport) + } + + /// Perform the client-side of a handshake. See [`handshake`] for more details. + /// + /// [`handshake`]: FramedTransport::handshake + pub async fn client_handshake(&mut self) -> io::Result<()> { + self.handshake(Handshake::client()).await + } + /// Shorthand for creating a [`FramedTransport`] with a [`PlainCodec`] and then immediately + /// performing a [`server_handshake`], returning the updated [`FramedTransport`] on success. + /// + /// [`client_handshake`]: FramedTransport::client_handshake + #[inline] + pub async fn from_server_handshake(transport: T) -> io::Result { + let mut transport = Self::plain(transport); + transport.server_handshake().await?; + Ok(transport) + } + + /// Perform the server-side of a handshake. See [`handshake`] for more details. + /// + /// [`handshake`]: FramedTransport::handshake + pub async fn server_handshake(&mut self) -> io::Result<()> { + self.handshake(Handshake::server()).await + } + + /// Performs a handshake in order to establish a new codec to use between this transport and + /// the other side. The parameter `handshake` defines how the transport will handle the + /// handshake with `Client` being used to pick the compression and encryption used while + /// `Server` defines what the choices are for compression and encryption. + /// + /// This will reset the framed transport's codec to [`PlainCodec`] in order to communicate + /// which compression and encryption to use. Upon selecting an encryption type, a shared secret + /// key will be derived on both sides and used to establish the [`EncryptionCodec`], which in + /// combination with the [`CompressionCodec`] (if any) will replace this transport's codec. + /// + /// ### Client + /// + /// 1. Wait for options from server + /// 2. Send to server a compression and encryption choice + /// 3. Configure framed transport using selected choices + /// 4. Invoke on_handshake function + /// + /// ### Server + /// + /// 1. Send options to client + /// 2. Receive choices from client + /// 3. Configure framed transport using client's choices + /// 4. Invoke on_handshake function + /// + /// ### Failure + /// + /// The handshake will fail in several cases: + /// + /// * If any frame during the handshake fails to be serialized + /// * If any unexpected frame is received during the handshake + /// * If using encryption and unable to derive a shared secret key + /// + /// If a failure happens, the codec will be reset to what it was prior to the handshake + /// request, and all internal buffers will be cleared to avoid corruption. + /// + pub async fn handshake(&mut self, handshake: Handshake) -> io::Result<()> { + // Place transport in plain text communication mode for start of handshake, and clear any + // data that is lingering within internal buffers + // + // NOTE: We grab the old codec in case we encounter an error and need to reset it + let old_codec = std::mem::replace(&mut self.codec, Box::new(PlainCodec::new())); + self.clear(); + + // Swap out our backup so we don't mutate it from synchronization efforts + let backup = std::mem::take(&mut self.backup); + + // Transform the transport's codec to abide by the choice. In the case of an error, we + // reset the codec back to what it was prior to attempting the handshake and clear the + // internal buffers as they may be corrupt. + match self.handshake_impl(handshake).await { + Ok(codec) => { + self.set_codec(codec); + self.backup = backup; + Ok(()) + } + Err(x) => { + self.set_codec(old_codec); + self.clear(); + self.backup = backup; + Err(x) + } + } + } + + async fn handshake_impl(&mut self, handshake: Handshake) -> io::Result { + #[derive(Debug, Serialize, Deserialize)] + struct Choice { + compression_level: Option, + compression_type: Option, + encryption_type: Option, + } + + #[derive(Debug, Serialize, Deserialize)] + struct Options { + compression_types: Vec, + encryption_types: Vec, + } + + // Define a label to distinguish log output for client and server + let log_label = if handshake.is_client() { + "Handshake | Client" + } else { + "Handshake | Server" + }; + + // Determine compression and encryption to apply to framed transport + let choice = match handshake { + Handshake::Client { + preferred_compression_type, + preferred_compression_level, + preferred_encryption_type, + } => { + // Receive options from the server and pick one + debug!("[{log_label}] Waiting on options"); + let options = self.read_frame_as::().await?.ok_or_else(|| { + io::Error::new( + io::ErrorKind::UnexpectedEof, + "Transport closed early while waiting for options", + ) + })?; + + // Choose a compression and encryption option from the options + debug!("[{log_label}] Selecting from options: {options:#?}"); + let choice = Choice { + // Use preferred compression if available, otherwise default to no compression + // to avoid choosing something poor + compression_type: preferred_compression_type + .filter(|ty| options.compression_types.contains(ty)), + + // Use preferred compression level, otherwise allowing the server to pick + compression_level: preferred_compression_level, + + // Use preferred encryption, otherwise pick first non-unknown encryption type + // that is available instead + encryption_type: preferred_encryption_type + .filter(|ty| options.encryption_types.contains(ty)) + .or_else(|| { + options + .encryption_types + .iter() + .find(|ty| !ty.is_unknown()) + .copied() + }), + }; + + // Report back to the server the choice + debug!("[{log_label}] Reporting choice: {choice:#?}"); + self.write_frame_for(&choice).await?; + + choice + } + Handshake::Server { + compression_types, + encryption_types, + } => { + let options = Options { + compression_types: compression_types.to_vec(), + encryption_types: encryption_types.to_vec(), + }; + + // Send options to the client + debug!("[{log_label}] Sending options: {options:#?}"); + self.write_frame_for(&options).await?; + + // Get client's response with selected compression and encryption + debug!("[{log_label}] Waiting on choice"); + self.read_frame_as::().await?.ok_or_else(|| { + io::Error::new( + io::ErrorKind::UnexpectedEof, + "Transport closed early while waiting for choice", + ) + })? + } + }; + + debug!("[{log_label}] Building compression & encryption codecs based on {choice:#?}"); + let compression_level = choice.compression_level.unwrap_or_default(); + + // Acquire a codec for the compression type + let compression_codec = choice + .compression_type + .map(|ty| ty.new_codec(compression_level)) + .transpose()?; + + // In the case that we are using encryption, we derive a shared secret key to use with the + // encryption type + let encryption_codec = match choice.encryption_type { + // Fail early if we got an unknown encryption type + Some(EncryptionType::Unknown) => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Unknown compression type", + )) + } + Some(ty) => { + let key = self.exchange_keys_impl(log_label).await?; + Some(ty.new_codec(key.unprotected_as_bytes())?) + } + None => None, + }; + + // Bundle our compression and encryption codecs into a single, chained codec + trace!("[{log_label}] Bundling codecs"); + let codec: BoxedCodec = match (compression_codec, encryption_codec) { + // If we have both encryption and compression, do the encryption first and then + // compress in order to get smallest result + (Some(c), Some(e)) => Box::new(ChainCodec::new(e, c)), + + // If we just have compression, pass along the compression codec + (Some(c), None) => Box::new(c), + + // If we just have encryption, pass along the encryption codec + (None, Some(e)) => Box::new(e), + + // If we have neither compression nor encryption, use a plaintext codec + (None, None) => Box::new(PlainCodec::new()), + }; + + Ok(codec) + } + + /// Places the transport into key-exchange mode where it attempts to derive a shared secret key + /// with the other transport. + pub async fn exchange_keys(&mut self) -> io::Result { + self.exchange_keys_impl("").await + } + + async fn exchange_keys_impl(&mut self, label: &str) -> io::Result { + let log_label = if label.is_empty() { + String::new() + } else { + format!("[{label}] ") + }; + + #[derive(Serialize, Deserialize)] + struct KeyExchangeData { + /// Bytes of the public key + #[serde(with = "serde_bytes")] + public_key: PublicKeyBytes, + + /// Randomly generated salt + #[serde(with = "serde_bytes")] + salt: Salt, + } + + debug!("{log_label}Exchanging public key and salt"); + let exchange = KeyExchange::default(); + self.write_frame_for(&KeyExchangeData { + public_key: exchange.pk_bytes(), + salt: *exchange.salt(), + }) + .await?; + + // TODO: This key only works because it happens to be 32 bytes and our encryption + // also wants a 32-byte key. Once we introduce new encryption algorithms that + // are not using 32-byte keys, the key exchange will need to support deriving + // other length keys. + trace!("{log_label}Waiting on public key and salt from other side"); + let data = self + .read_frame_as::() + .await? + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::UnexpectedEof, + "Transport closed early while waiting for key data", + ) + })?; + + trace!("{log_label}Deriving shared secret key"); + let key = exchange.derive_shared_secret(data.public_key, data.salt)?; + Ok(key) + } +} + +#[async_trait] +impl Reconnectable for FramedTransport +where + T: Transport, +{ + async fn reconnect(&mut self) -> io::Result<()> { + Reconnectable::reconnect(&mut self.inner).await + } +} + +impl FramedTransport { + /// Produces a pair of inmemory transports that are connected to each other using a + /// [`PlainCodec`]. + /// + /// Sets the buffer for message passing for each underlying transport to the given buffer size. + pub fn pair( + buffer: usize, + ) -> ( + FramedTransport, + FramedTransport, + ) { + let (a, b) = InmemoryTransport::pair(buffer); + let a = FramedTransport::new(a, Box::new(PlainCodec::new())); + let b = FramedTransport::new(b, Box::new(PlainCodec::new())); + (a, b) + } + + /// Links the underlying transports together using [`InmemoryTransport::link`]. + pub fn link(&mut self, other: &mut Self, buffer: usize) { + self.inner.link(&mut other.inner, buffer) + } +} + +#[cfg(test)] +impl FramedTransport { + /// Generates a test pair with default capacity + pub fn test_pair( + buffer: usize, + ) -> ( + FramedTransport, + FramedTransport, + ) { + Self::pair(buffer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::TestTransport; + use bytes::BufMut; + use test_log::test; + + /// Codec that always succeeds without altering the frame + #[derive(Clone, Debug, PartialEq, Eq)] + struct OkCodec; + + impl Codec for OkCodec { + fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + Ok(frame) + } + + fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + Ok(frame) + } + } + + /// Codec that always fails + #[derive(Clone, Debug, PartialEq, Eq)] + struct ErrCodec; + + impl Codec for ErrCodec { + fn encode<'a>(&mut self, _frame: Frame<'a>) -> io::Result> { + Err(io::Error::from(io::ErrorKind::Other)) + } + + fn decode<'a>(&mut self, _frame: Frame<'a>) -> io::Result> { + Err(io::Error::from(io::ErrorKind::Other)) + } + } + + // Hardcoded custom codec so we can verify it works differently than plain codec + #[derive(Clone)] + struct CustomCodec; + + impl Codec for CustomCodec { + fn encode<'a>(&mut self, _: Frame<'a>) -> io::Result> { + Ok(Frame::new(b"encode")) + } + + fn decode<'a>(&mut self, _: Frame<'a>) -> io::Result> { + Ok(Frame::new(b"decode")) + } + } + + type SimulateTryReadFn = Box io::Result + Send + Sync>; + + /// Simulate calls to try_read by feeding back `data` in `step` increments, triggering a block + /// if `block_on` returns true where `block_on` is provided a counter value that is incremented + /// every time the simulated `try_read` function is called + /// + /// NOTE: This will inject the frame len in front of the provided data to properly simulate + /// receiving a frame of data + fn simulate_try_read( + frames: Vec, + step: usize, + block_on: impl Fn(usize) -> bool + Send + Sync + 'static, + ) -> SimulateTryReadFn { + use std::sync::atomic::{AtomicUsize, Ordering}; + + // Stuff all of our frames into a single byte collection + let data = { + let mut buf = BytesMut::new(); + + for frame in frames { + frame.write(&mut buf).unwrap(); + } + + buf.to_vec() + }; + + let idx = AtomicUsize::new(0); + let cnt = AtomicUsize::new(0); + + Box::new(move |buf| { + if block_on(cnt.fetch_add(1, Ordering::Relaxed)) { + return Err(io::Error::from(io::ErrorKind::WouldBlock)); + } + + let start = idx.fetch_add(step, Ordering::Relaxed); + let end = start + step; + let end = if end > data.len() { data.len() } else { end }; + let len = if start > end { 0 } else { end - start }; + + buf[..len].copy_from_slice(&data[start..end]); + Ok(len) + }) + } + + #[test] + fn try_read_frame_should_return_would_block_if_fails_to_read_frame_before_blocking() { + // Should fail if immediately blocks + let mut transport = FramedTransport::new( + TestTransport { + f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::WouldBlock))), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + assert_eq!( + transport.try_read_frame().unwrap_err().kind(), + io::ErrorKind::WouldBlock + ); + + // Should fail if not read enough bytes before blocking + let mut transport = FramedTransport::new( + TestTransport { + f_try_read: simulate_try_read(vec![Frame::new(b"some data")], 1, |cnt| cnt == 1), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + assert_eq!( + transport.try_read_frame().unwrap_err().kind(), + io::ErrorKind::WouldBlock + ); + } + + #[test] + fn try_read_frame_should_return_error_if_encountered_error_with_reading_bytes() { + let mut transport = FramedTransport::new( + TestTransport { + f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + assert_eq!( + transport.try_read_frame().unwrap_err().kind(), + io::ErrorKind::NotConnected + ); + } + + #[test] + fn try_read_frame_should_return_error_if_encountered_error_during_decode() { + let mut transport = FramedTransport::new( + TestTransport { + f_try_read: simulate_try_read(vec![Frame::new(b"some data")], 1, |_| false), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }, + Box::new(ErrCodec), + ); + assert_eq!( + transport.try_read_frame().unwrap_err().kind(), + io::ErrorKind::Other + ); + } + + #[test] + fn try_read_frame_should_return_next_available_frame() { + let data = { + let mut data = BytesMut::new(); + Frame::new(b"hello world").write(&mut data).unwrap(); + data.freeze() + }; + + let mut transport = FramedTransport::new( + TestTransport { + f_try_read: Box::new(move |buf| { + buf[..data.len()].copy_from_slice(data.as_ref()); + Ok(data.len()) + }), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world"); + } + + #[test] + fn try_read_frame_should_return_next_available_frame_if_already_in_incoming_buffer() { + // Store two frames in our data to transmit + let data = { + let mut data = BytesMut::new(); + Frame::new(b"hello world").write(&mut data).unwrap(); + Frame::new(b"hello again").write(&mut data).unwrap(); + data.freeze() + }; + + // Configure transport to return both frames in single read such that we have another + // complete frame to parse (in the case that an underlying try_read would block, but we had + // data available before that) + let mut transport = FramedTransport::new( + TestTransport { + f_try_read: Box::new(move |buf| { + static mut CNT: usize = 0; + unsafe { + CNT += 1; + if CNT == 2 { + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } else { + let n = data.len(); + buf[..data.len()].copy_from_slice(data.as_ref()); + Ok(n) + } + } + }), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + + // Read first frame + assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world"); + + // Read second frame + assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello again"); + } + + #[test] + fn try_read_frame_should_keep_reading_until_a_frame_is_found() { + const STEP_SIZE: usize = Frame::HEADER_SIZE + 7; + + let mut transport = FramedTransport::new( + TestTransport { + f_try_read: simulate_try_read( + vec![Frame::new(b"hello world"), Frame::new(b"test hello")], + STEP_SIZE, + |_| false, + ), + f_ready: Box::new(|_| Ok(Ready::READABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world"); + + // Should have leftover bytes from next frame + // where len = 10, "tes" + assert_eq!( + transport.incoming.to_vec(), + [0, 0, 0, 0, 0, 0, 0, 10, b't', b'e', b's'] + ); + } + + #[test] + fn try_write_frame_should_return_would_block_if_fails_to_write_frame_before_blocking() { + let mut transport = FramedTransport::new( + TestTransport { + f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::WouldBlock))), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + + // First call will only write part of the frame and then return WouldBlock + assert_eq!( + transport + .try_write_frame(b"hello world") + .unwrap_err() + .kind(), + io::ErrorKind::WouldBlock + ); + } + + #[test] + fn try_write_frame_should_return_error_if_encountered_error_with_writing_bytes() { + let mut transport = FramedTransport::new( + TestTransport { + f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + assert_eq!( + transport + .try_write_frame(b"hello world") + .unwrap_err() + .kind(), + io::ErrorKind::NotConnected + ); + } + + #[test] + fn try_write_frame_should_return_error_if_encountered_error_during_encode() { + let mut transport = FramedTransport::new( + TestTransport { + f_try_write: Box::new(|buf| Ok(buf.len())), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }, + Box::new(ErrCodec), + ); + assert_eq!( + transport + .try_write_frame(b"hello world") + .unwrap_err() + .kind(), + io::ErrorKind::Other + ); + } + + #[test] + fn try_write_frame_should_write_entire_frame_if_possible() { + let (tx, rx) = std::sync::mpsc::sync_channel(1); + let mut transport = FramedTransport::new( + TestTransport { + f_try_write: Box::new(move |buf| { + let len = buf.len(); + tx.send(buf.to_vec()).unwrap(); + Ok(len) + }), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + + transport.try_write_frame(b"hello world").unwrap(); + + // Transmitted data should be encoded using the framed transport's codec + assert_eq!( + rx.try_recv().unwrap(), + [11u64.to_be_bytes().as_slice(), b"hello world".as_slice()].concat() + ); + } + + #[test] + fn try_write_frame_should_write_any_prior_queued_bytes_before_writing_next_frame() { + const STEP_SIZE: usize = Frame::HEADER_SIZE + 5; + let (tx, rx) = std::sync::mpsc::sync_channel(10); + let mut transport = FramedTransport::new( + TestTransport { + f_try_write: Box::new(move |buf| { + static mut CNT: usize = 0; + unsafe { + CNT += 1; + if CNT == 2 { + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } else { + let len = std::cmp::min(STEP_SIZE, buf.len()); + tx.send(buf[..len].to_vec()).unwrap(); + Ok(len) + } + } + }), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + + // First call will only write part of the frame and then return WouldBlock + assert_eq!( + transport + .try_write_frame(b"hello world") + .unwrap_err() + .kind(), + io::ErrorKind::WouldBlock + ); + + // Transmitted data should be encoded using the framed transport's codec + assert_eq!( + rx.try_recv().unwrap(), + [11u64.to_be_bytes().as_slice(), b"hello".as_slice()].concat() + ); + assert_eq!( + rx.try_recv().unwrap_err(), + std::sync::mpsc::TryRecvError::Empty + ); + + // Next call will keep writing successfully until done + transport.try_write_frame(b"test").unwrap(); + assert_eq!( + rx.try_recv().unwrap(), + [b' ', b'w', b'o', b'r', b'l', b'd', 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!(rx.try_recv().unwrap(), [4, b't', b'e', b's', b't']); + assert_eq!( + rx.try_recv().unwrap_err(), + std::sync::mpsc::TryRecvError::Empty + ); + } + + #[test] + fn try_flush_should_return_error_if_try_write_fails() { + let mut transport = FramedTransport::new( + TestTransport { + f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + + // Set our outgoing buffer to flush + transport.outgoing.put_slice(b"hello world"); + + // Perform flush and verify error happens + assert_eq!( + transport.try_flush().unwrap_err().kind(), + io::ErrorKind::NotConnected + ); + } + + #[test] + fn try_flush_should_return_error_if_try_write_returns_0_bytes_written() { + let mut transport = FramedTransport::new( + TestTransport { + f_try_write: Box::new(|_| Ok(0)), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + + // Set our outgoing buffer to flush + transport.outgoing.put_slice(b"hello world"); + + // Perform flush and verify error happens + assert_eq!( + transport.try_flush().unwrap_err().kind(), + io::ErrorKind::WriteZero + ); + } + + #[test] + fn try_flush_should_be_noop_if_nothing_to_flush() { + let mut transport = FramedTransport::new( + TestTransport { + f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + + // Perform flush and verify nothing happens + transport.try_flush().unwrap(); + } + + #[test] + fn try_flush_should_continually_call_try_write_until_outgoing_buffer_is_empty() { + const STEP_SIZE: usize = 5; + let (tx, rx) = std::sync::mpsc::sync_channel(10); + let mut transport = FramedTransport::new( + TestTransport { + f_try_write: Box::new(move |buf| { + let len = std::cmp::min(STEP_SIZE, buf.len()); + tx.send(buf[..len].to_vec()).unwrap(); + Ok(len) + }), + f_ready: Box::new(|_| Ok(Ready::WRITABLE)), + ..Default::default() + }, + Box::new(OkCodec), + ); + + // Set our outgoing buffer to flush + transport.outgoing.put_slice(b"hello world"); + + // Perform flush + transport.try_flush().unwrap(); + + // Verify outgoing data flushed with N calls to try_write + assert_eq!(rx.try_recv().unwrap(), b"hello".as_slice()); + assert_eq!(rx.try_recv().unwrap(), b" worl".as_slice()); + assert_eq!(rx.try_recv().unwrap(), b"d".as_slice()); + assert_eq!( + rx.try_recv().unwrap_err(), + std::sync::mpsc::TryRecvError::Empty + ); + } + + #[inline] + async fn test_synchronize_stats( + transport: &mut FramedTransport, + sent_cnt: u64, + received_cnt: u64, + available_cnt: u64, + expected_sent_cnt: u64, + expected_received_cnt: u64, + expected_available_cnt: u64, + ) { + // From the other side, claim that we have received 2 frames + // (sent, received, available) + transport + .write_frame_for(&(sent_cnt, received_cnt, available_cnt)) + .await + .unwrap(); + + // Receive stats from the other side + let (sent, received, available) = transport + .read_frame_as::<(u64, u64, u64)>() + .await + .unwrap() + .unwrap(); + assert_eq!(sent, expected_sent_cnt, "Wrong sent cnt"); + assert_eq!(received, expected_received_cnt, "Wrong received cnt"); + assert_eq!(available, expected_available_cnt, "Wrong available cnt"); + } + + #[test(tokio::test)] + async fn synchronize_should_resend_no_frames_if_other_side_claims_it_has_more_than_us() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Configure the backup such that we have sent one frame + t2.backup.push_frame(Frame::new(b"hello world")); + t2.backup.increment_sent_cnt(); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let _task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 0, 2, 0 + // expected (sent, received, available) = 1, 0, 1 + test_synchronize_stats(&mut t1, 0, 2, 0, 1, 0, 1).await; + + // Should not receive anything before our done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + } + + #[test(tokio::test)] + async fn synchronize_should_resend_no_frames_if_none_missing_on_other_side() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Configure the backup such that we have sent one frame + t2.backup.push_frame(Frame::new(b"hello world")); + t2.backup.increment_sent_cnt(); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let _task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 0, 1, 0 + // expected (sent, received, available) = 1, 0, 1 + test_synchronize_stats(&mut t1, 0, 1, 0, 1, 0, 1).await; + + // Should not receive anything before our done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + } + + #[test(tokio::test)] + async fn synchronize_should_resend_some_frames_if_some_missing_on_other_side() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Configure the backup such that we have sent two frames + t2.backup.push_frame(Frame::new(b"hello")); + t2.backup.push_frame(Frame::new(b"world")); + t2.backup.increment_sent_cnt(); + t2.backup.increment_sent_cnt(); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let _task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 0, 1, 0 + // expected (sent, received, available) = 2, 0, 2 + test_synchronize_stats(&mut t1, 0, 1, 0, 2, 0, 2).await; + + // Recieve both frames and then the done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world"); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + } + + #[test(tokio::test)] + async fn synchronize_should_resend_all_frames_if_all_missing_on_other_side() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Configure the backup such that we have sent two frames + t2.backup.push_frame(Frame::new(b"hello")); + t2.backup.push_frame(Frame::new(b"world")); + t2.backup.increment_sent_cnt(); + t2.backup.increment_sent_cnt(); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let _task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 0, 0, 0 + // expected (sent, received, available) = 2, 0, 2 + test_synchronize_stats(&mut t1, 0, 0, 0, 2, 0, 2).await; + + // Recieve both frames and then the done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello"); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world"); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + } + + #[test(tokio::test)] + async fn synchronize_should_resend_available_frames_if_more_than_available_missing_on_other_side( + ) { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Configure the backup such that we have sent two frames, and believe that we have + // sent 3 in total, a situation that happens once we reach the peak possible size of + // old frames to store + t2.backup.push_frame(Frame::new(b"hello")); + t2.backup.push_frame(Frame::new(b"world")); + t2.backup.increment_sent_cnt(); + t2.backup.increment_sent_cnt(); + t2.backup.increment_sent_cnt(); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let _task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 0, 0, 0 + // expected (sent, received, available) = 3, 0, 2 + test_synchronize_stats(&mut t1, 0, 0, 0, 3, 0, 2).await; + + // Recieve both frames and then the done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello"); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world"); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + } + + #[test(tokio::test)] + async fn synchronize_should_receive_no_frames_if_other_side_claims_it_has_more_than_us() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Mark other side as having received a frame + t2.backup.increment_received_cnt(); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let _task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 0, 0, 0 + // expected (sent, received, available) = 0, 1, 0 + test_synchronize_stats(&mut t1, 0, 0, 0, 0, 1, 0).await; + + // Recieve the done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + } + + #[test(tokio::test)] + async fn synchronize_should_receive_no_frames_if_none_missing_from_other_side() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Mark other side as having received a frame + t2.backup.increment_received_cnt(); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let _task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 1, 0, 1 + // expected (sent, received, available) = 0, 1, 0 + test_synchronize_stats(&mut t1, 1, 0, 1, 0, 1, 0).await; + + // Recieve the done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + } + + #[test(tokio::test)] + async fn synchronize_should_receive_some_frames_if_some_missing_from_other_side() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Mark other side as having received a frame + t2.backup.increment_received_cnt(); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 2, 0, 2 + // expected (sent, received, available) = 0, 1, 0 + test_synchronize_stats(&mut t1, 2, 0, 2, 0, 1, 0).await; + + // Send a frame to fill the gap + t1.write_frame(Frame::new(b"hello")).await.unwrap(); + + // Recieve the done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + + // Drop the transport such that the other side will get a definite termination + drop(t1); + + // Verify that the frame was captured on the other side + let mut t2 = task.await.unwrap(); + assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello"); + assert_eq!(t2.read_frame().await.unwrap(), None); + } + + #[test(tokio::test)] + async fn synchronize_should_receive_all_frames_if_all_missing_from_other_side() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 2, 0, 2 + // expected (sent, received, available) = 0, 0, 0 + test_synchronize_stats(&mut t1, 2, 0, 2, 0, 0, 0).await; + + // Send frames to fill the gap + t1.write_frame(Frame::new(b"hello")).await.unwrap(); + t1.write_frame(Frame::new(b"world")).await.unwrap(); + + // Recieve the done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + + // Drop the transport such that the other side will get a definite termination + drop(t1); + + // Verify that the frame was captured on the other side + let mut t2 = task.await.unwrap(); + assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello"); + assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"world"); + assert_eq!(t2.read_frame().await.unwrap(), None); + } + + #[test(tokio::test)] + async fn synchronize_should_receive_all_frames_if_more_than_all_missing_from_other_side() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 3, 0, 2 + // expected (sent, received, available) = 0, 0, 0 + test_synchronize_stats(&mut t1, 2, 0, 2, 0, 0, 0).await; + + // Send frames to fill the gap + t1.write_frame(Frame::new(b"hello")).await.unwrap(); + t1.write_frame(Frame::new(b"world")).await.unwrap(); + + // Recieve the done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + + // Drop the transport such that the other side will get a definite termination + drop(t1); + + // Verify that the frame was captured on the other side + let mut t2 = task.await.unwrap(); + assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello"); + assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"world"); + assert_eq!(t2.read_frame().await.unwrap(), None); + } + + #[test(tokio::test)] + async fn synchronize_should_fail_if_connection_terminated_before_receiving_missing_frames() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 2, 0, 2 + // expected (sent, received, available) = 0, 0, 0 + test_synchronize_stats(&mut t1, 2, 0, 2, 0, 0, 0).await; + + // Send one frame to fill the gap + t1.write_frame(Frame::new(b"hello")).await.unwrap(); + + // Drop the transport to cause a failure + drop(t1); + + // Verify that the other side's synchronization failed + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn synchronize_should_fail_if_connection_terminated_while_waiting_for_frame_stats() { + let (t1, mut t2) = FramedTransport::pair(100); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // Drop the transport to cause a failure + drop(t1); + + // Verify that the other side's synchronization failed + task.await.unwrap_err(); + } + + #[test(tokio::test)] + async fn synchronize_should_clear_any_prexisting_incoming_and_outgoing_data() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Put some frames into the incoming and outgoing of our transport + Frame::new(b"bad incoming").write(&mut t2.incoming).unwrap(); + Frame::new(b"bad outgoing").write(&mut t2.outgoing).unwrap(); + + // Configure the backup such that we have sent two frames + t2.backup.push_frame(Frame::new(b"hello")); + t2.backup.push_frame(Frame::new(b"world")); + t2.backup.increment_sent_cnt(); + t2.backup.increment_sent_cnt(); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2 + }); + + // fake (sent, received, available) = 2, 0, 2 + // expected (sent, received, available) = 2, 0, 2 + test_synchronize_stats(&mut t1, 2, 0, 2, 2, 0, 2).await; + + // Send frames to fill the gap + t1.write_frame(Frame::new(b"one")).await.unwrap(); + t1.write_frame(Frame::new(b"two")).await.unwrap(); + + // Recieve both frames and then the done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello"); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world"); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + + // Drop the transport such that the other side will get a definite termination + drop(t1); + + // Verify that the frame was captured on the other side + let mut t2 = task.await.unwrap(); + assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"one"); + assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"two"); + assert_eq!(t2.read_frame().await.unwrap(), None); + } + + #[test(tokio::test)] + async fn synchronize_should_not_increment_the_sent_frames_or_store_replayed_frames_in_the_backup( + ) { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Configure the backup such that we have sent two frames + t2.backup.push_frame(Frame::new(b"hello")); + t2.backup.push_frame(Frame::new(b"world")); + t2.backup.increment_sent_cnt(); + t2.backup.increment_sent_cnt(); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + + t2.backup.freeze(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2.backup.unfreeze(); + + t2 + }); + + // fake (sent, received, available) = 0, 0, 0 + // expected (sent, received, available) = 2, 0, 2 + test_synchronize_stats(&mut t1, 0, 0, 0, 2, 0, 2).await; + + // Recieve both frames and then the done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello"); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world"); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + + // Drop the transport such that the other side will get a definite termination + drop(t1); + + // Verify that the backup on the other side was unaltered by the frames being sent + let t2 = task.await.unwrap(); + assert_eq!(t2.backup.sent_cnt(), 2, "Wrong sent cnt"); + assert_eq!(t2.backup.received_cnt(), 0, "Wrong received cnt"); + assert_eq!(t2.backup.frame_cnt(), 2, "Wrong frame cnt"); + } + + #[test(tokio::test)] + async fn synchronize_should_update_the_backup_received_cnt_to_match_other_side_sent() { + let (mut t1, mut t2) = FramedTransport::pair(100); + + // Spawn a separate task to do synchronization simulation so we don't deadlock, and also + // send a frame to indicate when finished so we can know when synchronization is done + // during our test + let task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + + t2.backup.freeze(); + t2.write_frame(Frame::new(b"done")).await.unwrap(); + t2.backup.unfreeze(); + + t2 + }); + + // fake (sent, received, available) = 2, 0, 1 + // expected (sent, received, available) = 0, 0, 0 + test_synchronize_stats(&mut t1, 2, 0, 1, 0, 0, 0).await; + + // Send frames to fill the gap + t1.write_frame(Frame::new(b"hello")).await.unwrap(); + + // Recieve both frames and then the done indicator + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done"); + + // Drop the transport such that the other side will get a definite termination + drop(t1); + + // Verify that the backup on the other side updated based on sent count and not available + let t2 = task.await.unwrap(); + assert_eq!(t2.backup.sent_cnt(), 0, "Wrong sent cnt"); + assert_eq!(t2.backup.received_cnt(), 2, "Wrong received cnt"); + assert_eq!(t2.backup.frame_cnt(), 0, "Wrong frame cnt"); + } + + #[test(tokio::test)] + async fn synchronize_should_work_even_if_codec_changes_between_attempts() { + let (mut t1, _t1_other) = FramedTransport::pair(100); + let (mut t2, _t2_other) = FramedTransport::pair(100); + + // Send some frames from each side + t1.write_frame(Frame::new(b"hello")).await.unwrap(); + t1.write_frame(Frame::new(b"world")).await.unwrap(); + t2.write_frame(Frame::new(b"foo")).await.unwrap(); + t2.write_frame(Frame::new(b"bar")).await.unwrap(); + + // Drop the other transports, link our real transports together, and change the codec + drop(_t1_other); + drop(_t2_other); + t1.link(&mut t2, 100); + let codec = EncryptionCodec::new_xchacha20poly1305(Default::default()); + t1.codec = Box::new(codec.clone()); + t2.codec = Box::new(codec); + + // Spawn a separate task to do synchronization so we don't deadlock + let task = tokio::spawn(async move { + t2.synchronize().await.unwrap(); + t2 + }); + + t1.synchronize().await.unwrap(); + + // Verify that we get the appropriate frames from both sides + let mut t2 = task.await.unwrap(); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"foo"); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"bar"); + assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello"); + assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"world"); + } + + #[test(tokio::test)] + async fn handshake_should_configure_transports_with_matching_codec() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // NOTE: Spawn a separate task for one of our transports so we can communicate without + // deadlocking + let task = tokio::spawn(async move { + // Wait for handshake to complete + t2.server_handshake().await.unwrap(); + + // Receive one frame and echo it back + let frame = t2.read_frame().await.unwrap().unwrap(); + t2.write_frame(frame).await.unwrap(); + }); + + t1.client_handshake().await.unwrap(); + + // Verify that the transports can still communicate with one another + t1.write_frame(b"hello world").await.unwrap(); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello world"); + + // Ensure that the other transport did not error + task.await.unwrap(); + } + + #[test(tokio::test)] + async fn handshake_failing_should_ensure_existing_codec_remains() { + let (mut t1, t2) = FramedTransport::test_pair(100); + + // Set a different codec on our transport so we can verify it doesn't change + t1.set_codec(Box::new(CustomCodec)); + + // Drop our transport on the other side to cause an immediate failure + drop(t2); + + // Ensure we detect the failure on handshake + t1.client_handshake().await.unwrap_err(); + + // Verify that the codec did not reset to plain text by using the codec + assert_eq!(t1.codec.encode(Frame::new(b"test")).unwrap(), b"encode"); + assert_eq!(t1.codec.decode(Frame::new(b"test")).unwrap(), b"decode"); + } + + #[test(tokio::test)] + async fn handshake_should_clear_any_intermittent_buffer_contents_prior_to_handshake_failing() { + let (mut t1, t2) = FramedTransport::test_pair(100); + + // Set a different codec on our transport so we can verify it doesn't change + t1.set_codec(Box::new(CustomCodec)); + + // Drop our transport on the other side to cause an immediate failure + drop(t2); + + // Put some garbage in our buffers + t1.incoming.extend_from_slice(b"garbage in"); + t1.outgoing.extend_from_slice(b"garbage out"); + + // Ensure we detect the failure on handshake + t1.client_handshake().await.unwrap_err(); + + // Verify that the incoming and outgoing buffers are empty + assert!(t1.incoming.is_empty()); + assert!(t1.outgoing.is_empty()); + } + + #[test(tokio::test)] + async fn handshake_should_clear_any_intermittent_buffer_contents_prior_to_handshake_succeeding() + { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // NOTE: Spawn a separate task for one of our transports so we can communicate without + // deadlocking + let task = tokio::spawn(async move { + // Wait for handshake to complete + t2.server_handshake().await.unwrap(); + + // Receive one frame and echo it back + let frame = t2.read_frame().await.unwrap().unwrap(); + t2.write_frame(frame).await.unwrap(); + }); + + // Put some garbage in our buffers + t1.incoming.extend_from_slice(b"garbage in"); + t1.outgoing.extend_from_slice(b"garbage out"); + + t1.client_handshake().await.unwrap(); + + // Verify that the transports can still communicate with one another + t1.write_frame(b"hello world").await.unwrap(); + assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello world"); + + // Ensure that the other transport did not error + task.await.unwrap(); + + // Verify that the incoming and outgoing buffers are empty + assert!(t1.incoming.is_empty()); + assert!(t1.outgoing.is_empty()); + } + + #[test(tokio::test)] + async fn handshake_for_client_should_fail_if_receives_unexpected_frame_instead_of_options() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // NOTE: Spawn a separate task for one of our transports so we can communicate without + // deadlocking + let task = tokio::spawn(async move { + t2.write_frame(b"not a valid frame for handshake") + .await + .unwrap(); + }); + + // Ensure we detect the failure on handshake + let err = t1.client_handshake().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + + // Ensure that the other transport did not error + task.await.unwrap(); + } + + #[test(tokio::test)] + async fn handshake_for_client_should_fail_unable_to_send_codec_choice_to_other_side() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + #[derive(Debug, Serialize, Deserialize)] + struct Options { + compression_types: Vec, + encryption_types: Vec, + } + + // NOTE: Spawn a separate task for one of our transports so we can communicate without + // deadlocking + let task = tokio::spawn(async move { + // Send options, and then quit so the client side will fail + t2.write_frame_for(&Options { + compression_types: Vec::new(), + encryption_types: Vec::new(), + }) + .await + .unwrap(); + }); + + // Ensure we detect the failure on handshake + let err = t1.client_handshake().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::WriteZero); + + // Ensure that the other transport did not error + task.await.unwrap(); + } + + #[test(tokio::test)] + async fn handshake_for_client_should_fail_if_unable_to_receive_key_exchange_data_from_other_side( + ) { + #[derive(Debug, Serialize, Deserialize)] + struct Options { + compression_types: Vec, + encryption_types: Vec, + } + + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Go ahead and queue up a choice, and then queue up invalid key exchange data + t2.write_frame_for(&Options { + compression_types: CompressionType::known_variants().to_vec(), + encryption_types: EncryptionType::known_variants().to_vec(), + }) + .await + .unwrap(); + + t2.write_frame(b"not valid key exchange data") + .await + .unwrap(); + + // Ensure we detect the failure on handshake + let err = t1.client_handshake().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[test(tokio::test)] + async fn handshake_for_server_should_fail_if_receives_unexpected_frame_instead_of_choice() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // NOTE: Spawn a separate task for one of our transports so we can communicate without + // deadlocking + let task = tokio::spawn(async move { + t2.write_frame(b"not a valid frame for handshake") + .await + .unwrap(); + }); + + // Ensure we detect the failure on handshake + let err = t1.server_handshake().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + + // Ensure that the other transport did not error + task.await.unwrap(); + } + + #[test(tokio::test)] + async fn handshake_for_server_should_fail_unable_to_send_codec_options_to_other_side() { + let (mut t1, t2) = FramedTransport::test_pair(100); + + // Drop our other transport to ensure that nothing can be sent to it + drop(t2); + + // Ensure we detect the failure on handshake + let err = t1.server_handshake().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::WriteZero); + } + + #[test(tokio::test)] + async fn handshake_for_server_should_fail_if_selected_codec_choice_uses_an_unknown_compression_type( + ) { + #[derive(Debug, Serialize, Deserialize)] + struct Choice { + compression_level: Option, + compression_type: Option, + encryption_type: Option, + } + + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Go ahead and queue up an improper response + t2.write_frame_for(&Choice { + compression_level: None, + compression_type: Some(CompressionType::Unknown), + encryption_type: None, + }) + .await + .unwrap(); + + // Ensure we detect the failure on handshake + let err = t1.server_handshake().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + } + + #[test(tokio::test)] + async fn handshake_for_server_should_fail_if_selected_codec_choice_uses_an_unknown_encryption_type( + ) { + #[derive(Debug, Serialize, Deserialize)] + struct Choice { + compression_level: Option, + compression_type: Option, + encryption_type: Option, + } + + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Go ahead and queue up an improper response + t2.write_frame_for(&Choice { + compression_level: None, + compression_type: None, + encryption_type: Some(EncryptionType::Unknown), + }) + .await + .unwrap(); + + // Ensure we detect the failure on handshake + let err = t1.server_handshake().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + } + + #[test(tokio::test)] + async fn handshake_for_server_should_fail_if_unable_to_receive_key_exchange_data_from_other_side( + ) { + #[derive(Debug, Serialize, Deserialize)] + struct Choice { + compression_level: Option, + compression_type: Option, + encryption_type: Option, + } + + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Go ahead and queue up a choice, and then queue up invalid key exchange data + t2.write_frame_for(&Choice { + compression_level: None, + compression_type: None, + encryption_type: Some(EncryptionType::XChaCha20Poly1305), + }) + .await + .unwrap(); + + t2.write_frame(b"not valid key exchange data") + .await + .unwrap(); + + // Ensure we detect the failure on handshake + let err = t1.server_handshake().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[test(tokio::test)] + async fn exchange_keys_should_fail_if_unable_to_send_exchange_data_to_other_side() { + let (mut t1, t2) = FramedTransport::test_pair(100); + + // Drop the other side to ensure that the exchange fails at the beginning + drop(t2); + + // Perform key exchange and verify error is as expected + assert_eq!( + t1.exchange_keys().await.unwrap_err().kind(), + io::ErrorKind::WriteZero + ); + } + + #[test(tokio::test)] + async fn exchange_keys_should_fail_if_received_invalid_exchange_data() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Queue up an invalid exchange response + t2.write_frame(b"some invalid frame").await.unwrap(); + + // Perform key exchange and verify error is as expected + assert_eq!( + t1.exchange_keys().await.unwrap_err().kind(), + io::ErrorKind::InvalidData + ); + } + + #[test(tokio::test)] + async fn exchange_keys_should_return_shared_secret_key_if_successful() { + let (mut t1, mut t2) = FramedTransport::test_pair(100); + + // Spawn a task to avoid deadlocking + let task = tokio::spawn(async move { t2.exchange_keys().await.unwrap() }); + + // Perform key exchange + let key = t1.exchange_keys().await.unwrap(); + + // Validate that the keys on both sides match + assert_eq!(key, task.await.unwrap()); + } +} diff --git a/distant-net/src/common/transport/framed/backup.rs b/distant-net/src/common/transport/framed/backup.rs new file mode 100644 index 0000000..4a09068 --- /dev/null +++ b/distant-net/src/common/transport/framed/backup.rs @@ -0,0 +1,201 @@ +use super::{Frame, OwnedFrame}; +use std::collections::VecDeque; + +/// Maximum size (in bytes) for saved frames (256MiB) +const MAX_BACKUP_SIZE: usize = 256 * 1024 * 1024; + +/// Stores [`Frame`]s for reuse later. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Backup { + /// Maximum size (in bytes) to save frames in case we need to backup them + /// + /// NOTE: If 0, no frames will be stored. + max_backup_size: usize, + + /// Tracker for the total size (in bytes) of stored frames + current_backup_size: usize, + + /// Storage used to hold outgoing frames in case they need to be reused + frames: VecDeque, + + /// Counter keeping track of total frames sent + sent_cnt: u64, + + /// Counter keeping track of total frames received + received_cnt: u64, + + /// Indicates whether the backup is frozen, which indicates that mutations are ignored + frozen: bool, +} + +impl Default for Backup { + fn default() -> Self { + Self::new() + } +} + +impl Backup { + /// Creates a new, unfrozen backup. + pub fn new() -> Self { + Self { + max_backup_size: MAX_BACKUP_SIZE, + current_backup_size: 0, + frames: VecDeque::new(), + sent_cnt: 0, + received_cnt: 0, + frozen: false, + } + } + + /// Clears the backup of any stored data and resets the state to being new. + /// + /// ### Note + /// + /// Like all other modifications, this will do nothing if the backup is frozen. + pub fn clear(&mut self) { + if !self.frozen { + self.current_backup_size = 0; + self.frames.clear(); + self.sent_cnt = 0; + self.received_cnt = 0; + } + } + + /// Returns true if the backup is frozen, meaning that modifications will be ignored. + #[inline] + pub fn is_frozen(&self) -> bool { + self.frozen + } + + /// Sets the frozen status. + #[inline] + pub fn set_frozen(&mut self, frozen: bool) { + self.frozen = frozen; + } + + /// Marks the backup as frozen. + #[inline] + pub fn freeze(&mut self) { + self.frozen = true; + } + + /// Marks the backup as no longer frozen. + #[inline] + pub fn unfreeze(&mut self) { + self.frozen = false; + } + + /// Sets the maximum size (in bytes) of collective frames stored in case a backup is needed + /// during reconnection. Setting the `size` to 0 will result in no frames being stored. + /// + /// ### Note + /// + /// Like all other modifications, this will do nothing if the backup is frozen. + pub fn set_max_backup_size(&mut self, size: usize) { + if !self.frozen { + self.max_backup_size = size; + } + } + + /// Returns the maximum size (in bytes) of collective frames stored in case a backup is needed + /// during reconnection. + pub fn max_backup_size(&self) -> usize { + self.max_backup_size + } + + /// Increments (by 1) the total sent frames. + /// + /// ### Note + /// + /// Like all other modifications, this will do nothing if the backup is frozen. + pub(crate) fn increment_sent_cnt(&mut self) { + if !self.frozen { + self.sent_cnt += 1; + } + } + + /// Returns how many frames have been sent. + pub(crate) fn sent_cnt(&self) -> u64 { + self.sent_cnt + } + + /// Increments (by 1) the total received frames. + /// + /// ### Note + /// + /// Like all other modifications, this will do nothing if the backup is frozen. + pub(super) fn increment_received_cnt(&mut self) { + if !self.frozen { + self.received_cnt += 1; + } + } + + /// Returns how many frames have been received. + pub(crate) fn received_cnt(&self) -> u64 { + self.received_cnt + } + + /// Sets the total received frames to the specified `cnt`. + /// + /// ### Note + /// + /// Like all other modifications, this will do nothing if the backup is frozen. + pub(super) fn set_received_cnt(&mut self, cnt: u64) { + if !self.frozen { + self.received_cnt = cnt; + } + } + + /// Pushes a new frame to the end of the internal queue. + /// + /// ### Note + /// + /// Like all other modifications, this will do nothing if the backup is frozen. + pub(crate) fn push_frame(&mut self, frame: Frame) { + if self.max_backup_size > 0 && !self.frozen { + self.current_backup_size += frame.len(); + self.frames.push_back(frame.into_owned()); + while self.current_backup_size > self.max_backup_size { + match self.frames.pop_front() { + Some(frame) => { + self.current_backup_size -= frame.len(); + } + + // If we have exhausted all frames, then we have reached + // an internal size of 0 and should exit the loop + None => { + self.current_backup_size = 0; + break; + } + } + } + } + } + + /// Returns the total frames being kept for potential reuse. + pub(super) fn frame_cnt(&self) -> usize { + self.frames.len() + } + + /// Returns an iterator over the frames contained in the backup. + pub(super) fn frames(&self) -> impl Iterator { + self.frames.iter() + } + + /// Truncates the stored frames to be no larger than `size` total frames by popping from the + /// front rather than the back of the list. + /// + /// ### Note + /// + /// Like all other modifications, this will do nothing if the backup is frozen. + pub(super) fn truncate_front(&mut self, size: usize) { + if !self.frozen { + while self.frames.len() > size { + if let Some(frame) = self.frames.pop_front() { + self.current_backup_size -= + std::cmp::min(frame.len(), self.current_backup_size); + } + } + } + } +} diff --git a/distant-net/src/common/transport/framed/codec.rs b/distant-net/src/common/transport/framed/codec.rs new file mode 100644 index 0000000..8aec343 --- /dev/null +++ b/distant-net/src/common/transport/framed/codec.rs @@ -0,0 +1,68 @@ +use super::Frame; +use dyn_clone::DynClone; +use std::io; + +mod chain; +mod compression; +mod encryption; +mod plain; +mod predicate; + +pub use chain::*; +pub use compression::*; +pub use encryption::*; +pub use plain::*; +pub use predicate::*; + +/// Represents abstraction that implements specific encoder and decoder logic to transform an +/// arbitrary collection of bytes. This can be used to encrypt and authenticate bytes sent and +/// received by transports. +pub trait Codec: DynClone { + /// Encodes a frame's item + fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result>; + + /// Decodes a frame's item + fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result>; +} + +/// Represents a [`Box`]ed version of [`Codec`] +pub type BoxedCodec = Box; + +macro_rules! impl_traits { + ($($x:tt)+) => { + impl Clone for Box { + fn clone(&self) -> Self { + dyn_clone::clone_box(&**self) + } + } + + impl Codec for Box { + fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + Codec::encode(self.as_mut(), frame) + } + + fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + Codec::decode(self.as_mut(), frame) + } + } + }; +} + +impl_traits!(Codec); +impl_traits!(Codec + Send); +impl_traits!(Codec + Sync); +impl_traits!(Codec + Send + Sync); + +/// Interface that provides extensions to the codec interface +pub trait CodecExt { + /// Chains this codec with another codec + fn chain(self, codec: T) -> ChainCodec + where + Self: Sized; +} + +impl CodecExt for C { + fn chain(self, codec: T) -> ChainCodec { + ChainCodec::new(self, codec) + } +} diff --git a/distant-net/src/common/transport/framed/codec/chain.rs b/distant-net/src/common/transport/framed/codec/chain.rs new file mode 100644 index 0000000..1148bd2 --- /dev/null +++ b/distant-net/src/common/transport/framed/codec/chain.rs @@ -0,0 +1,160 @@ +use super::{Codec, Frame}; +use std::io; + +/// Represents a codec that chains together other codecs such that encoding will call the encode +/// methods of the underlying, chained codecs from left-to-right and decoding will call the decode +/// methods in reverse order +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +pub struct ChainCodec { + left: T, + right: U, +} + +impl ChainCodec { + /// Chains two codecs together such that `left` will be invoked first during encoding and + /// `right` will be invoked first during decoding + pub fn new(left: T, right: U) -> Self { + Self { left, right } + } + + /// Returns reference to left codec + pub fn as_left(&self) -> &T { + &self.left + } + + /// Consumes the chain and returns the left codec + pub fn into_left(self) -> T { + self.left + } + + /// Returns reference to right codec + pub fn as_right(&self) -> &U { + &self.right + } + + /// Consumes the chain and returns the right codec + pub fn into_right(self) -> U { + self.right + } + + /// Consumes the chain and returns the left and right codecs + pub fn into_left_right(self) -> (T, U) { + (self.left, self.right) + } +} + +impl Codec for ChainCodec +where + T: Codec + Clone, + U: Codec + Clone, +{ + fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + Codec::encode(&mut self.left, frame).and_then(|frame| Codec::encode(&mut self.right, frame)) + } + + fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + Codec::decode(&mut self.right, frame).and_then(|frame| Codec::decode(&mut self.left, frame)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + #[derive(Copy, Clone)] + struct TestCodec<'a> { + msg: &'a str, + } + + impl<'a> TestCodec<'a> { + pub fn new(msg: &'a str) -> Self { + Self { msg } + } + } + + impl Codec for TestCodec<'_> { + fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + let mut item = frame.into_item().to_vec(); + item.extend_from_slice(self.msg.as_bytes()); + Ok(Frame::from(item)) + } + + fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + let item = frame.into_item().to_vec(); + let frame = Frame::new(item.strip_suffix(self.msg.as_bytes()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + format!( + "Decode failed because did not end with suffix: {}", + self.msg + ), + ) + })?); + Ok(frame.into_owned()) + } + } + + #[derive(Copy, Clone)] + struct ErrCodec; + + impl Codec for ErrCodec { + fn encode<'a>(&mut self, _frame: Frame<'a>) -> io::Result> { + Err(io::Error::from(io::ErrorKind::InvalidData)) + } + + fn decode<'a>(&mut self, _frame: Frame<'a>) -> io::Result> { + Err(io::Error::from(io::ErrorKind::InvalidData)) + } + } + + #[test] + fn encode_should_invoke_left_codec_followed_by_right_codec() { + let mut codec = ChainCodec::new(TestCodec::new("hello"), TestCodec::new("world")); + let frame = codec.encode(Frame::new(b"some bytes")).unwrap(); + assert_eq!(frame, b"some byteshelloworld"); + } + + #[test] + fn encode_should_fail_if_left_codec_fails_to_encode() { + let mut codec = ChainCodec::new(ErrCodec, TestCodec::new("world")); + assert_eq!( + codec.encode(Frame::new(b"some bytes")).unwrap_err().kind(), + io::ErrorKind::InvalidData + ); + } + + #[test] + fn encode_should_fail_if_right_codec_fails_to_encode() { + let mut codec = ChainCodec::new(TestCodec::new("hello"), ErrCodec); + assert_eq!( + codec.encode(Frame::new(b"some bytes")).unwrap_err().kind(), + io::ErrorKind::InvalidData + ); + } + + #[test] + fn decode_should_invoke_right_codec_followed_by_left_codec() { + let mut codec = ChainCodec::new(TestCodec::new("hello"), TestCodec::new("world")); + let frame = codec.decode(Frame::new(b"some byteshelloworld")).unwrap(); + assert_eq!(frame, b"some bytes"); + } + + #[test] + fn decode_should_fail_if_left_codec_fails_to_decode() { + let mut codec = ChainCodec::new(ErrCodec, TestCodec::new("world")); + assert_eq!( + codec.decode(Frame::new(b"some bytes")).unwrap_err().kind(), + io::ErrorKind::InvalidData + ); + } + + #[test] + fn decode_should_fail_if_right_codec_fails_to_decode() { + let mut codec = ChainCodec::new(TestCodec::new("hello"), ErrCodec); + assert_eq!( + codec.decode(Frame::new(b"some bytes")).unwrap_err().kind(), + io::ErrorKind::InvalidData + ); + } +} diff --git a/distant-net/src/common/transport/framed/codec/compression.rs b/distant-net/src/common/transport/framed/codec/compression.rs new file mode 100644 index 0000000..fbb40ea --- /dev/null +++ b/distant-net/src/common/transport/framed/codec/compression.rs @@ -0,0 +1,263 @@ +use super::{Codec, Frame}; +use flate2::{ + bufread::{DeflateDecoder, DeflateEncoder, GzDecoder, GzEncoder, ZlibDecoder, ZlibEncoder}, + Compression, +}; +use serde::{Deserialize, Serialize}; +use std::io::{self, Read}; + +/// Represents the level of compression to apply to data +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum CompressionLevel { + /// Use no compression (can potentially inflate data) + Zero = 0, + + /// Optimize for the speed of encoding + One = 1, + + Two = 2, + Three = 3, + Four = 4, + Five = 5, + Six = 6, + Seven = 7, + Eight = 8, + + /// Optimize for the size of data being encoded + Nine = 9, +} + +impl CompressionLevel { + /// Applies no compression + pub const NONE: Self = Self::Zero; + + /// Applies fastest compression + pub const FAST: Self = Self::One; + + /// Applies best compression to reduce size (slowest) + pub const BEST: Self = Self::Nine; +} + +impl Default for CompressionLevel { + /// Standard compression level used in zlib library is 6, which is also used here + fn default() -> Self { + Self::Six + } +} + +/// Represents the type of compression for a [`CompressionCodec`] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum CompressionType { + Deflate, + Gzip, + Zlib, + + /// Indicates an unknown compression type for use in handshakes + #[serde(other)] + Unknown, +} + +impl CompressionType { + /// Returns a list of all variants of the type *except* unknown. + pub const fn known_variants() -> &'static [CompressionType] { + &[ + CompressionType::Deflate, + CompressionType::Gzip, + CompressionType::Zlib, + ] + } + + /// Returns true if type is unknown + pub fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } + + /// Creates a new [`CompressionCodec`] for this type, failing if this type is unknown + pub fn new_codec(&self, level: CompressionLevel) -> io::Result { + CompressionCodec::from_type_and_level(*self, level) + } +} + +/// Represents a codec that applies compression during encoding and decompression during decoding +/// of a frame's item +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum CompressionCodec { + /// Apply DEFLATE compression/decompression using compression `level` + Deflate { level: CompressionLevel }, + + /// Apply gzip compression/decompression using compression `level` + Gzip { level: CompressionLevel }, + + /// Apply zlib compression/decompression using compression `level` + Zlib { level: CompressionLevel }, +} + +impl CompressionCodec { + /// Makes a new [`CompressionCodec`] based on the [`CompressionType`] and [`CompressionLevel`], + /// returning error if the type is unknown + pub fn from_type_and_level( + ty: CompressionType, + level: CompressionLevel, + ) -> io::Result { + match ty { + CompressionType::Deflate => Ok(Self::Deflate { level }), + CompressionType::Gzip => Ok(Self::Gzip { level }), + CompressionType::Zlib => Ok(Self::Zlib { level }), + CompressionType::Unknown => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Unknown compression type", + )), + } + } + + /// Create a new deflate compression codec with the specified `level` + pub fn deflate(level: impl Into) -> Self { + Self::Deflate { + level: level.into(), + } + } + + /// Create a new gzip compression codec with the specified `level` + pub fn gzip(level: impl Into) -> Self { + Self::Gzip { + level: level.into(), + } + } + + /// Create a new zlib compression codec with the specified `level` + pub fn zlib(level: impl Into) -> Self { + Self::Zlib { + level: level.into(), + } + } + + /// Returns the compression level associated with the codec + pub fn level(&self) -> CompressionLevel { + match self { + Self::Deflate { level } => *level, + Self::Gzip { level } => *level, + Self::Zlib { level } => *level, + } + } + + /// Returns the compression type associated with the codec + pub fn ty(&self) -> CompressionType { + match self { + Self::Deflate { .. } => CompressionType::Deflate, + Self::Gzip { .. } => CompressionType::Gzip, + Self::Zlib { .. } => CompressionType::Zlib, + } + } +} + +impl Codec for CompressionCodec { + fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + let item = frame.as_item(); + + let mut buf = Vec::new(); + match *self { + Self::Deflate { level } => { + DeflateEncoder::new(item, Compression::new(level as u32)).read_to_end(&mut buf)? + } + Self::Gzip { level } => { + GzEncoder::new(item, Compression::new(level as u32)).read_to_end(&mut buf)? + } + Self::Zlib { level } => { + ZlibEncoder::new(item, Compression::new(level as u32)).read_to_end(&mut buf)? + } + }; + + Ok(Frame::from(buf)) + } + + fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + let item = frame.as_item(); + + let mut buf = Vec::new(); + match *self { + Self::Deflate { .. } => DeflateDecoder::new(item).read_to_end(&mut buf)?, + Self::Gzip { .. } => GzDecoder::new(item).read_to_end(&mut buf)?, + Self::Zlib { .. } => ZlibDecoder::new(item).read_to_end(&mut buf)?, + }; + + Ok(Frame::from(buf)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + #[test] + fn encode_should_apply_appropriate_compression_algorithm() { + // Encode using DEFLATE and verify that the compression was as expected by decompressing + let mut codec = CompressionCodec::deflate(CompressionLevel::BEST); + let frame = codec.encode(Frame::new(b"some bytes")).unwrap(); + + let mut item = Vec::new(); + DeflateDecoder::new(frame.as_item()) + .read_to_end(&mut item) + .unwrap(); + assert_eq!(item, b"some bytes"); + + // Encode using gzip and verify that the compression was as expected by decompressing + let mut codec = CompressionCodec::gzip(CompressionLevel::BEST); + let frame = codec.encode(Frame::new(b"some bytes")).unwrap(); + + let mut item = Vec::new(); + GzDecoder::new(frame.as_item()) + .read_to_end(&mut item) + .unwrap(); + assert_eq!(item, b"some bytes"); + + // Encode using zlib and verify that the compression was as expected by decompressing + let mut codec = CompressionCodec::zlib(CompressionLevel::BEST); + let frame = codec.encode(Frame::new(b"some bytes")).unwrap(); + + let mut item = Vec::new(); + ZlibDecoder::new(frame.as_item()) + .read_to_end(&mut item) + .unwrap(); + assert_eq!(item, b"some bytes"); + } + + #[test] + fn decode_should_apply_appropriate_decompression_algorithm() { + // Decode using DEFLATE + let frame = { + let mut item = Vec::new(); + DeflateEncoder::new(b"some bytes".as_slice(), Compression::best()) + .read_to_end(&mut item) + .unwrap(); + Frame::from(item) + }; + let mut codec = CompressionCodec::deflate(CompressionLevel::BEST); + let frame = codec.decode(frame).unwrap(); + assert_eq!(frame, b"some bytes"); + + // Decode using gzip + let frame = { + let mut item = Vec::new(); + GzEncoder::new(b"some bytes".as_slice(), Compression::best()) + .read_to_end(&mut item) + .unwrap(); + Frame::from(item) + }; + let mut codec = CompressionCodec::gzip(CompressionLevel::BEST); + let frame = codec.decode(frame).unwrap(); + assert_eq!(frame, b"some bytes"); + + // Decode using zlib + let frame = { + let mut item = Vec::new(); + ZlibEncoder::new(b"some bytes".as_slice(), Compression::best()) + .read_to_end(&mut item) + .unwrap(); + Frame::from(item) + }; + let mut codec = CompressionCodec::zlib(CompressionLevel::BEST); + let frame = codec.decode(frame).unwrap(); + assert_eq!(frame, b"some bytes"); + } +} diff --git a/distant-net/src/common/transport/framed/codec/encryption.rs b/distant-net/src/common/transport/framed/codec/encryption.rs new file mode 100644 index 0000000..8d5eac1 --- /dev/null +++ b/distant-net/src/common/transport/framed/codec/encryption.rs @@ -0,0 +1,255 @@ +use super::{Codec, Frame}; +use derive_more::Display; +use std::{fmt, io}; + +mod key; +pub use key::*; + +/// Represents the type of encryption for a [`EncryptionCodec`] +#[derive( + Copy, Clone, Debug, Display, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, +)] +pub enum EncryptionType { + /// ChaCha20Poly1305 variant with an extended 192-bit (24-byte) nonce + #[display(fmt = "xchacha20poly1305")] + XChaCha20Poly1305, + + /// Indicates an unknown encryption type for use in handshakes + #[display(fmt = "unknown")] + #[serde(other)] + Unknown, +} + +impl EncryptionType { + /// Generates bytes for a secret key based on the encryption type + pub fn generate_secret_key_bytes(&self) -> io::Result> { + match self { + Self::XChaCha20Poly1305 => Ok(SecretKey::<32>::generate() + .unwrap() + .into_heap_secret_key() + .unprotected_into_bytes()), + Self::Unknown => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Unknown encryption type", + )), + } + } + + /// Returns a list of all variants of the type *except* unknown. + pub const fn known_variants() -> &'static [EncryptionType] { + &[EncryptionType::XChaCha20Poly1305] + } + + /// Returns true if type is unknown + pub fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } + + /// Creates a new [`EncryptionCodec`] for this type, failing if this type is unknown or the key + /// is an invalid length + pub fn new_codec(&self, key: &[u8]) -> io::Result { + EncryptionCodec::from_type_and_key(*self, key) + } +} + +/// Represents the codec that encodes & decodes frames by encrypting/decrypting them +#[derive(Clone)] +pub enum EncryptionCodec { + /// ChaCha20Poly1305 variant with an extended 192-bit (24-byte) nonce, using + /// [`XChaCha20Poly1305`] underneath + XChaCha20Poly1305 { + cipher: chacha20poly1305::XChaCha20Poly1305, + }, +} + +impl EncryptionCodec { + /// Makes a new [`EncryptionCodec`] based on the [`EncryptionType`] and `key`, returning an + /// error if the key is invalid for the encryption type or the type is unknown + pub fn from_type_and_key(ty: EncryptionType, key: &[u8]) -> io::Result { + match ty { + EncryptionType::XChaCha20Poly1305 => { + use chacha20poly1305::{KeyInit, XChaCha20Poly1305}; + let cipher = XChaCha20Poly1305::new_from_slice(key) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?; + Ok(Self::XChaCha20Poly1305 { cipher }) + } + EncryptionType::Unknown => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Encryption type is unknown", + )), + } + } + + pub fn new_xchacha20poly1305(secret_key: SecretKey32) -> EncryptionCodec { + // NOTE: This should never fail as we are enforcing the key size at compile time + Self::from_type_and_key( + EncryptionType::XChaCha20Poly1305, + secret_key.unprotected_as_bytes(), + ) + .unwrap() + } + + /// Returns the encryption type associa ted with the codec + pub fn ty(&self) -> EncryptionType { + match self { + Self::XChaCha20Poly1305 { .. } => EncryptionType::XChaCha20Poly1305, + } + } + + /// Size of nonce (in bytes) associated with the encryption algorithm + pub const fn nonce_size(&self) -> usize { + match self { + // XChaCha20Poly1305 uses a 192-bit (24-byte) key + Self::XChaCha20Poly1305 { .. } => 24, + } + } + + /// Generates a new nonce for the encryption algorithm + fn generate_nonce_bytes(&self) -> Vec { + // NOTE: As seen in orion, with a 24-bit nonce, it's safe to generate instead of + // maintaining a stateful counter due to its size (24-byte secret key generation + // will never panic) + match self { + Self::XChaCha20Poly1305 { .. } => SecretKey::<24>::generate() + .unwrap() + .into_heap_secret_key() + .unprotected_into_bytes(), + } + } +} + +impl fmt::Debug for EncryptionCodec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("EncryptionCodec") + .field("cipher", &"**OMITTED**".to_string()) + .field("nonce_size", &self.nonce_size()) + .field("ty", &self.ty().to_string()) + .finish() + } +} + +impl Codec for EncryptionCodec { + fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + let nonce_bytes = self.generate_nonce_bytes(); + + Ok(match self { + Self::XChaCha20Poly1305 { cipher } => { + use chacha20poly1305::{aead::Aead, XNonce}; + let item = frame.into_item(); + let nonce = XNonce::from_slice(&nonce_bytes); + + // Encrypt the frame's item as our ciphertext + let ciphertext = cipher + .encrypt(nonce, item.as_ref()) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?; + + // Start our frame with the nonce at the beginning + let mut frame = Frame::from(nonce_bytes); + frame.extend(ciphertext); + + frame + } + }) + } + + fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + let nonce_size = self.nonce_size(); + if frame.len() <= nonce_size { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Frame cannot have length less than {}", nonce_size + 1), + )); + } + + // Grab the nonce from the front of the frame, and then use it with the remainder + // of the frame to tease out the decrypted frame item + let item = match self { + Self::XChaCha20Poly1305 { cipher } => { + use chacha20poly1305::{aead::Aead, XNonce}; + let nonce = XNonce::from_slice(&frame.as_item()[..nonce_size]); + cipher + .decrypt(nonce, &frame.as_item()[nonce_size..]) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))? + } + }; + + Ok(Frame::from(item)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + #[test] + fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() { + let ty = EncryptionType::XChaCha20Poly1305; + let key = ty.generate_secret_key_bytes().unwrap(); + let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap(); + + let frame = codec + .encode(Frame::new(b"hello world")) + .expect("Failed to encode"); + + let nonce = &frame.as_item()[..codec.nonce_size()]; + let ciphertext = &frame.as_item()[codec.nonce_size()..]; + + // Manually build our key & cipher so we can decrypt the frame manually to ensure it is + // correct + let item = { + use chacha20poly1305::{aead::Aead, KeyInit, XChaCha20Poly1305, XNonce}; + let cipher = XChaCha20Poly1305::new_from_slice(&key).unwrap(); + cipher + .decrypt(XNonce::from_slice(nonce), ciphertext) + .expect("Failed to decrypt") + }; + assert_eq!(item, b"hello world"); + } + + #[test] + fn decode_should_fail_if_frame_length_is_smaller_than_nonce_plus_data() { + let ty = EncryptionType::XChaCha20Poly1305; + let key = ty.generate_secret_key_bytes().unwrap(); + let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap(); + + // NONCE_SIZE + 1 is minimum for frame length + let frame = Frame::from(b"a".repeat(codec.nonce_size())); + + let result = codec.decode(frame); + match result { + Err(x) if x.kind() == io::ErrorKind::InvalidData => {} + x => panic!("Unexpected result: {:?}", x), + } + } + + #[test] + fn decode_should_fail_if_unable_to_decrypt_frame_item() { + let ty = EncryptionType::XChaCha20Poly1305; + let key = ty.generate_secret_key_bytes().unwrap(); + let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap(); + + // NONCE_SIZE + 1 is minimum for frame length + let frame = Frame::from(b"a".repeat(codec.nonce_size() + 1)); + + let result = codec.decode(frame); + match result { + Err(x) if x.kind() == io::ErrorKind::InvalidData => {} + x => panic!("Unexpected result: {:?}", x), + } + } + + #[test] + fn decode_should_return_decrypted_frame_when_successful() { + let ty = EncryptionType::XChaCha20Poly1305; + let key = ty.generate_secret_key_bytes().unwrap(); + let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap(); + + let frame = codec + .encode(Frame::new(b"hello, world")) + .expect("Failed to encode"); + + let frame = codec.decode(frame).expect("Failed to decode"); + assert_eq!(frame, b"hello, world"); + } +} diff --git a/distant-net/src/common/transport/framed/codec/encryption/key.rs b/distant-net/src/common/transport/framed/codec/encryption/key.rs new file mode 100644 index 0000000..f1e47da --- /dev/null +++ b/distant-net/src/common/transport/framed/codec/encryption/key.rs @@ -0,0 +1,318 @@ +use derive_more::{Display, Error}; +use rand::{rngs::OsRng, RngCore}; +use std::{fmt, str::FromStr}; + +#[derive(Debug, Display, Error)] +pub struct SecretKeyError; + +impl From for std::io::Error { + fn from(_: SecretKeyError) -> Self { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + "not valid secret key format", + ) + } +} + +/// Represents a 16-byte (128-bit) secret key +pub type SecretKey16 = SecretKey<16>; + +/// Represents a 24-byte (192-bit) secret key +pub type SecretKey24 = SecretKey<24>; + +/// Represents a 32-byte (256-bit) secret key +pub type SecretKey32 = SecretKey<32>; + +/// Represents a secret key used with transport encryption and authentication +#[derive(Clone, PartialEq, Eq)] +pub struct SecretKey([u8; N]); + +impl fmt::Debug for SecretKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("SecretKey") + .field(&"**OMITTED**".to_string()) + .finish() + } +} + +impl Default for SecretKey { + /// Creates a new secret key of the size `N` + /// + /// ### Panic + /// + /// Will panic if `N` is less than 1 or greater than `isize::MAX` + fn default() -> Self { + Self::generate().unwrap() + } +} + +impl SecretKey { + /// Returns byte slice to the key's bytes + pub fn unprotected_as_bytes(&self) -> &[u8] { + &self.0 + } + + /// Returns reference to array of key's bytes + pub fn unprotected_as_byte_array(&self) -> &[u8; N] { + &self.0 + } + + /// Consumes the secret key and returns the array of key's bytes + pub fn unprotected_into_byte_array(self) -> [u8; N] { + self.0 + } + + /// Consumes the secret key and returns the key's bytes as a [`HeapSecretKey`] + pub fn into_heap_secret_key(self) -> HeapSecretKey { + HeapSecretKey(self.0.to_vec()) + } + + /// Returns the length of the key + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + N + } + + /// Generates a new secret key, returning success if key created or + /// failing if the desired key length is not between 1 and `isize::MAX` + pub fn generate() -> Result { + // Limitation described in https://github.com/orion-rs/orion/issues/130 + if N < 1 || N > (isize::MAX as usize) { + return Err(SecretKeyError); + } + + let mut key = [0; N]; + OsRng.fill_bytes(&mut key); + + Ok(Self(key)) + } + + /// Creates the key from the given byte slice, returning success if key created + /// or failing if the byte slice does not match the desired key length + pub fn from_slice(slice: &[u8]) -> Result { + if slice.len() != N { + return Err(SecretKeyError); + } + + let mut value = [0u8; N]; + value[..N].copy_from_slice(slice); + + Ok(Self(value)) + } +} + +impl From<[u8; N]> for SecretKey { + fn from(arr: [u8; N]) -> Self { + Self(arr) + } +} + +impl FromStr for SecretKey { + type Err = SecretKeyError; + + /// Parse a str of hex as an N-byte secret key + fn from_str(s: &str) -> Result { + let bytes = hex::decode(s).map_err(|_| SecretKeyError)?; + Self::from_slice(&bytes) + } +} + +impl fmt::Display for SecretKey { + /// Display an N-byte secret key as a hex string + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", hex::encode(self.unprotected_as_bytes())) + } +} + +/// Represents a secret key used with transport encryption and authentication that is stored on the +/// heap +#[derive(Clone, PartialEq, Eq)] +pub struct HeapSecretKey(Vec); + +impl fmt::Debug for HeapSecretKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("HeapSecretKey") + .field(&"**OMITTED**".to_string()) + .finish() + } +} + +impl HeapSecretKey { + /// Returns byte slice to the key's bytes + pub fn unprotected_as_bytes(&self) -> &[u8] { + &self.0 + } + + /// Consumes the secret key and returns the key's bytes + pub fn unprotected_into_bytes(self) -> Vec { + self.0.to_vec() + } + + /// Returns the length of the key + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.0.len() + } + + /// Generates a random key of `n` bytes in length. + /// + /// ### Note + /// + /// Will return an error if `n` < 1 or `n` > `isize::MAX`. + pub fn generate(n: usize) -> Result { + // Limitation described in https://github.com/orion-rs/orion/issues/130 + if n < 1 || n > (isize::MAX as usize) { + return Err(SecretKeyError); + } + + let mut key = Vec::new(); + let mut buf = [0; 32]; + + // Continually generate a chunk of bytes and extend our key until we've reached + // the appropriate length + while key.len() < n { + OsRng.fill_bytes(&mut buf); + key.extend_from_slice(&buf[..std::cmp::min(n - key.len(), 32)]); + } + + Ok(Self(key)) + } +} + +impl From> for HeapSecretKey { + fn from(bytes: Vec) -> Self { + Self(bytes) + } +} + +impl From<[u8; N]> for HeapSecretKey { + fn from(arr: [u8; N]) -> Self { + Self::from(arr.to_vec()) + } +} + +impl From> for HeapSecretKey { + fn from(key: SecretKey) -> Self { + key.into_heap_secret_key() + } +} + +impl FromStr for HeapSecretKey { + type Err = SecretKeyError; + + /// Parse a str of hex as secret key on heap + fn from_str(s: &str) -> Result { + Ok(Self(hex::decode(s).map_err(|_| SecretKeyError)?)) + } +} + +impl fmt::Display for HeapSecretKey { + /// Display an N-byte secret key as a hex string + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", hex::encode(self.unprotected_as_bytes())) + } +} + +impl PartialEq<[u8; N]> for HeapSecretKey { + fn eq(&self, other: &[u8; N]) -> bool { + self.0.eq(other) + } +} + +impl PartialEq for [u8; N] { + fn eq(&self, other: &HeapSecretKey) -> bool { + other.eq(self) + } +} + +impl PartialEq for &[u8; N] { + fn eq(&self, other: &HeapSecretKey) -> bool { + other.eq(*self) + } +} + +impl PartialEq<[u8]> for HeapSecretKey { + fn eq(&self, other: &[u8]) -> bool { + self.0.eq(other) + } +} + +impl PartialEq for [u8] { + fn eq(&self, other: &HeapSecretKey) -> bool { + other.eq(self) + } +} + +impl PartialEq for &[u8] { + fn eq(&self, other: &HeapSecretKey) -> bool { + other.eq(*self) + } +} + +impl PartialEq for HeapSecretKey { + fn eq(&self, other: &String) -> bool { + self.0.eq(other.as_bytes()) + } +} + +impl PartialEq for String { + fn eq(&self, other: &HeapSecretKey) -> bool { + other.eq(self) + } +} + +impl PartialEq for &String { + fn eq(&self, other: &HeapSecretKey) -> bool { + other.eq(*self) + } +} + +impl PartialEq for HeapSecretKey { + fn eq(&self, other: &str) -> bool { + self.0.eq(other.as_bytes()) + } +} + +impl PartialEq for str { + fn eq(&self, other: &HeapSecretKey) -> bool { + other.eq(self) + } +} + +impl PartialEq for &str { + fn eq(&self, other: &HeapSecretKey) -> bool { + other.eq(*self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + #[test] + fn secret_key_should_be_able_to_be_generated() { + SecretKey::<0>::generate().unwrap_err(); + + let key = SecretKey::<1>::generate().unwrap(); + assert_eq!(key.len(), 1); + + // NOTE: We aren't going to validate generating isize::MAX or +1 of that size because it + // takes a lot of time to do so + let key = SecretKey::<100>::generate().unwrap(); + assert_eq!(key.len(), 100); + } + + #[test] + fn heap_secret_key_should_be_able_to_be_generated() { + HeapSecretKey::generate(0).unwrap_err(); + + let key = HeapSecretKey::generate(1).unwrap(); + assert_eq!(key.len(), 1); + + // NOTE: We aren't going to validate generating isize::MAX or +1 of that size because it + // takes a lot of time to do so + let key = HeapSecretKey::generate(100).unwrap(); + assert_eq!(key.len(), 100); + } +} diff --git a/distant-net/src/common/transport/framed/codec/plain.rs b/distant-net/src/common/transport/framed/codec/plain.rs new file mode 100644 index 0000000..d8c180b --- /dev/null +++ b/distant-net/src/common/transport/framed/codec/plain.rs @@ -0,0 +1,22 @@ +use super::{Codec, Frame}; +use std::io; + +/// Represents a codec that does not alter the frame (synonymous with "plain text") +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +pub struct PlainCodec; + +impl PlainCodec { + pub fn new() -> Self { + Self::default() + } +} + +impl Codec for PlainCodec { + fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + Ok(frame) + } + + fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + Ok(frame) + } +} diff --git a/distant-net/src/common/transport/framed/codec/predicate.rs b/distant-net/src/common/transport/framed/codec/predicate.rs new file mode 100644 index 0000000..e4e744c --- /dev/null +++ b/distant-net/src/common/transport/framed/codec/predicate.rs @@ -0,0 +1,180 @@ +use super::{Codec, Frame}; +use std::{io, sync::Arc}; + +/// Represents a codec that invokes one of two codecs based on the given predicate +#[derive(Debug, Default, PartialEq, Eq)] +pub struct PredicateCodec { + left: T, + right: U, + predicate: Arc

, +} + +impl PredicateCodec { + /// Creates a new predicate codec where the left codec is invoked if the predicate returns true + /// and the right codec is invoked if the predicate returns false + pub fn new(left: T, right: U, predicate: P) -> Self { + Self { + left, + right, + predicate: Arc::new(predicate), + } + } + + /// Returns reference to left codec + pub fn as_left(&self) -> &T { + &self.left + } + + /// Consumes the chain and returns the left codec + pub fn into_left(self) -> T { + self.left + } + + /// Returns reference to right codec + pub fn as_right(&self) -> &U { + &self.right + } + + /// Consumes the chain and returns the right codec + pub fn into_right(self) -> U { + self.right + } + + /// Consumes the chain and returns the left and right codecs + pub fn into_left_right(self) -> (T, U) { + (self.left, self.right) + } +} + +impl Clone for PredicateCodec +where + T: Clone, + U: Clone, +{ + fn clone(&self) -> Self { + Self { + left: self.left.clone(), + right: self.right.clone(), + predicate: Arc::clone(&self.predicate), + } + } +} + +impl Codec for PredicateCodec +where + T: Codec + Clone, + U: Codec + Clone, + P: Fn(&Frame) -> bool, +{ + fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + if (self.predicate)(&frame) { + Codec::encode(&mut self.left, frame) + } else { + Codec::encode(&mut self.right, frame) + } + } + + fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + if (self.predicate)(&frame) { + Codec::decode(&mut self.left, frame) + } else { + Codec::decode(&mut self.right, frame) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + #[derive(Copy, Clone)] + struct TestCodec<'a> { + msg: &'a str, + } + + impl<'a> TestCodec<'a> { + pub fn new(msg: &'a str) -> Self { + Self { msg } + } + } + + impl Codec for TestCodec<'_> { + fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + let mut item = frame.into_item().to_vec(); + item.extend_from_slice(self.msg.as_bytes()); + Ok(Frame::from(item)) + } + + fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result> { + let item = frame.into_item().to_vec(); + let frame = Frame::new(item.strip_suffix(self.msg.as_bytes()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + format!( + "Decode failed because did not end with suffix: {}", + self.msg + ), + ) + })?); + Ok(frame.into_owned()) + } + } + + #[derive(Copy, Clone)] + struct ErrCodec; + + impl Codec for ErrCodec { + fn encode<'a>(&mut self, _frame: Frame<'a>) -> io::Result> { + Err(io::Error::from(io::ErrorKind::InvalidData)) + } + + fn decode<'a>(&mut self, _frame: Frame<'a>) -> io::Result> { + Err(io::Error::from(io::ErrorKind::InvalidData)) + } + } + + #[test] + fn encode_should_invoke_left_codec_if_predicate_returns_true() { + let mut codec = PredicateCodec::new( + TestCodec::new("hello"), + TestCodec::new("world"), + |_: &Frame| true, + ); + let frame = codec.encode(Frame::new(b"some bytes")).unwrap(); + assert_eq!(frame, b"some byteshello"); + } + + #[test] + fn encode_should_invoke_right_codec_if_predicate_returns_false() { + let mut codec = PredicateCodec::new( + TestCodec::new("hello"), + TestCodec::new("world"), + |_: &Frame| false, + ); + let frame = codec.encode(Frame::new(b"some bytes")).unwrap(); + assert_eq!(frame, b"some bytesworld"); + } + + #[test] + fn decode_should_invoke_left_codec_if_predicate_returns_true() { + let mut codec = PredicateCodec::new( + TestCodec::new("hello"), + TestCodec::new("world"), + |_: &Frame| true, + ); + let frame = codec.decode(Frame::new(b"some byteshello")).unwrap(); + assert_eq!(frame, b"some bytes"); + } + + #[test] + fn decode_should_invoke_right_codec_if_predicate_returns_false() { + let mut codec = PredicateCodec::new( + TestCodec::new("hello"), + TestCodec::new("world"), + |_: &Frame| false, + ); + let frame = codec.decode(Frame::new(b"some bytesworld")).unwrap(); + assert_eq!(frame, b"some bytes"); + } +} diff --git a/distant-net/src/auth/handshake.rs b/distant-net/src/common/transport/framed/exchange.rs similarity index 71% rename from distant-net/src/auth/handshake.rs rename to distant-net/src/common/transport/framed/exchange.rs index b342720..4ef88fd 100644 --- a/distant-net/src/auth/handshake.rs +++ b/distant-net/src/common/transport/framed/exchange.rs @@ -1,3 +1,4 @@ +use super::SecretKey32; use p256::{ecdh::EphemeralSecret, PublicKey}; use rand::rngs::OsRng; use sha2::Sha256; @@ -9,16 +10,14 @@ pub use pkb::PublicKeyBytes; mod salt; pub use salt::Salt; -/// 32-byte key shared by handshake -pub type SharedKey = [u8; 32]; - -/// Utility to perform a handshake -pub struct Handshake { +/// Utility to support performing an exchange of public keys and salts in order to derive a shared +/// key between two separate entities +pub struct KeyExchange { secret: EphemeralSecret, salt: Salt, } -impl Default for Handshake { +impl Default for KeyExchange { // Create a new handshake instance with a secret and salt fn default() -> Self { let secret = EphemeralSecret::random(&mut OsRng); @@ -28,7 +27,7 @@ impl Default for Handshake { } } -impl Handshake { +impl KeyExchange { // Return encoded bytes of public key pub fn pk_bytes(&self) -> PublicKeyBytes { PublicKeyBytes::from(self.secret.public_key()) @@ -39,8 +38,13 @@ impl Handshake { &self.salt } - pub fn handshake(&self, public_key: PublicKeyBytes, salt: Salt) -> io::Result { - // Decode the public key of the client + /// Derives a shared secret using another key exchange's public key and salt + pub fn derive_shared_secret( + &self, + public_key: PublicKeyBytes, + salt: Salt, + ) -> io::Result { + // Decode the public key of the other side let decoded_public_key = PublicKey::try_from(public_key)?; // Produce a salt that is consistent with what the other side will do @@ -55,7 +59,7 @@ impl Handshake { // Derive a shared key (32 bytes) let mut shared_key = [0u8; 32]; match hkdf.expand(&[], &mut shared_key) { - Ok(_) => Ok(shared_key), + Ok(_) => Ok(SecretKey32::from(shared_key)), Err(x) => Err(io::Error::new(io::ErrorKind::InvalidData, x.to_string())), } } diff --git a/distant-net/src/auth/handshake/pkb.rs b/distant-net/src/common/transport/framed/exchange/pkb.rs similarity index 100% rename from distant-net/src/auth/handshake/pkb.rs rename to distant-net/src/common/transport/framed/exchange/pkb.rs diff --git a/distant-net/src/auth/handshake/salt.rs b/distant-net/src/common/transport/framed/exchange/salt.rs similarity index 100% rename from distant-net/src/auth/handshake/salt.rs rename to distant-net/src/common/transport/framed/exchange/salt.rs diff --git a/distant-net/src/common/transport/framed/frame.rs b/distant-net/src/common/transport/framed/frame.rs new file mode 100644 index 0000000..47ffb7f --- /dev/null +++ b/distant-net/src/common/transport/framed/frame.rs @@ -0,0 +1,343 @@ +use bytes::{Buf, BufMut, BytesMut}; +use std::{borrow::Cow, io}; + +/// Represents a frame whose lifetime is static +pub type OwnedFrame = Frame<'static>; + +/// Represents some data wrapped in a frame in order to ship it over the network. The format is +/// simple and follows `{len}{item}` where `len` is the length of the item as a `u64`. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Frame<'a> { + /// Represents the item that will be shipped across the network + item: Cow<'a, [u8]>, +} + +impl<'a> Frame<'a> { + /// Creates a new frame wrapping the `item` that will be shipped across the network + pub fn new(item: &'a [u8]) -> Self { + Self { + item: Cow::Borrowed(item), + } + } + + /// Consumes the frame and returns its underlying item. + pub fn into_item(self) -> Cow<'a, [u8]> { + self.item + } +} + +impl Frame<'_> { + /// Total bytes to use as the header field denoting a frame's size + pub const HEADER_SIZE: usize = 8; + + /// Returns the len (in bytes) of the item wrapped by the frame + pub fn len(&self) -> usize { + self.item.len() + } + + /// Returns true if the frame is comprised of zero bytes + pub fn is_empty(&self) -> bool { + self.item.is_empty() + } + + /// Returns a reference to the bytes of the frame's item + pub fn as_item(&self) -> &[u8] { + &self.item + } + + /// Writes the frame to a new [`Vec`] of bytes, returning them on success + pub fn try_to_bytes(&self) -> io::Result> { + let mut bytes = BytesMut::new(); + self.write(&mut bytes)?; + Ok(bytes.to_vec()) + } + + /// Writes the frame to the end of `dst`, including the header representing the length of the + /// item as part of the written bytes + pub fn write(&self, dst: &mut BytesMut) -> io::Result<()> { + if self.item.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Empty item provided", + )); + } + + dst.reserve(Self::HEADER_SIZE + self.item.len()); + + // Add data in form of {LEN}{ITEM} + dst.put_u64((self.item.len()) as u64); + dst.put_slice(&self.item); + + Ok(()) + } + + /// Attempts to read a frame from `src`, returning `Some(Frame)` if a frame was found + /// (including the header) or `None` if the current `src` does not contain a frame + pub fn read(src: &mut BytesMut) -> io::Result> { + // First, check if we have more data than just our frame's message length + if src.len() <= Self::HEADER_SIZE { + return Ok(None); + } + + // Second, retrieve total size of our frame's message + let item_len = u64::from_be_bytes(src[..Self::HEADER_SIZE].try_into().unwrap()) as usize; + + // In the case that our item len is 0, we skip over the invalid frame + if item_len == 0 { + // Ensure we advance to remove the frame + src.advance(Self::HEADER_SIZE); + + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Frame's msg cannot have length of 0", + )); + } + + // Third, check if we have all data for our frame; if not, exit early + if src.len() < item_len + Self::HEADER_SIZE { + return Ok(None); + } + + // Fourth, get and return our item + let item = src[Self::HEADER_SIZE..(Self::HEADER_SIZE + item_len)].to_vec(); + + // Fifth, advance so frame is no longer kept around + src.advance(Self::HEADER_SIZE + item_len); + + Ok(Some(Frame::from(item))) + } + + /// Checks if a full frame is available from `src`, returning true if a frame was found false + /// if the current `src` does not contain a frame. Does not consume the frame. + pub fn available(src: &BytesMut) -> bool { + matches!(Frame::read(&mut src.clone()), Ok(Some(_))) + } + + /// Returns a new frame which is identical but has a lifetime tied to this frame. + pub fn as_borrowed(&self) -> Frame<'_> { + let item = match &self.item { + Cow::Borrowed(x) => x, + Cow::Owned(x) => x.as_slice(), + }; + + Frame { + item: Cow::Borrowed(item), + } + } + + /// Converts the [`Frame`] into an owned copy. + /// + /// If you construct the frame from an item with a non-static lifetime, you may run into + /// lifetime problems due to the way the struct is designed. Calling this function will ensure + /// that the returned value has a static lifetime. + /// + /// This is different from just cloning. Cloning the frame will just copy the references, and + /// thus the lifetime will remain the same. + pub fn into_owned(self) -> OwnedFrame { + Frame { + item: Cow::from(self.item.into_owned()), + } + } +} + +impl<'a> From<&'a [u8]> for Frame<'a> { + /// Consumes the byte slice and returns a [`Frame`] whose item references those bytes. + fn from(item: &'a [u8]) -> Self { + Self { + item: Cow::Borrowed(item), + } + } +} + +impl<'a, const N: usize> From<&'a [u8; N]> for Frame<'a> { + /// Consumes the byte array slice and returns a [`Frame`] whose item references those bytes. + fn from(item: &'a [u8; N]) -> Self { + Self { + item: Cow::Borrowed(item), + } + } +} + +impl From<[u8; N]> for OwnedFrame { + /// Consumes an array of bytes and returns a [`Frame`] with an owned item of those bytes + /// allocated as a [`Vec`]. + fn from(item: [u8; N]) -> Self { + Self { + item: Cow::Owned(item.to_vec()), + } + } +} + +impl From> for OwnedFrame { + /// Consumes a [`Vec`] of bytes and returns a [`Frame`] with an owned item of those bytes. + fn from(item: Vec) -> Self { + Self { + item: Cow::Owned(item), + } + } +} + +impl AsRef<[u8]> for Frame<'_> { + /// Returns a reference to this [`Frame`]'s item as bytes. + fn as_ref(&self) -> &[u8] { + AsRef::as_ref(&self.item) + } +} + +impl Extend for Frame<'_> { + /// Extends the [`Frame`]'s item with the provided bytes, allocating an owned [`Vec`] + /// underneath if this frame had borrowed bytes as an item. + fn extend>(&mut self, iter: T) { + match &mut self.item { + // If we only have a borrowed item, we need to allocate it into a new vec so we can + // extend it with additional bytes + Cow::Borrowed(item) => { + let mut item = item.to_vec(); + item.extend(iter); + self.item = Cow::Owned(item); + } + + // Othewise, if we already have an owned allocation of bytes, we just extend it + Cow::Owned(item) => { + item.extend(iter); + } + } + } +} + +impl PartialEq<[u8]> for Frame<'_> { + /// Test if [`Frame`]'s item matches the provided bytes. + fn eq(&self, item: &[u8]) -> bool { + self.item.as_ref().eq(item) + } +} + +impl<'a> PartialEq<&'a [u8]> for Frame<'_> { + /// Test if [`Frame`]'s item matches the provided bytes. + fn eq(&self, item: &&'a [u8]) -> bool { + self.item.as_ref().eq(*item) + } +} + +impl PartialEq<[u8; N]> for Frame<'_> { + /// Test if [`Frame`]'s item matches the provided bytes. + fn eq(&self, item: &[u8; N]) -> bool { + self.item.as_ref().eq(item) + } +} + +impl<'a, const N: usize> PartialEq<&'a [u8; N]> for Frame<'_> { + /// Test if [`Frame`]'s item matches the provided bytes. + fn eq(&self, item: &&'a [u8; N]) -> bool { + self.item.as_ref().eq(*item) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + #[test] + fn write_should_fail_when_item_is_zero_bytes() { + let frame = Frame::new(&[]); + + let mut buf = BytesMut::new(); + let result = frame.write(&mut buf); + + match result { + Err(x) if x.kind() == io::ErrorKind::InvalidInput => {} + x => panic!("Unexpected result: {:?}", x), + } + } + + #[test] + fn write_should_build_a_frame_containing_a_length_and_item() { + let frame = Frame::new(b"hello, world"); + + let mut buf = BytesMut::new(); + frame.write(&mut buf).expect("Failed to write"); + + let len = buf.get_u64() as usize; + assert_eq!(len, 12, "Wrong length writed"); + assert_eq!(buf.as_ref(), b"hello, world"); + } + + #[test] + fn read_should_return_none_if_data_smaller_than_or_equal_to_item_length_field() { + let mut buf = BytesMut::new(); + buf.put_bytes(0, Frame::HEADER_SIZE); + + let result = Frame::read(&mut buf); + assert!( + matches!(result, Ok(None)), + "Unexpected result: {:?}", + result + ); + } + + #[test] + fn read_should_return_none_if_not_enough_data_for_frame() { + let mut buf = BytesMut::new(); + buf.put_u64(0); + + let result = Frame::read(&mut buf); + assert!( + matches!(result, Ok(None)), + "Unexpected result: {:?}", + result + ); + } + + #[test] + fn read_should_fail_if_writed_item_length_is_zero() { + let mut buf = BytesMut::new(); + buf.put_u64(0); + buf.put_u8(255); + + let result = Frame::read(&mut buf); + match result { + Err(x) if x.kind() == io::ErrorKind::InvalidData => {} + x => panic!("Unexpected result: {:?}", x), + } + } + + #[test] + fn read_should_advance_src_by_frame_size_even_if_item_length_is_zero() { + let mut buf = BytesMut::new(); + buf.put_u64(0); + buf.put_bytes(0, 3); + + assert!( + Frame::read(&mut buf).is_err(), + "read unexpectedly succeeded" + ); + assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf"); + } + + #[test] + fn read_should_advance_src_by_frame_size_when_successful() { + // Add 3 extra bytes after a full frame + let mut buf = BytesMut::new(); + Frame::new(b"hello, world") + .write(&mut buf) + .expect("Failed to write"); + buf.put_bytes(0, 3); + + assert!(Frame::read(&mut buf).is_ok(), "read unexpectedly failed"); + assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf"); + } + + #[test] + fn read_should_return_some_byte_vec_when_successful() { + let mut buf = BytesMut::new(); + Frame::new(b"hello, world") + .write(&mut buf) + .expect("Failed to write"); + + let item = Frame::read(&mut buf) + .expect("Failed to read") + .expect("Item not properly captured"); + assert_eq!(item, b"hello, world"); + } +} diff --git a/distant-net/src/common/transport/framed/handshake.rs b/distant-net/src/common/transport/framed/handshake.rs new file mode 100644 index 0000000..5cd14e1 --- /dev/null +++ b/distant-net/src/common/transport/framed/handshake.rs @@ -0,0 +1,57 @@ +use super::{CompressionLevel, CompressionType, EncryptionType}; + +/// Definition of the handshake to perform for a transport +#[derive(Clone, Debug)] +pub enum Handshake { + /// Indicates that the handshake is being performed from the client-side + Client { + /// Preferred compression algorithm when presented options by server + preferred_compression_type: Option, + + /// Preferred compression level when presented options by server + preferred_compression_level: Option, + + /// Preferred encryption algorithm when presented options by server + preferred_encryption_type: Option, + }, + + /// Indicates that the handshake is being performed from the server-side + Server { + /// List of available compression algorithms for use between client and server + compression_types: Vec, + + /// List of available encryption algorithms for use between client and server + encryption_types: Vec, + }, +} + +impl Handshake { + /// Creates a new client handshake definition, providing defaults for the preferred compression + /// type, compression level, and encryption type + pub fn client() -> Self { + Self::Client { + preferred_compression_type: None, + preferred_compression_level: None, + preferred_encryption_type: Some(EncryptionType::XChaCha20Poly1305), + } + } + + /// Creates a new server handshake definition, providing defaults for the compression types and + /// encryption types by including all known variants + pub fn server() -> Self { + Self::Server { + compression_types: CompressionType::known_variants().to_vec(), + encryption_types: EncryptionType::known_variants().to_vec(), + } + } + + /// Returns true if handshake is from client-side + pub fn is_client(&self) -> bool { + matches!(self, Self::Client { .. }) + } + + /// Returns true if handshake is from server-side + pub fn is_server(&self) -> bool { + matches!(self, Self::Server { .. }) + } +} diff --git a/distant-net/src/common/transport/inmemory.rs b/distant-net/src/common/transport/inmemory.rs new file mode 100644 index 0000000..5d5ffdc --- /dev/null +++ b/distant-net/src/common/transport/inmemory.rs @@ -0,0 +1,512 @@ +use super::{Interest, Ready, Reconnectable, Transport}; +use async_trait::async_trait; +use std::{ + io, + sync::{Mutex, MutexGuard}, +}; +use tokio::sync::mpsc::{ + self, + error::{TryRecvError, TrySendError}, +}; + +/// Represents a [`Transport`] comprised of two inmemory channels +#[derive(Debug)] +pub struct InmemoryTransport { + tx: mpsc::Sender>, + rx: Mutex>>, + + /// Internal storage used when we get more data from a `try_read` than can be returned + buf: Mutex>>, +} + +impl InmemoryTransport { + /// Creates a new transport where `tx` is used to send data out of the transport during + /// [`try_write`] and `rx` is used to receive data into the transport during [`try_read`]. + /// + /// [`try_read`]: Transport::try_read + /// [`try_write`]: Transport::try_write + pub fn new(tx: mpsc::Sender>, rx: mpsc::Receiver>) -> Self { + Self { + tx, + rx: Mutex::new(rx), + buf: Mutex::new(None), + } + } + + /// Returns (incoming_tx, outgoing_rx, transport) where `incoming_tx` is used to send data to + /// the transport where it will be consumed during [`try_read`] and `outgoing_rx` is used to + /// receive data from the transport when it is written using [`try_write`]. + /// + /// [`try_read`]: Transport::try_read + /// [`try_write`]: Transport::try_write + pub fn make(buffer: usize) -> (mpsc::Sender>, mpsc::Receiver>, Self) { + let (incoming_tx, incoming_rx) = mpsc::channel(buffer); + let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer); + + ( + incoming_tx, + outgoing_rx, + Self::new(outgoing_tx, incoming_rx), + ) + } + + /// Returns pair of transports that are connected such that one sends to the other and + /// vice versa + pub fn pair(buffer: usize) -> (Self, Self) { + let (tx, rx, transport) = Self::make(buffer); + (transport, Self::new(tx, rx)) + } + + /// Links two independent [`InmemoryTransport`] together by dropping their internal channels + /// and generating new ones of `buffer` capacity to connect these transports. + /// + /// ### Note + /// + /// This will drop any pre-existing data in the internal storage to avoid corruption. + pub fn link(&mut self, other: &mut InmemoryTransport, buffer: usize) { + let (incoming_tx, incoming_rx) = mpsc::channel(buffer); + let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer); + + self.buf = Mutex::new(None); + self.tx = outgoing_tx; + self.rx = Mutex::new(incoming_rx); + + other.buf = Mutex::new(None); + other.tx = incoming_tx; + other.rx = Mutex::new(outgoing_rx); + } + + /// Returns true if the read channel is closed, meaning it will no longer receive more data. + /// This does not factor in data remaining in the internal buffer, meaning that this may return + /// true while the transport still has data remaining in the internal buffer. + /// + /// NOTE: Because there is no `is_closed` on the receiver, we have to actually try to + /// read from the receiver to see if it is disconnected, adding any received data + /// to our internal buffer if it is not disconnected and has data available + /// + /// Track https://github.com/tokio-rs/tokio/issues/4638 for future `is_closed` on rx + fn is_rx_closed(&self) -> bool { + match self.rx.lock().unwrap().try_recv() { + Ok(mut data) => { + let mut buf_lock = self.buf.lock().unwrap(); + + let data = match buf_lock.take() { + Some(mut existing) => { + existing.append(&mut data); + existing + } + None => data, + }; + + *buf_lock = Some(data); + + false + } + Err(TryRecvError::Empty) => false, + Err(TryRecvError::Disconnected) => true, + } + } +} + +#[async_trait] +impl Reconnectable for InmemoryTransport { + /// Once the underlying channels have closed, there is no way for this transport to + /// re-establish those channels; therefore, reconnecting will fail with + /// [`ErrorKind::ConnectionRefused`] if either underlying channel has closed. + /// + /// [`ErrorKind::ConnectionRefused`]: io::ErrorKind::ConnectionRefused + async fn reconnect(&mut self) -> io::Result<()> { + if self.tx.is_closed() || self.is_rx_closed() { + Err(io::Error::from(io::ErrorKind::ConnectionRefused)) + } else { + Ok(()) + } + } +} + +#[async_trait] +impl Transport for InmemoryTransport { + fn try_read(&self, buf: &mut [u8]) -> io::Result { + // Lock our internal storage to ensure that nothing else mutates it for the lifetime of + // this call as we want to make sure that data is read and stored in order + let mut buf_lock = self.buf.lock().unwrap(); + + // Check if we have data in our internal buffer, and if so feed it into the outgoing buf + if let Some(data) = buf_lock.take() { + return Ok(copy_and_store(buf_lock, data, buf)); + } + + match self.rx.lock().unwrap().try_recv() { + Ok(data) => Ok(copy_and_store(buf_lock, data, buf)), + Err(TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)), + Err(TryRecvError::Disconnected) => Ok(0), + } + } + + fn try_write(&self, buf: &[u8]) -> io::Result { + match self.tx.try_send(buf.to_vec()) { + Ok(()) => Ok(buf.len()), + Err(TrySendError::Full(_)) => Err(io::Error::from(io::ErrorKind::WouldBlock)), + Err(TrySendError::Closed(_)) => Ok(0), + } + } + + async fn ready(&self, interest: Interest) -> io::Result { + let mut status = Ready::EMPTY; + + if interest.is_readable() { + // TODO: Replace `self.is_rx_closed()` with `self.rx.is_closed()` once the tokio issue + // is resolved that adds `is_closed` to the `mpsc::Receiver` + // + // See https://github.com/tokio-rs/tokio/issues/4638 + status |= if self.is_rx_closed() && self.buf.lock().unwrap().is_none() { + Ready::READ_CLOSED + } else { + Ready::READABLE + }; + } + + if interest.is_writable() { + status |= if self.tx.is_closed() { + Ready::WRITE_CLOSED + } else { + Ready::WRITABLE + }; + } + + Ok(status) + } +} + +/// Copies `data` into `out`, storing any overflow from `data` into the storage pointed to by the +/// mutex `buf_lock` +fn copy_and_store( + mut buf_lock: MutexGuard>>, + mut data: Vec, + out: &mut [u8], +) -> usize { + // NOTE: We can get data that is larger than the destination buf; so, + // we store as much as we can and queue up the rest in our temporary + // storage for future retrievals + if data.len() > out.len() { + let n = out.len(); + out.copy_from_slice(&data[..n]); + *buf_lock = Some(data.split_off(n)); + n + } else { + let n = data.len(); + out[..n].copy_from_slice(&data); + n + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::TransportExt; + use test_log::test; + + #[test] + fn is_rx_closed_should_properly_reflect_if_internal_rx_channel_is_closed() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (read_tx, read_rx) = mpsc::channel(1); + + let transport = InmemoryTransport::new(write_tx, read_rx); + + // Not closed when the channel is empty + assert!(!transport.is_rx_closed()); + + read_tx.try_send(b"some bytes".to_vec()).unwrap(); + + // Not closed when the channel has data (will queue up data) + assert!(!transport.is_rx_closed()); + assert_eq!( + transport.buf.lock().unwrap().as_deref().unwrap(), + b"some bytes" + ); + + // Queue up one more set of bytes and then close the channel + read_tx.try_send(b"more".to_vec()).unwrap(); + drop(read_tx); + + // Not closed when channel has closed but has something remaining in the queue + assert!(!transport.is_rx_closed()); + assert_eq!( + transport.buf.lock().unwrap().as_deref().unwrap(), + b"some bytesmore" + ); + + // Closed once there is nothing left in the channel and it has closed + assert!(transport.is_rx_closed()); + assert_eq!( + transport.buf.lock().unwrap().as_deref().unwrap(), + b"some bytesmore" + ); + } + + #[test] + fn try_read_should_succeed_if_able_to_read_entire_data_through_channel() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (read_tx, read_rx) = mpsc::channel(1); + + let transport = InmemoryTransport::new(write_tx, read_rx); + + // Queue up some data to be read + read_tx.try_send(b"some bytes".to_vec()).unwrap(); + + let mut buf = [0; 10]; + assert_eq!(transport.try_read(&mut buf).unwrap(), 10); + assert_eq!(&buf[..10], b"some bytes"); + } + + #[test] + fn try_read_should_succeed_if_reading_cached_data_from_previous_read() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (read_tx, read_rx) = mpsc::channel(1); + + let transport = InmemoryTransport::new(write_tx, read_rx); + + // Queue up some data to be read + read_tx.try_send(b"some bytes".to_vec()).unwrap(); + + let mut buf = [0; 5]; + assert_eq!(transport.try_read(&mut buf).unwrap(), 5); + assert_eq!(&buf[..5], b"some "); + + // Queue up some new data to be read (previous data already consumed) + read_tx.try_send(b"more".to_vec()).unwrap(); + + let mut buf = [0; 2]; + assert_eq!(transport.try_read(&mut buf).unwrap(), 2); + assert_eq!(&buf[..2], b"by"); + + // Inmemory still separates buffered bytes from next channel recv() + let mut buf = [0; 5]; + assert_eq!(transport.try_read(&mut buf).unwrap(), 3); + assert_eq!(&buf[..3], b"tes"); + + let mut buf = [0; 5]; + assert_eq!(transport.try_read(&mut buf).unwrap(), 4); + assert_eq!(&buf[..4], b"more"); + } + + #[test] + fn try_read_should_fail_with_would_block_if_channel_is_empty() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (_read_tx, read_rx) = mpsc::channel(1); + + let transport = InmemoryTransport::new(write_tx, read_rx); + + assert_eq!( + transport.try_read(&mut [0; 5]).unwrap_err().kind(), + io::ErrorKind::WouldBlock + ); + } + + #[test] + fn try_read_should_succeed_with_zero_bytes_read_if_channel_closed() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (read_tx, read_rx) = mpsc::channel(1); + + // Drop to close the read channel + drop(read_tx); + + let transport = InmemoryTransport::new(write_tx, read_rx); + assert_eq!(transport.try_read(&mut [0; 5]).unwrap(), 0); + } + + #[test] + fn try_write_should_succeed_if_able_to_send_data_through_channel() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (_read_tx, read_rx) = mpsc::channel(1); + + let transport = InmemoryTransport::new(write_tx, read_rx); + + let value = b"some bytes"; + assert_eq!(transport.try_write(value).unwrap(), value.len()); + } + + #[test] + fn try_write_should_fail_with_would_block_if_channel_capacity_has_been_reached() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (_read_tx, read_rx) = mpsc::channel(1); + + let transport = InmemoryTransport::new(write_tx, read_rx); + + // Fill up the channel + transport + .try_write(b"some bytes") + .expect("Failed to fill channel"); + + assert_eq!( + transport.try_write(b"some bytes").unwrap_err().kind(), + io::ErrorKind::WouldBlock + ); + } + + #[test] + fn try_write_should_succeed_with_zero_bytes_written_if_channel_closed() { + let (write_tx, write_rx) = mpsc::channel(1); + let (_read_tx, read_rx) = mpsc::channel(1); + + // Drop to close the write channel + drop(write_rx); + + let transport = InmemoryTransport::new(write_tx, read_rx); + assert_eq!(transport.try_write(b"some bytes").unwrap(), 0); + } + + #[test(tokio::test)] + async fn reconnect_should_fail_if_read_channel_closed() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (_, read_rx) = mpsc::channel(1); + let mut transport = InmemoryTransport::new(write_tx, read_rx); + + assert_eq!( + transport.reconnect().await.unwrap_err().kind(), + io::ErrorKind::ConnectionRefused + ); + } + + #[test(tokio::test)] + async fn reconnect_should_fail_if_write_channel_closed() { + let (write_tx, _) = mpsc::channel(1); + let (_read_tx, read_rx) = mpsc::channel(1); + let mut transport = InmemoryTransport::new(write_tx, read_rx); + + assert_eq!( + transport.reconnect().await.unwrap_err().kind(), + io::ErrorKind::ConnectionRefused + ); + } + + #[test(tokio::test)] + async fn reconnect_should_succeed_if_both_channels_open() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (_read_tx, read_rx) = mpsc::channel(1); + let mut transport = InmemoryTransport::new(write_tx, read_rx); + + transport.reconnect().await.unwrap(); + } + + #[test(tokio::test)] + async fn ready_should_report_read_closed_if_channel_closed_and_internal_buf_empty() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (read_tx, read_rx) = mpsc::channel(1); + + // Drop to close the read channel + drop(read_tx); + + let transport = InmemoryTransport::new(write_tx, read_rx); + let ready = transport.ready(Interest::READABLE).await.unwrap(); + assert!(ready.is_readable()); + assert!(ready.is_read_closed()); + } + + #[test(tokio::test)] + async fn ready_should_report_readable_if_channel_not_closed() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (_read_tx, read_rx) = mpsc::channel(1); + + let transport = InmemoryTransport::new(write_tx, read_rx); + let ready = transport.ready(Interest::READABLE).await.unwrap(); + assert!(ready.is_readable()); + assert!(!ready.is_read_closed()); + } + + #[test(tokio::test)] + async fn ready_should_report_readable_if_internal_buf_not_empty() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (read_tx, read_rx) = mpsc::channel(1); + + // Drop to close the read channel + drop(read_tx); + + let transport = InmemoryTransport::new(write_tx, read_rx); + + // Assign some data to our buffer to ensure that we test this condition + *transport.buf.lock().unwrap() = Some(vec![1]); + + let ready = transport.ready(Interest::READABLE).await.unwrap(); + assert!(ready.is_readable()); + assert!(!ready.is_read_closed()); + } + + #[test(tokio::test)] + async fn ready_should_report_writable_if_channel_not_closed() { + let (write_tx, _write_rx) = mpsc::channel(1); + let (_read_tx, read_rx) = mpsc::channel(1); + + let transport = InmemoryTransport::new(write_tx, read_rx); + let ready = transport.ready(Interest::WRITABLE).await.unwrap(); + assert!(ready.is_writable()); + assert!(!ready.is_write_closed()); + } + + #[test(tokio::test)] + async fn ready_should_report_write_closed_if_channel_closed() { + let (write_tx, write_rx) = mpsc::channel(1); + let (_read_tx, read_rx) = mpsc::channel(1); + + // Drop to close the write channel + drop(write_rx); + + let transport = InmemoryTransport::new(write_tx, read_rx); + let ready = transport.ready(Interest::WRITABLE).await.unwrap(); + assert!(ready.is_writable()); + assert!(ready.is_write_closed()); + } + + #[test(tokio::test)] + async fn make_should_return_sender_that_sends_data_to_transport() { + let (tx, _, transport) = InmemoryTransport::make(3); + + tx.send(b"test msg 1".to_vec()).await.unwrap(); + tx.send(b"test msg 2".to_vec()).await.unwrap(); + tx.send(b"test msg 3".to_vec()).await.unwrap(); + + // Should get data matching a singular message + let mut buf = [0; 256]; + let len = transport.try_read(&mut buf).unwrap(); + assert_eq!(&buf[..len], b"test msg 1"); + + // Next call would get the second message + let len = transport.try_read(&mut buf).unwrap(); + assert_eq!(&buf[..len], b"test msg 2"); + + // When the last of the senders is dropped, we should still get + // the rest of the data that was sent first before getting + // an indicator that there is no more data + drop(tx); + + let len = transport.try_read(&mut buf).unwrap(); + assert_eq!(&buf[..len], b"test msg 3"); + + let len = transport.try_read(&mut buf).unwrap(); + assert_eq!(len, 0, "Unexpectedly got more data"); + } + + #[test(tokio::test)] + async fn make_should_return_receiver_that_receives_data_from_transport() { + let (_, mut rx, transport) = InmemoryTransport::make(3); + + transport.write_all(b"test msg 1").await.unwrap(); + transport.write_all(b"test msg 2").await.unwrap(); + transport.write_all(b"test msg 3").await.unwrap(); + + // Should get data matching a singular message + assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec())); + + // Next call would get the second message + assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec())); + + // When the transport is dropped, we should still get + // the rest of the data that was sent first before getting + // an indicator that there is no more data + drop(transport); + + assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec())); + + assert_eq!(rx.recv().await, None, "Unexpectedly got more data"); + } +} diff --git a/distant-net/src/common/transport/tcp.rs b/distant-net/src/common/transport/tcp.rs new file mode 100644 index 0000000..86c833c --- /dev/null +++ b/distant-net/src/common/transport/tcp.rs @@ -0,0 +1,222 @@ +use super::{Interest, Ready, Reconnectable, Transport}; +use async_trait::async_trait; +use std::{fmt, io, net::IpAddr}; +use tokio::net::{TcpStream, ToSocketAddrs}; + +/// Represents a [`Transport`] that leverages a TCP stream +pub struct TcpTransport { + pub(crate) addr: IpAddr, + pub(crate) port: u16, + pub(crate) inner: TcpStream, +} + +impl TcpTransport { + /// Creates a new stream by connecting to a remote machine at the specified + /// IP address and port + pub async fn connect(addrs: impl ToSocketAddrs) -> io::Result { + let stream = TcpStream::connect(addrs).await?; + let addr = stream.peer_addr()?; + Ok(Self { + addr: addr.ip(), + port: addr.port(), + inner: stream, + }) + } + + /// Returns the IP address that the stream is connected to + pub fn ip_addr(&self) -> IpAddr { + self.addr + } + + /// Returns the port that the stream is connected to + pub fn port(&self) -> u16 { + self.port + } +} + +impl fmt::Debug for TcpTransport { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TcpTransport") + .field("addr", &self.addr) + .field("port", &self.port) + .finish() + } +} + +#[async_trait] +impl Reconnectable for TcpTransport { + async fn reconnect(&mut self) -> io::Result<()> { + self.inner = TcpStream::connect((self.addr, self.port)).await?; + Ok(()) + } +} + +#[async_trait] +impl Transport for TcpTransport { + fn try_read(&self, buf: &mut [u8]) -> io::Result { + self.inner.try_read(buf) + } + + fn try_write(&self, buf: &[u8]) -> io::Result { + self.inner.try_write(buf) + } + + async fn ready(&self, interest: Interest) -> io::Result { + self.inner.ready(interest).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::TransportExt; + use std::net::{Ipv6Addr, SocketAddr}; + use test_log::test; + use tokio::{net::TcpListener, sync::oneshot, task::JoinHandle}; + + async fn find_ephemeral_addr() -> SocketAddr { + // Start a listener on a distinct port, get its port, and kill it + // NOTE: This is a race condition as something else could bind to + // this port inbetween us killing it and us attempting to + // connect to it. We're willing to take that chance + let addr = IpAddr::V6(Ipv6Addr::LOCALHOST); + + let listener = TcpListener::bind((addr, 0)) + .await + .expect("Failed to bind on an ephemeral port"); + + let port = listener + .local_addr() + .expect("Failed to look up ephemeral port") + .port(); + + SocketAddr::from((addr, port)) + } + + async fn start_and_run_server(tx: oneshot::Sender) -> io::Result<()> { + let addr = find_ephemeral_addr().await; + + // Start listening at the distinct address + let listener = TcpListener::bind(addr).await?; + + // Send the address back to our main test thread + tx.send(addr) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string()))?; + + run_server(listener).await + } + + async fn run_server(listener: TcpListener) -> io::Result<()> { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + // Get the connection + let (mut conn, _) = listener.accept().await?; + + // Send some data to the connection (10 bytes) + conn.write_all(b"hello conn").await?; + + // Receive some data from the connection (12 bytes) + let mut buf: [u8; 12] = [0; 12]; + let _ = conn.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server"); + + Ok(()) + } + + #[test(tokio::test)] + async fn should_fail_to_connect_if_nothing_listening() { + let addr = find_ephemeral_addr().await; + + // Now this should fail as we've stopped what was listening + TcpTransport::connect(addr).await.expect_err(&format!( + "Unexpectedly succeeded in connecting to ghost address: {}", + addr + )); + } + + #[test(tokio::test)] + async fn should_be_able_to_read_and_write_data() { + let (tx, rx) = oneshot::channel(); + + // Spawn a task that will wait for a connection, send data, + // and receive data that it will return in the task + let task: JoinHandle> = tokio::spawn(start_and_run_server(tx)); + + // Wait for the server to be ready + let addr = rx.await.expect("Failed to get server server address"); + + // Connect to the socket, send some bytes, and get some bytes + let mut buf: [u8; 10] = [0; 10]; + + let conn = TcpTransport::connect(&addr) + .await + .expect("Conn failed to connect"); + + // Continually read until we get all of the data + conn.read_exact(&mut buf) + .await + .expect("Conn failed to read"); + assert_eq!(&buf, b"hello conn"); + + conn.write_all(b"hello server") + .await + .expect("Conn failed to write"); + + // Verify that the task has completed by waiting on it + let _ = task.await.expect("Server task failed unexpectedly"); + } + + #[test(tokio::test)] + async fn should_be_able_to_reconnect() { + let (tx, rx) = oneshot::channel(); + + // Spawn a task that will wait for a connection, send data, + // and receive data that it will return in the task + let task: JoinHandle> = tokio::spawn(start_and_run_server(tx)); + + // Wait for the server to be ready + let addr = rx.await.expect("Failed to get server server address"); + + // Connect to the server + let mut conn = TcpTransport::connect(&addr) + .await + .expect("Conn failed to connect"); + + // Kill the server to make the connection fail + task.abort(); + + // Verify the connection fails by trying to read from it (should get connection reset) + conn.readable() + .await + .expect("Failed to wait for conn to be readable"); + let res = conn.read_exact(&mut [0; 10]).await; + assert!( + matches!(res, Ok(0) | Err(_)), + "Unexpected read result: {res:?}" + ); + + // Restart the server + let task: JoinHandle> = tokio::spawn(run_server( + TcpListener::bind(addr) + .await + .expect("Failed to rebind server"), + )); + + // Reconnect to the socket, send some bytes, and get some bytes + let mut buf: [u8; 10] = [0; 10]; + conn.reconnect().await.expect("Conn failed to reconnect"); + + // Continually read until we get all of the data + conn.read_exact(&mut buf) + .await + .expect("Conn failed to read"); + assert_eq!(&buf, b"hello conn"); + + conn.write_all(b"hello server") + .await + .expect("Conn failed to write"); + + // Verify that the task has completed by waiting on it + let _ = task.await.expect("Server task failed unexpectedly"); + } +} diff --git a/distant-net/src/common/transport/test.rs b/distant-net/src/common/transport/test.rs new file mode 100644 index 0000000..d5ac10a --- /dev/null +++ b/distant-net/src/common/transport/test.rs @@ -0,0 +1,48 @@ +use super::{Interest, Ready, Reconnectable, Transport}; +use async_trait::async_trait; +use std::io; + +pub type TryReadFn = Box io::Result + Send + Sync>; +pub type TryWriteFn = Box io::Result + Send + Sync>; +pub type ReadyFn = Box io::Result + Send + Sync>; +pub type ReconnectFn = Box io::Result<()> + Send + Sync>; + +pub struct TestTransport { + pub f_try_read: TryReadFn, + pub f_try_write: TryWriteFn, + pub f_ready: ReadyFn, + pub f_reconnect: ReconnectFn, +} + +impl Default for TestTransport { + fn default() -> Self { + Self { + f_try_read: Box::new(|_| unimplemented!()), + f_try_write: Box::new(|_| unimplemented!()), + f_ready: Box::new(|_| unimplemented!()), + f_reconnect: Box::new(|| unimplemented!()), + } + } +} + +#[async_trait] +impl Reconnectable for TestTransport { + async fn reconnect(&mut self) -> io::Result<()> { + (self.f_reconnect)() + } +} + +#[async_trait] +impl Transport for TestTransport { + fn try_read(&self, buf: &mut [u8]) -> io::Result { + (self.f_try_read)(buf) + } + + fn try_write(&self, buf: &[u8]) -> io::Result { + (self.f_try_write)(buf) + } + + async fn ready(&self, interest: Interest) -> io::Result { + (self.f_ready)(interest) + } +} diff --git a/distant-net/src/common/transport/unix.rs b/distant-net/src/common/transport/unix.rs new file mode 100644 index 0000000..0fe496f --- /dev/null +++ b/distant-net/src/common/transport/unix.rs @@ -0,0 +1,216 @@ +use super::{Interest, Ready, Reconnectable, Transport}; +use async_trait::async_trait; +use std::{ + fmt, io, + path::{Path, PathBuf}, +}; +use tokio::net::UnixStream; + +/// Represents a [`Transport`] that leverages a Unix socket +pub struct UnixSocketTransport { + pub(crate) path: PathBuf, + pub(crate) inner: UnixStream, +} + +impl UnixSocketTransport { + /// Creates a new stream by connecting to the specified path + pub async fn connect(path: impl AsRef) -> io::Result { + let stream = UnixStream::connect(path.as_ref()).await?; + Ok(Self { + path: path.as_ref().to_path_buf(), + inner: stream, + }) + } + + /// Returns the path to the socket + pub fn path(&self) -> &Path { + &self.path + } +} + +impl fmt::Debug for UnixSocketTransport { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UnixSocketTransport") + .field("path", &self.path) + .finish() + } +} + +#[async_trait] +impl Reconnectable for UnixSocketTransport { + async fn reconnect(&mut self) -> io::Result<()> { + self.inner = UnixStream::connect(self.path.as_path()).await?; + Ok(()) + } +} + +#[async_trait] +impl Transport for UnixSocketTransport { + fn try_read(&self, buf: &mut [u8]) -> io::Result { + self.inner.try_read(buf) + } + + fn try_write(&self, buf: &[u8]) -> io::Result { + self.inner.try_write(buf) + } + + async fn ready(&self, interest: Interest) -> io::Result { + self.inner.ready(interest).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::TransportExt; + use tempfile::NamedTempFile; + use test_log::test; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::UnixListener, + sync::oneshot, + task::JoinHandle, + }; + + async fn start_and_run_server(tx: oneshot::Sender) -> io::Result<()> { + // Generate a socket path and delete the file after so there is nothing there + let path = NamedTempFile::new() + .expect("Failed to create socket file") + .path() + .to_path_buf(); + + // Start listening at the socket path + let listener = UnixListener::bind(&path)?; + + // Send the path back to our main test thread + tx.send(path) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x.display().to_string()))?; + + run_server(listener).await + } + + async fn run_server(listener: UnixListener) -> io::Result<()> { + // Get the connection + let (mut conn, _) = listener.accept().await?; + + // Send some data to the connection (10 bytes) + conn.write_all(b"hello conn").await?; + + // Receive some data from the connection (12 bytes) + let mut buf: [u8; 12] = [0; 12]; + let _ = conn.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server"); + + Ok(()) + } + + #[test(tokio::test)] + async fn should_fail_to_connect_if_socket_does_not_exist() { + // Generate a socket path and delete the file after so there is nothing there + let path = NamedTempFile::new() + .expect("Failed to create socket file") + .path() + .to_path_buf(); + + // Now this should fail as we're already bound to the name + UnixSocketTransport::connect(&path) + .await + .expect_err("Unexpectedly succeeded in connecting to missing socket"); + } + + #[test(tokio::test)] + async fn should_fail_to_connect_if_path_is_not_a_socket() { + // Generate a regular file + let path = NamedTempFile::new() + .expect("Failed to create socket file") + .into_temp_path(); + + // Now this should fail as this file is not a socket + UnixSocketTransport::connect(&path) + .await + .expect_err("Unexpectedly succeeded in connecting to regular file"); + } + + #[test(tokio::test)] + async fn should_be_able_to_read_and_write_data() { + let (tx, rx) = oneshot::channel(); + + // Spawn a task that will wait for a connection, send data, + // and receive data that it will return in the task + let task: JoinHandle> = tokio::spawn(start_and_run_server(tx)); + + // Wait for the server to be ready + let path = rx.await.expect("Failed to get server socket path"); + + // Connect to the socket, send some bytes, and get some bytes + let mut buf: [u8; 10] = [0; 10]; + + let conn = UnixSocketTransport::connect(&path) + .await + .expect("Conn failed to connect"); + conn.read_exact(&mut buf) + .await + .expect("Conn failed to read"); + assert_eq!(&buf, b"hello conn"); + + conn.write_all(b"hello server") + .await + .expect("Conn failed to write"); + + // Verify that the task has completed by waiting on it + let _ = task.await.expect("Server task failed unexpectedly"); + } + + #[test(tokio::test)] + async fn should_be_able_to_reconnect() { + let (tx, rx) = oneshot::channel(); + + // Spawn a task that will wait for a connection, send data, + // and receive data that it will return in the task + let task: JoinHandle> = tokio::spawn(start_and_run_server(tx)); + + // Wait for the server to be ready + let path = rx.await.expect("Failed to get server socket path"); + + // Connect to the server + let mut conn = UnixSocketTransport::connect(&path) + .await + .expect("Conn failed to connect"); + + // Kill the server to make the connection fail + task.abort(); + + // Verify the connection fails by trying to read from it (should get connection reset) + conn.readable() + .await + .expect("Failed to wait for conn to be readable"); + let res = conn.read_exact(&mut [0; 10]).await; + assert!( + matches!(res, Ok(0) | Err(_)), + "Unexpected read result: {res:?}" + ); + + // Restart the server (need to remove the socket file) + let _ = tokio::fs::remove_file(&path).await; + let task: JoinHandle> = tokio::spawn(run_server( + UnixListener::bind(&path).expect("Failed to rebind server"), + )); + + // Reconnect to the socket, send some bytes, and get some bytes + let mut buf: [u8; 10] = [0; 10]; + conn.reconnect().await.expect("Conn failed to reconnect"); + + // Continually read until we get all of the data + conn.read_exact(&mut buf) + .await + .expect("Conn failed to read"); + assert_eq!(&buf, b"hello conn"); + + conn.write_all(b"hello server") + .await + .expect("Conn failed to write"); + + // Verify that the task has completed by waiting on it + let _ = task.await.expect("Server task failed unexpectedly"); + } +} diff --git a/distant-net/src/common/transport/windows.rs b/distant-net/src/common/transport/windows.rs new file mode 100644 index 0000000..7bd902d --- /dev/null +++ b/distant-net/src/common/transport/windows.rs @@ -0,0 +1,186 @@ +use super::{Interest, Ready, Reconnectable, Transport}; +use async_trait::async_trait; +use std::{ + ffi::{OsStr, OsString}, + fmt, io, +}; + +mod pipe; +pub use pipe::NamedPipe; + +/// Represents a [`Transport`] that leverages a named Windows pipe (client or server) +pub struct WindowsPipeTransport { + pub(crate) addr: OsString, + pub(crate) inner: NamedPipe, +} + +impl WindowsPipeTransport { + /// Establishes a connection to the pipe with the specified name, using the + /// name for a local pipe address in the form of `\\.\pipe\my_pipe_name` where + /// `my_pipe_name` is provided to this function + pub async fn connect_local(name: impl AsRef) -> io::Result { + let mut addr = OsString::from(r"\\.\pipe\"); + addr.push(name.as_ref()); + Self::connect(addr).await + } + + /// Establishes a connection to the pipe at the specified address + /// + /// Address may be something like `\.\pipe\my_pipe_name` + pub async fn connect(addr: impl Into) -> io::Result { + let addr = addr.into(); + let inner = NamedPipe::connect_as_client(&addr).await?; + + Ok(Self { addr, inner }) + } + + /// Returns the addr that the listener is bound to + pub fn addr(&self) -> &OsStr { + &self.addr + } +} + +impl fmt::Debug for WindowsPipeTransport { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WindowsPipeTransport") + .field("addr", &self.addr) + .finish() + } +} + +#[async_trait] +impl Reconnectable for WindowsPipeTransport { + async fn reconnect(&mut self) -> io::Result<()> { + // We cannot reconnect from server-side + if self.inner.is_server() { + return Err(io::Error::from(io::ErrorKind::Unsupported)); + } + + self.inner = NamedPipe::connect_as_client(&self.addr).await?; + Ok(()) + } +} + +#[async_trait] +impl Transport for WindowsPipeTransport { + fn try_read(&self, buf: &mut [u8]) -> io::Result { + match &self.inner { + NamedPipe::Client(x) => x.try_read(buf), + NamedPipe::Server(x) => x.try_read(buf), + } + } + + fn try_write(&self, buf: &[u8]) -> io::Result { + match &self.inner { + NamedPipe::Client(x) => x.try_write(buf), + NamedPipe::Server(x) => x.try_write(buf), + } + } + + async fn ready(&self, interest: Interest) -> io::Result { + match &self.inner { + NamedPipe::Client(x) => x.ready(interest).await, + NamedPipe::Server(x) => x.ready(interest).await, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::TransportExt; + use test_log::test; + use tokio::{ + net::windows::named_pipe::{NamedPipeServer, ServerOptions}, + sync::oneshot, + task::JoinHandle, + }; + + async fn start_and_run_server(tx: oneshot::Sender) -> io::Result<()> { + let pipe = start_server(tx).await?; + run_server(pipe).await + } + + async fn start_server(tx: oneshot::Sender) -> io::Result { + // Generate a pipe address (not just a name) + let addr = format!(r"\\.\pipe\test_pipe_{}", rand::random::()); + + // Listen at the pipe + let pipe = ServerOptions::new() + .first_pipe_instance(true) + .create(&addr)?; + + // Send the address back to our main test thread + tx.send(addr) + .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; + + Ok(pipe) + } + + async fn run_server(pipe: NamedPipeServer) -> io::Result<()> { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + // Get the connection + let mut conn = { + pipe.connect().await?; + pipe + }; + + // Send some data to the connection (10 bytes) + conn.write_all(b"hello conn").await?; + + // Receive some data from the connection (12 bytes) + let mut buf: [u8; 12] = [0; 12]; + let _ = conn.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello server"); + + Ok(()) + } + + #[test(tokio::test)] + async fn should_fail_to_connect_if_pipe_does_not_exist() { + // Generate a pipe name + let name = format!("test_pipe_{}", rand::random::()); + + // Now this should fail as we're already bound to the name + WindowsPipeTransport::connect_local(&name) + .await + .expect_err("Unexpectedly succeeded in connecting to missing pipe"); + } + + #[test(tokio::test)] + async fn should_be_able_to_read_and_write_data() { + let (tx, rx) = oneshot::channel(); + + // Spawn a task that will wait for a connection, send data, + // and receive data that it will return in the task + let task: JoinHandle> = tokio::spawn(start_and_run_server(tx)); + + // Wait for the server to be ready + let address = rx.await.expect("Failed to get server address"); + + // Connect to the pipe, send some bytes, and get some bytes + let mut buf: [u8; 10] = [0; 10]; + + let conn = WindowsPipeTransport::connect(&address) + .await + .expect("Conn failed to connect"); + conn.read_exact(&mut buf) + .await + .expect("Conn failed to read"); + assert_eq!(&buf, b"hello conn"); + + conn.write_all(b"hello server") + .await + .expect("Conn failed to write"); + + // Verify that the task has completed by waiting on it + let _ = task.await.expect("Server task failed unexpectedly"); + } + + #[test(tokio::test)] + #[ignore] + async fn should_be_able_to_reconnect() { + todo!(); + } +} diff --git a/distant-net/src/common/transport/windows/pipe.rs b/distant-net/src/common/transport/windows/pipe.rs new file mode 100644 index 0000000..331bfad --- /dev/null +++ b/distant-net/src/common/transport/windows/pipe.rs @@ -0,0 +1,92 @@ +use derive_more::{From, TryInto}; +use std::{ffi::OsStr, io, time::Duration}; +use tokio::net::windows::named_pipe::{ClientOptions, NamedPipeClient, NamedPipeServer}; + +// Equivalent to winapi::shared::winerror::ERROR_PIPE_BUSY +// DWORD -> c_uLong -> u32 +const ERROR_PIPE_BUSY: u32 = 231; + +// Time between attempts to connect to a busy pipe +const BUSY_PIPE_SLEEP_DURATION: Duration = Duration::from_millis(50); + +/// Represents a named pipe from either a client or server perspective +#[derive(From, TryInto)] +pub enum NamedPipe { + Client(NamedPipeClient), + Server(NamedPipeServer), +} + +impl NamedPipe { + /// Returns true if the underlying named pipe is a client named pipe + pub fn is_client(&self) -> bool { + matches!(self, Self::Client(_)) + } + + /// Returns a reference to the underlying named client pipe + pub fn as_client(&self) -> Option<&NamedPipeClient> { + match self { + Self::Client(x) => Some(x), + _ => None, + } + } + + /// Returns a mutable reference to the underlying named client pipe + pub fn as_mut_client(&mut self) -> Option<&mut NamedPipeClient> { + match self { + Self::Client(x) => Some(x), + _ => None, + } + } + + /// Consumes and returns the underlying named client pipe + pub fn into_client(self) -> Option { + match self { + Self::Client(x) => Some(x), + _ => None, + } + } + + /// Returns true if the underlying named pipe is a server named pipe + pub fn is_server(&self) -> bool { + matches!(self, Self::Server(_)) + } + + /// Returns a reference to the underlying named server pipe + pub fn as_server(&self) -> Option<&NamedPipeServer> { + match self { + Self::Server(x) => Some(x), + _ => None, + } + } + + /// Returns a mutable reference to the underlying named server pipe + pub fn as_mut_server(&mut self) -> Option<&mut NamedPipeServer> { + match self { + Self::Server(x) => Some(x), + _ => None, + } + } + + /// Consumes and returns the underlying named server pipe + pub fn into_server(self) -> Option { + match self { + Self::Server(x) => Some(x), + _ => None, + } + } + + /// Attempts to connect as a client pipe + pub(super) async fn connect_as_client(addr: &OsStr) -> io::Result { + let pipe = loop { + match ClientOptions::new().open(addr) { + Ok(client) => break client, + Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), + Err(e) => return Err(e), + } + + tokio::time::sleep(BUSY_PIPE_SLEEP_DURATION).await; + }; + + Ok(NamedPipe::from(pipe)) + } +} diff --git a/distant-net/src/utils.rs b/distant-net/src/common/utils.rs similarity index 73% rename from distant-net/src/utils.rs rename to distant-net/src/common/utils.rs index d4ef150..413a1b8 100644 --- a/distant-net/src/utils.rs +++ b/distant-net/src/common/utils.rs @@ -1,5 +1,9 @@ -use serde::{de::DeserializeOwned, Serialize}; -use std::{future::Future, io, time::Duration}; +use serde::{ + de::{DeserializeOwned, Deserializer, Error as SerdeError, Visitor}, + ser::Serializer, + Serialize, +}; +use std::{fmt, future::Future, io, marker::PhantomData, str::FromStr, time::Duration}; use tokio::{sync::mpsc, task::JoinHandle}; pub fn serialize_to_vec(value: &T) -> io::Result> { @@ -20,6 +24,46 @@ pub fn deserialize_from_slice(slice: &[u8]) -> io::Result(deserializer: D) -> Result +where + D: Deserializer<'de>, + T: FromStr, + T::Err: fmt::Display, +{ + struct Helper(PhantomData); + + impl<'de, S> Visitor<'de> for Helper + where + S: FromStr, + ::Err: fmt::Display, + { + type Value = S; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "a string") + } + + fn visit_str(self, value: &str) -> Result + where + E: SerdeError, + { + value.parse::().map_err(SerdeError::custom) + } + } + + deserializer.deserialize_str(Helper(PhantomData)) +} + +/// From https://docs.rs/serde_with/1.14.0/src/serde_with/rust.rs.html#121-127 +pub fn serialize_to_str(value: &T, serializer: S) -> Result +where + T: fmt::Display, + S: Serializer, +{ + serializer.collect_str(&value) +} + pub(crate) struct Timer where T: Send + 'static, @@ -54,17 +98,13 @@ where } } - /// Returns duration of the timer - pub fn duration(&self) -> Duration { - self.duration - } - /// Starts the timer, re-starting the countdown if already running. If the callback has already /// been completed, this timer will not invoke it again; however, this will start the timer /// itself, which will wait the duration and then fail to trigger the callback pub fn start(&mut self) { // Cancel the active timer task self.stop(); + self.active_timer = None; // Exit early if callback completed as starting will do nothing if self.callback.is_finished() { @@ -82,20 +122,21 @@ where /// Stops the timer, cancelling the internal task, but leaving the callback in place in case /// the timer is re-started later - pub fn stop(&mut self) { - // Delete the active timer task - if let Some(task) = self.active_timer.take() { + pub fn stop(&self) { + if let Some(task) = self.active_timer.as_ref() { task.abort(); } } + /// Returns true if the timer is actively running + pub fn is_running(&self) -> bool { + self.active_timer.is_some() && !self.active_timer.as_ref().unwrap().is_finished() + } + /// Aborts the timer's callback task and internal task to trigger the callback, which means /// that the timer will never complete the callback and starting will have no effect pub fn abort(&self) { - if let Some(task) = self.active_timer.as_ref() { - task.abort(); - } - + self.stop(); self.callback.abort(); } } @@ -106,8 +147,9 @@ mod tests { mod timer { use super::*; + use test_log::test; - #[tokio::test] + #[test(tokio::test)] async fn should_not_invoke_callback_regardless_of_time_if_not_started() { let timer = Timer::new(Duration::default(), async {}); @@ -119,9 +161,9 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn should_not_invoke_callback_if_only_stop_called() { - let mut timer = Timer::new(Duration::default(), async {}); + let timer = Timer::new(Duration::default(), async {}); timer.stop(); tokio::time::sleep(Duration::from_millis(300)).await; @@ -132,7 +174,7 @@ mod tests { ); } - #[tokio::test] + #[test(tokio::test)] async fn should_finish_callback_but_not_trigger_it_if_abort_called() { let (tx, mut rx) = mpsc::channel(1); @@ -147,7 +189,7 @@ mod tests { assert!(rx.try_recv().is_err(), "Callback triggered unexpectedly"); } - #[tokio::test] + #[test(tokio::test)] async fn should_trigger_callback_after_time_elapses_once_started() { let (tx, mut rx) = mpsc::channel(1); @@ -162,7 +204,7 @@ mod tests { assert!(rx.try_recv().is_ok(), "Callback not triggered"); } - #[tokio::test] + #[test(tokio::test)] async fn should_trigger_callback_even_if_timer_dropped() { let (tx, mut rx) = mpsc::channel(1); diff --git a/distant-net/src/id.rs b/distant-net/src/id.rs deleted file mode 100644 index b2ccda8..0000000 --- a/distant-net/src/id.rs +++ /dev/null @@ -1,2 +0,0 @@ -/// Id associated with an active connection -pub type ConnectionId = u64; diff --git a/distant-net/src/key.rs b/distant-net/src/key.rs deleted file mode 100644 index 4217a92..0000000 --- a/distant-net/src/key.rs +++ /dev/null @@ -1,100 +0,0 @@ -use derive_more::{Display, Error}; -use rand::{rngs::OsRng, RngCore}; -use std::{fmt, str::FromStr}; - -#[derive(Debug, Display, Error)] -pub struct SecretKeyError; - -/// Represents a 32-byte secret key -pub type SecretKey32 = SecretKey<32>; - -/// Represents a secret key used with transport encryption and authentication -#[derive(Clone, PartialEq, Eq)] -pub struct SecretKey([u8; N]); - -impl fmt::Debug for SecretKey { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("SecretKey") - .field(&"**OMITTED**".to_string()) - .finish() - } -} - -impl Default for SecretKey { - /// Creates a new secret key of the size `N` - /// - /// ### Panic - /// - /// Will panic if `N` is less than 1 or greater than `isize::MAX` - fn default() -> Self { - Self::generate().unwrap() - } -} - -impl SecretKey { - /// Returns byte slice to the key's bytes - pub fn unprotected_as_bytes(&self) -> &[u8] { - &self.0 - } - - /// Returns reference to array of key's bytes - pub fn unprotected_as_byte_array(&self) -> &[u8; N] { - &self.0 - } - - /// Returns the length of the key - #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> usize { - N - } - - /// Generates a new secret key, returning success if key created or - /// failing if the desired key length is not between 1 and `isize::MAX` - pub fn generate() -> Result { - // Limitation described in https://github.com/orion-rs/orion/issues/130 - if N < 1 || N > (isize::MAX as usize) { - return Err(SecretKeyError); - } - - let mut key = [0; N]; - OsRng.fill_bytes(&mut key); - - Ok(Self(key)) - } - - /// Creates the key from the given byte slice, returning success if key created - /// or failing if the byte slice does not match the desired key length - pub fn from_slice(slice: &[u8]) -> Result { - if slice.len() != N { - return Err(SecretKeyError); - } - - let mut value = [0u8; N]; - value[..N].copy_from_slice(slice); - - Ok(Self(value)) - } -} - -impl From<[u8; N]> for SecretKey { - fn from(arr: [u8; N]) -> Self { - Self(arr) - } -} - -impl FromStr for SecretKey { - type Err = SecretKeyError; - - /// Parse a str of hex as an N-byte secret key - fn from_str(s: &str) -> Result { - let bytes = hex::decode(s).map_err(|_| SecretKeyError)?; - Self::from_slice(&bytes) - } -} - -impl fmt::Display for SecretKey { - /// Display an N-byte secret key as a hex string - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", hex::encode(self.unprotected_as_bytes())) - } -} diff --git a/distant-net/src/lib.rs b/distant-net/src/lib.rs index b85de6a..f23a7c8 100644 --- a/distant-net/src/lib.rs +++ b/distant-net/src/lib.rs @@ -1,27 +1,10 @@ -mod any; -mod auth; -mod client; -mod codec; -mod id; -mod key; -mod listener; -mod packet; -mod port; -mod server; -mod transport; -mod utils; +pub mod client; +pub mod common; +pub mod manager; +pub mod server; -pub use any::*; -pub use auth::*; -pub use client::*; -pub use codec::*; -pub use id::*; -pub use key::*; -pub use listener::*; -pub use packet::*; -pub use port::*; -pub use server::*; -pub use transport::*; +pub use client::{Client, ReconnectStrategy}; +pub use server::Server; pub use log; pub use paste; diff --git a/distant-core/src/manager.rs b/distant-net/src/manager.rs similarity index 100% rename from distant-core/src/manager.rs rename to distant-net/src/manager.rs diff --git a/distant-net/src/manager/client.rs b/distant-net/src/manager/client.rs new file mode 100644 index 0000000..7142d86 --- /dev/null +++ b/distant-net/src/manager/client.rs @@ -0,0 +1,626 @@ +use crate::{ + client::Client, + common::{ + authentication::{ + msg::{Authentication, AuthenticationResponse}, + AuthHandler, + }, + ConnectionId, Destination, Map, Request, + }, + manager::data::{ + ConnectionInfo, ConnectionList, ManagerCapabilities, ManagerRequest, ManagerResponse, + }, +}; +use log::*; +use std::io; + +mod channel; +pub use channel::*; + +/// Represents a client that can connect to a remote server manager. +pub type ManagerClient = Client; + +impl ManagerClient { + /// Request that the manager launches a new server at the given `destination` with `options` + /// being passed for destination-specific details, returning the new `destination` of the + /// spawned server. + /// + /// The provided `handler` will be used for any authentication requirements when connecting to + /// the remote machine to spawn the server. + pub async fn launch( + &mut self, + destination: impl Into, + options: impl Into, + mut handler: impl AuthHandler + Send, + ) -> io::Result { + let destination = Box::new(destination.into()); + let options = options.into(); + trace!("launch({}, {})", destination, options); + + let mut mailbox = self + .mail(ManagerRequest::Launch { + destination: destination.clone(), + options, + }) + .await?; + + // Continue to process authentication challenges and other details until we are either + // launched or fail + while let Some(res) = mailbox.next().await { + match res.payload { + ManagerResponse::Authenticate { id, msg } => match msg { + Authentication::Initialization(x) => { + if log::log_enabled!(Level::Debug) { + debug!( + "Initializing authentication, supporting {}", + x.methods + .iter() + .map(ToOwned::to_owned) + .collect::>() + .join(",") + ); + } + let msg = AuthenticationResponse::Initialization( + handler.on_initialization(x).await?, + ); + self.fire(Request::new(ManagerRequest::Authenticate { id, msg })) + .await?; + } + Authentication::StartMethod(x) => { + debug!("Starting authentication method {}", x.method); + } + Authentication::Challenge(x) => { + if log::log_enabled!(Level::Debug) { + for question in x.questions.iter() { + debug!( + "Received challenge question [{}]: {}", + question.label, question.text + ); + } + } + let msg = AuthenticationResponse::Challenge(handler.on_challenge(x).await?); + self.fire(Request::new(ManagerRequest::Authenticate { id, msg })) + .await?; + } + Authentication::Verification(x) => { + debug!("Received verification request {}: {}", x.kind, x.text); + let msg = + AuthenticationResponse::Verification(handler.on_verification(x).await?); + self.fire(Request::new(ManagerRequest::Authenticate { id, msg })) + .await?; + } + Authentication::Info(x) => { + info!("{}", x.text); + } + Authentication::Error(x) => { + error!("{}", x.text); + if x.is_fatal() { + return Err(x.into_io_permission_denied()); + } + } + Authentication::Finished => { + debug!("Finished authentication for {destination}"); + } + }, + ManagerResponse::Launched { destination } => return Ok(destination), + ManagerResponse::Error { description } => { + return Err(io::Error::new(io::ErrorKind::Other, description)) + } + x => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )) + } + } + } + + Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "Missing connection confirmation", + )) + } + + /// Request that the manager establishes a new connection at the given `destination` + /// with `options` being passed for destination-specific details. + /// + /// The provided `handler` will be used for any authentication requirements when connecting to + /// the server. + pub async fn connect( + &mut self, + destination: impl Into, + options: impl Into, + mut handler: impl AuthHandler + Send, + ) -> io::Result { + let destination = Box::new(destination.into()); + let options = options.into(); + trace!("connect({}, {})", destination, options); + + let mut mailbox = self + .mail(ManagerRequest::Connect { + destination: destination.clone(), + options, + }) + .await?; + + // Continue to process authentication challenges and other details until we are either + // connected or fail + while let Some(res) = mailbox.next().await { + match res.payload { + ManagerResponse::Authenticate { id, msg } => match msg { + Authentication::Initialization(x) => { + if log::log_enabled!(Level::Debug) { + debug!( + "Initializing authentication, supporting {}", + x.methods + .iter() + .map(ToOwned::to_owned) + .collect::>() + .join(",") + ); + } + let msg = AuthenticationResponse::Initialization( + handler.on_initialization(x).await?, + ); + self.fire(Request::new(ManagerRequest::Authenticate { id, msg })) + .await?; + } + Authentication::StartMethod(x) => { + debug!("Starting authentication method {}", x.method); + } + Authentication::Challenge(x) => { + if log::log_enabled!(Level::Debug) { + for question in x.questions.iter() { + debug!( + "Received challenge question [{}]: {}", + question.label, question.text + ); + } + } + let msg = AuthenticationResponse::Challenge(handler.on_challenge(x).await?); + self.fire(Request::new(ManagerRequest::Authenticate { id, msg })) + .await?; + } + Authentication::Verification(x) => { + debug!("Received verification request {}: {}", x.kind, x.text); + let msg = + AuthenticationResponse::Verification(handler.on_verification(x).await?); + self.fire(Request::new(ManagerRequest::Authenticate { id, msg })) + .await?; + } + Authentication::Info(x) => { + info!("{}", x.text); + } + Authentication::Error(x) => { + error!("{}", x.text); + if x.is_fatal() { + return Err(x.into_io_permission_denied()); + } + } + Authentication::Finished => { + debug!("Finished authentication for {destination}"); + } + }, + ManagerResponse::Connected { id } => return Ok(id), + ManagerResponse::Error { description } => { + return Err(io::Error::new(io::ErrorKind::Other, description)) + } + x => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )) + } + } + } + + Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "Missing connection confirmation", + )) + } + + /// Establishes a channel with the server represented by the `connection_id`, + /// returning a [`RawChannel`] acting as the connection. + /// + /// ### Note + /// + /// Multiple calls to open a channel against the same connection will result in establishing a + /// duplicate channel to the same server, so take care when using this method. + pub async fn open_raw_channel( + &mut self, + connection_id: ConnectionId, + ) -> io::Result { + trace!("open_raw_channel({})", connection_id); + RawChannel::spawn(connection_id, self).await + } + + /// Retrieves a list of supported capabilities + pub async fn capabilities(&mut self) -> io::Result { + trace!("capabilities()"); + let res = self.send(ManagerRequest::Capabilities).await?; + match res.payload { + ManagerResponse::Capabilities { supported } => Ok(supported), + ManagerResponse::Error { description } => { + Err(io::Error::new(io::ErrorKind::Other, description)) + } + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )), + } + } + + /// Retrieves information about a specific connection + pub async fn info(&mut self, id: ConnectionId) -> io::Result { + trace!("info({})", id); + let res = self.send(ManagerRequest::Info { id }).await?; + match res.payload { + ManagerResponse::Info(info) => Ok(info), + ManagerResponse::Error { description } => { + Err(io::Error::new(io::ErrorKind::Other, description)) + } + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )), + } + } + + /// Kills the specified connection + pub async fn kill(&mut self, id: ConnectionId) -> io::Result<()> { + trace!("kill({})", id); + let res = self.send(ManagerRequest::Kill { id }).await?; + match res.payload { + ManagerResponse::Killed => Ok(()), + ManagerResponse::Error { description } => { + Err(io::Error::new(io::ErrorKind::Other, description)) + } + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )), + } + } + + /// Retrieves a list of active connections + pub async fn list(&mut self) -> io::Result { + trace!("list()"); + let res = self.send(ManagerRequest::List).await?; + match res.payload { + ManagerResponse::List(list) => Ok(list), + ManagerResponse::Error { description } => { + Err(io::Error::new(io::ErrorKind::Other, description)) + } + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Got unexpected response: {:?}", x), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::{ReconnectStrategy, UntypedClient}; + use crate::common::authentication::DummyAuthHandler; + use crate::common::{Connection, InmemoryTransport, Request, Response}; + + fn setup() -> (ManagerClient, Connection) { + let (client, server) = Connection::pair(100); + let client = UntypedClient::spawn(client, ReconnectStrategy::Fail).into_typed_client(); + (client, server) + } + + #[inline] + fn test_error() -> io::Error { + io::Error::new(io::ErrorKind::Interrupted, "test error") + } + + #[inline] + fn test_error_response() -> ManagerResponse { + ManagerResponse::from(test_error()) + } + + #[tokio::test] + async fn connect_should_report_error_if_receives_error_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + transport + .write_frame_for(&Response::new(request.id, test_error_response())) + .await + .unwrap(); + }); + + let err = client + .connect( + "scheme://host".parse::().unwrap(), + "key=value".parse::().unwrap(), + DummyAuthHandler, + ) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!(err.to_string(), test_error().to_string()); + } + + #[tokio::test] + async fn connect_should_report_error_if_receives_unexpected_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + transport + .write_frame_for(&Response::new(request.id, ManagerResponse::Killed)) + .await + .unwrap(); + }); + + let err = client + .connect( + "scheme://host".parse::().unwrap(), + "key=value".parse::().unwrap(), + DummyAuthHandler, + ) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn connect_should_return_id_from_successful_response() { + let (mut client, mut transport) = setup(); + + let expected_id = 999; + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + transport + .write_frame_for(&Response::new( + request.id, + ManagerResponse::Connected { id: expected_id }, + )) + .await + .unwrap(); + }); + + let id = client + .connect( + "scheme://host".parse::().unwrap(), + "key=value".parse::().unwrap(), + DummyAuthHandler, + ) + .await + .unwrap(); + assert_eq!(id, expected_id); + } + + #[tokio::test] + async fn info_should_report_error_if_receives_error_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + transport + .write_frame_for(&Response::new(request.id, test_error_response())) + .await + .unwrap(); + }); + + let err = client.info(123).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!(err.to_string(), test_error().to_string()); + } + + #[tokio::test] + async fn info_should_report_error_if_receives_unexpected_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + transport + .write_frame_for(&Response::new(request.id, ManagerResponse::Killed)) + .await + .unwrap(); + }); + + let err = client.info(123).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn info_should_return_connection_info_from_successful_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + let info = ConnectionInfo { + id: 123, + destination: "scheme://host".parse::().unwrap(), + options: "key=value".parse::().unwrap(), + }; + + transport + .write_frame_for(&Response::new(request.id, ManagerResponse::Info(info))) + .await + .unwrap(); + }); + + let info = client.info(123).await.unwrap(); + assert_eq!(info.id, 123); + assert_eq!( + info.destination, + "scheme://host".parse::().unwrap() + ); + assert_eq!(info.options, "key=value".parse::().unwrap()); + } + + #[tokio::test] + async fn list_should_report_error_if_receives_error_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + transport + .write_frame_for(&Response::new(request.id, test_error_response())) + .await + .unwrap(); + }); + + let err = client.list().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!(err.to_string(), test_error().to_string()); + } + + #[tokio::test] + async fn list_should_report_error_if_receives_unexpected_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + transport + .write_frame_for(&Response::new(request.id, ManagerResponse::Killed)) + .await + .unwrap(); + }); + + let err = client.list().await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn list_should_return_connection_list_from_successful_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + let mut list = ConnectionList::new(); + list.insert(123, "scheme://host".parse::().unwrap()); + + transport + .write_frame_for(&Response::new(request.id, ManagerResponse::List(list))) + .await + .unwrap(); + }); + + let list = client.list().await.unwrap(); + assert_eq!(list.len(), 1); + assert_eq!( + list.get(&123).expect("Connection list missing item"), + &"scheme://host".parse::().unwrap() + ); + } + + #[tokio::test] + async fn kill_should_report_error_if_receives_error_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + transport + .write_frame_for(&Response::new(request.id, test_error_response())) + .await + .unwrap(); + }); + + let err = client.kill(123).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!(err.to_string(), test_error().to_string()); + } + + #[tokio::test] + async fn kill_should_report_error_if_receives_unexpected_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + transport + .write_frame_for(&Response::new( + request.id, + ManagerResponse::Connected { id: 0 }, + )) + .await + .unwrap(); + }); + + let err = client.kill(123).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn kill_should_return_success_from_successful_response() { + let (mut client, mut transport) = setup(); + + tokio::spawn(async move { + let request = transport + .read_frame_as::>() + .await + .unwrap() + .unwrap(); + + transport + .write_frame_for(&Response::new(request.id, ManagerResponse::Killed)) + .await + .unwrap(); + }); + + client.kill(123).await.unwrap(); + } +} diff --git a/distant-net/src/manager/client/channel.rs b/distant-net/src/manager/client/channel.rs new file mode 100644 index 0000000..0d01308 --- /dev/null +++ b/distant-net/src/manager/client/channel.rs @@ -0,0 +1,174 @@ +use crate::{ + client::{Client, ReconnectStrategy, UntypedClient}, + common::{ConnectionId, FramedTransport, InmemoryTransport, UntypedRequest}, + manager::data::{ManagerRequest, ManagerResponse}, +}; +use log::*; +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + io, + ops::{Deref, DerefMut}, +}; +use tokio::task::JoinHandle; + +/// Represents a raw channel between a manager client and server. Underneath, this routes incoming +/// and outgoing data from a proxied server to an inmemory transport. +pub struct RawChannel { + transport: FramedTransport, + task: JoinHandle<()>, +} + +impl RawChannel { + pub fn abort(&self) { + self.task.abort(); + } + + /// Consumes this channel, returning a typed client wrapping the transport. + /// + /// ### Note + /// + /// This does not perform any additional handshakes or authentication. All authentication was + /// performed during separate connection and this merely wraps an inmemory transport that maps + /// to the primary connection. + pub fn into_client(self) -> Client + where + T: Send + Sync + Serialize + 'static, + U: Send + Sync + DeserializeOwned + 'static, + { + Client::spawn_inmemory(self.transport, ReconnectStrategy::Fail) + } + + /// Consumes this channel, returning an untyped client wrapping the transport. + /// + /// ### Note + /// + /// This does not perform any additional handshakes or authentication. All authentication was + /// performed during separate connection and this merely wraps an inmemory transport that maps + /// to the primary connection. + pub fn into_untyped_client(self) -> UntypedClient { + UntypedClient::spawn_inmemory(self.transport, ReconnectStrategy::Fail) + } + + /// Returns reference to the underlying framed transport. + pub fn as_framed_transport(&self) -> &FramedTransport { + &self.transport + } + + /// Returns mutable reference to the underlying framed transport. + pub fn as_mut_framed_transport(&mut self) -> &mut FramedTransport { + &mut self.transport + } + + /// Consumes the channel, returning the underlying framed transport. + pub fn into_framed_transport(self) -> FramedTransport { + self.transport + } +} + +impl Deref for RawChannel { + type Target = FramedTransport; + + fn deref(&self) -> &Self::Target { + &self.transport + } +} + +impl DerefMut for RawChannel { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.transport + } +} + +impl RawChannel { + pub(super) async fn spawn( + connection_id: ConnectionId, + client: &mut Client, + ) -> io::Result { + let mut mailbox = client + .mail(ManagerRequest::OpenChannel { id: connection_id }) + .await?; + + // Wait for the first response, which should be channel confirmation + let channel_id = match mailbox.next().await { + Some(response) => match response.payload { + ManagerResponse::ChannelOpened { id } => Ok(id), + ManagerResponse::Error { description } => { + Err(io::Error::new(io::ErrorKind::Other, description)) + } + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("[Conn {connection_id}] Raw channel open unexpected response: {x:?}"), + )), + }, + None => Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + format!("[Conn {connection_id}] Raw channel mailbox aborted"), + )), + }?; + + // Spawn our channel proxy transport + let (mut proxy, transport) = FramedTransport::pair(1); + + let mut manager_channel = client.clone_channel(); + let task = tokio::spawn(async move { + loop { + tokio::select! { + maybe_response = mailbox.next() => { + if maybe_response.is_none() { + debug!("[Conn {connection_id} :: Chan {channel_id}] Closing from no more responses"); + break; + } + + match maybe_response.unwrap().payload { + ManagerResponse::Channel { response, .. } => { + if let Err(x) = proxy.write_frame(response.to_bytes()).await { + error!( + "[Conn {connection_id} :: Chan {channel_id}] Write response failed: {x}" + ); + } + } + ManagerResponse::ChannelClosed { .. } => { + break; + } + _ => continue, + } + } + result = proxy.read_frame() => { + match result { + Ok(Some(frame)) => { + let request = match UntypedRequest::from_slice(frame.as_item()) { + Ok(x) => x.into_owned(), + Err(x) => { + error!("[Conn {connection_id} :: Chan {channel_id}] Parse request failed: {x}"); + continue; + } + }; + + // NOTE: In this situation, we do not expect a response to this + // request (even if the server sends something back) + if let Err(x) = manager_channel + .fire(ManagerRequest::Channel { + id: channel_id, + request, + }) + .await + { + error!("[Conn {connection_id} :: Chan {channel_id}] Forward failed: {x}"); + } + } + Ok(None) => { + debug!("[Conn {connection_id} :: Chan {channel_id}] Closing from no more requests"); + break; + } + Err(x) => { + error!("[Conn {connection_id} :: Chan {channel_id}] Read frame failed: {x}"); + } + } + } + } + } + }); + + Ok(RawChannel { transport, task }) + } +} diff --git a/distant-core/src/manager/data.rs b/distant-net/src/manager/data.rs similarity index 69% rename from distant-core/src/manager/data.rs rename to distant-net/src/manager/data.rs index 11f6022..3a9cc80 100644 --- a/distant-core/src/manager/data.rs +++ b/distant-net/src/manager/data.rs @@ -1,12 +1,9 @@ +pub type ManagerChannelId = u32; +pub type ManagerAuthenticationId = u32; + mod capabilities; pub use capabilities::*; -mod destination; -pub use destination::*; - -mod id; -pub use id::*; - mod info; pub use info::*; diff --git a/distant-core/src/manager/data/capabilities.rs b/distant-net/src/manager/data/capabilities.rs similarity index 100% rename from distant-core/src/manager/data/capabilities.rs rename to distant-net/src/manager/data/capabilities.rs diff --git a/distant-core/src/manager/data/info.rs b/distant-net/src/manager/data/info.rs similarity index 86% rename from distant-core/src/manager/data/info.rs rename to distant-net/src/manager/data/info.rs index b20c9b7..10a2624 100644 --- a/distant-core/src/manager/data/info.rs +++ b/distant-net/src/manager/data/info.rs @@ -1,5 +1,4 @@ -use super::{ConnectionId, Destination}; -use crate::data::Map; +use crate::common::{ConnectionId, Destination, Map}; use serde::{Deserialize, Serialize}; /// Information about a specific connection diff --git a/distant-core/src/manager/data/list.rs b/distant-net/src/manager/data/list.rs similarity index 80% rename from distant-core/src/manager/data/list.rs rename to distant-net/src/manager/data/list.rs index dfbed3d..e3fc754 100644 --- a/distant-core/src/manager/data/list.rs +++ b/distant-net/src/manager/data/list.rs @@ -1,4 +1,4 @@ -use super::{ConnectionId, Destination}; +use crate::common::{ConnectionId, Destination}; use derive_more::IntoIterator; use serde::{Deserialize, Serialize}; use std::{ @@ -41,16 +41,16 @@ impl DerefMut for ConnectionList { } } -impl Index for ConnectionList { +impl Index for ConnectionList { type Output = Destination; - fn index(&self, connection_id: u64) -> &Self::Output { + fn index(&self, connection_id: ConnectionId) -> &Self::Output { &self.0[&connection_id] } } -impl IndexMut for ConnectionList { - fn index_mut(&mut self, connection_id: u64) -> &mut Self::Output { +impl IndexMut for ConnectionList { + fn index_mut(&mut self, connection_id: ConnectionId) -> &mut Self::Output { self.0 .get_mut(&connection_id) .expect("No connection with id") diff --git a/distant-core/src/manager/data/request.rs b/distant-net/src/manager/data/request.rs similarity index 67% rename from distant-core/src/manager/data/request.rs rename to distant-net/src/manager/data/request.rs index b8e0906..aab5b60 100644 --- a/distant-core/src/manager/data/request.rs +++ b/distant-net/src/manager/data/request.rs @@ -1,13 +1,13 @@ -use super::{ChannelId, ConnectionId, Destination}; -use crate::{DistantMsg, DistantRequestData, Map}; +use super::{ManagerAuthenticationId, ManagerChannelId}; +use crate::common::{ + authentication::msg::AuthenticationResponse, ConnectionId, Destination, Map, UntypedRequest, +}; use derive_more::IsVariant; -use distant_net::Request; use serde::{Deserialize, Serialize}; use strum::{AsRefStr, EnumDiscriminants, EnumIter, EnumMessage, EnumString}; #[allow(clippy::large_enum_variant)] #[derive(Clone, Debug, EnumDiscriminants, Serialize, Deserialize)] -#[cfg_attr(feature = "clap", derive(clap::Subcommand))] #[strum_discriminants(derive( AsRefStr, strum::Display, @@ -34,13 +34,12 @@ pub enum ManagerRequest { Capabilities, /// Launch a server using the manager - #[strum_discriminants(strum(message = "Supports launching distant on remote servers"))] + #[strum_discriminants(strum(message = "Supports launching a server on remote machines"))] Launch { // NOTE: Boxed per clippy's large_enum_variant warning destination: Box, /// Additional options specific to the connection - #[cfg_attr(feature = "clap", clap(short, long, action = clap::ArgAction::Append))] options: Map, }, @@ -51,12 +50,20 @@ pub enum ManagerRequest { destination: Box, /// Additional options specific to the connection - #[cfg_attr(feature = "clap", clap(short, long, action = clap::ArgAction::Append))] options: Map, }, - /// Opens a channel for communication with a server - #[cfg_attr(feature = "clap", clap(skip))] + /// Submit some authentication message for the manager to use with an active connection + #[strum_discriminants(strum(message = "Supports authenticating with a remote server"))] + Authenticate { + /// Id of the authentication request that is being responded to + id: ManagerAuthenticationId, + + /// Response being sent to some active connection + msg: AuthenticationResponse, + }, + + /// Opens a channel for communication with an already-connected server #[strum_discriminants(strum(message = "Supports opening a channel with a remote server"))] OpenChannel { /// Id of the connection @@ -64,25 +71,22 @@ pub enum ManagerRequest { }, /// Sends data through channel - #[cfg_attr(feature = "clap", clap(skip))] #[strum_discriminants(strum( message = "Supports sending data through a channel with a remote server" ))] Channel { /// Id of the channel - id: ChannelId, + id: ManagerChannelId, - /// Request to send to through the channel - #[cfg_attr(feature = "clap", clap(skip = skipped_request()))] - request: Request>, + /// Untyped request to send through the channel + request: UntypedRequest<'static>, }, /// Closes an open channel - #[cfg_attr(feature = "clap", clap(skip))] #[strum_discriminants(strum(message = "Supports closing a channel with a remote server"))] CloseChannel { /// Id of the channel to close - id: ChannelId, + id: ManagerChannelId, }, /// Retrieve information about a specific connection @@ -96,14 +100,4 @@ pub enum ManagerRequest { /// Retrieve list of connections being managed #[strum_discriminants(strum(message = "Supports retrieving a list of managed connections"))] List, - - /// Signals the manager to shutdown - #[strum_discriminants(strum(message = "Supports being shut down on demand"))] - Shutdown, -} - -/// Produces some default request, purely to satisfy clap -#[cfg(feature = "clap")] -fn skipped_request() -> Request> { - Request::new(DistantMsg::Single(DistantRequestData::SystemInfo {})) } diff --git a/distant-core/src/manager/data/response.rs b/distant-net/src/manager/data/response.rs similarity index 55% rename from distant-core/src/manager/data/response.rs rename to distant-net/src/manager/data/response.rs index e5b7687..e69289b 100644 --- a/distant-core/src/manager/data/response.rs +++ b/distant-net/src/manager/data/response.rs @@ -1,6 +1,9 @@ -use crate::{data::Error, ConnectionInfo, ConnectionList, Destination, ManagerCapabilities}; -use crate::{ChannelId, ConnectionId, DistantMsg, DistantResponseData}; -use distant_net::Response; +use super::{ + ConnectionInfo, ConnectionList, ManagerAuthenticationId, ManagerCapabilities, ManagerChannelId, +}; +use crate::common::{ + authentication::msg::Authentication, ConnectionId, Destination, UntypedResponse, +}; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -9,16 +12,13 @@ pub enum ManagerResponse { /// Acknowledgement that a connection was killed Killed, - /// Broadcast that the manager is shutting down (not guaranteed to be sent) - Shutdown, - /// Indicates that some error occurred during a request - Error(Error), + Error { description: String }, /// Response to retrieving information about the manager's capabilities Capabilities { supported: ManagerCapabilities }, - /// Confirmation of a distant server being launched + /// Confirmation of a server being launched Launched { /// Updated location of the spawned server destination: Destination, @@ -27,6 +27,15 @@ pub enum ManagerResponse { /// Confirmation of a connection being established Connected { id: ConnectionId }, + /// Authentication information being sent to a client + Authenticate { + /// Id tied to authentication information in case a response is needed + id: ManagerAuthenticationId, + + /// Authentication message + msg: Authentication, + }, + /// Information about a specific connection Info(ConnectionInfo), @@ -36,21 +45,29 @@ pub enum ManagerResponse { /// Forward a response back to a specific channel that made a request Channel { /// Id of the channel - id: ChannelId, + id: ManagerChannelId, - /// Response to an earlier channel request - response: Response>, + /// Untyped response to send through the channel + response: UntypedResponse<'static>, }, /// Indicates that a channel has been opened ChannelOpened { /// Id of the channel - id: ChannelId, + id: ManagerChannelId, }, /// Indicates that a channel has been closed ChannelClosed { /// Id of the channel - id: ChannelId, + id: ManagerChannelId, }, } + +impl From for ManagerResponse { + fn from(x: T) -> Self { + Self::Error { + description: x.to_string(), + } + } +} diff --git a/distant-net/src/manager/server.rs b/distant-net/src/manager/server.rs new file mode 100644 index 0000000..1dc2286 --- /dev/null +++ b/distant-net/src/manager/server.rs @@ -0,0 +1,584 @@ +use crate::{ + common::{authentication::msg::AuthenticationResponse, ConnectionId, Destination, Map}, + manager::{ + ConnectionInfo, ConnectionList, ManagerAuthenticationId, ManagerCapabilities, + ManagerChannelId, ManagerRequest, ManagerResponse, + }, + server::{Server, ServerCtx, ServerHandler}, +}; +use async_trait::async_trait; +use log::*; +use std::{collections::HashMap, io, sync::Arc}; +use tokio::sync::{oneshot, RwLock}; + +mod authentication; +pub use authentication::*; + +mod config; +pub use config::*; + +mod connection; +pub use connection::*; + +mod handler; +pub use handler::*; + +/// Represents a manager of multiple server connections. +pub struct ManagerServer { + /// Configuration settings for the server + config: Config, + + /// Mapping of connection id -> connection + connections: RwLock>, + + /// Mapping of auth id -> callback + registry: + Arc>>>, +} + +impl ManagerServer { + /// Creates a new [`Server`] starting with a default configuration and no authentication + /// methods. The provided `config` will be used to configure the launch and connect handlers + /// for the server as well as provide other defaults. + pub fn new(config: Config) -> Server { + Server::new().handler(Self { + config, + connections: RwLock::new(HashMap::new()), + registry: Arc::new(RwLock::new(HashMap::new())), + }) + } + + /// Launches a new server at the specified `destination` using the given `options` information + /// and authentication client (if needed) to retrieve additional information needed to + /// enter the destination prior to starting the server, returning the destination of the + /// launched server + async fn launch( + &self, + destination: Destination, + options: Map, + mut authenticator: ManagerAuthenticator, + ) -> io::Result { + let scheme = match destination.scheme.as_deref() { + Some(scheme) => { + trace!("Using scheme {}", scheme); + scheme + } + None => { + trace!( + "Using fallback scheme of {}", + self.config.launch_fallback_scheme.as_str() + ); + self.config.launch_fallback_scheme.as_str() + } + } + .to_lowercase(); + + let credentials = { + let handler = self.config.launch_handlers.get(&scheme).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("No launch handler registered for {}", scheme), + ) + })?; + handler + .launch(&destination, &options, &mut authenticator) + .await? + }; + + Ok(credentials) + } + + /// Connects to a new server at the specified `destination` using the given `options` information + /// and authentication client (if needed) to retrieve additional information needed to + /// establish the connection to the server + async fn connect( + &self, + destination: Destination, + options: Map, + mut authenticator: ManagerAuthenticator, + ) -> io::Result { + let scheme = match destination.scheme.as_deref() { + Some(scheme) => { + trace!("Using scheme {}", scheme); + scheme + } + None => { + trace!( + "Using fallback scheme of {}", + self.config.connect_fallback_scheme.as_str() + ); + self.config.connect_fallback_scheme.as_str() + } + } + .to_lowercase(); + + let client = { + let handler = self.config.connect_handlers.get(&scheme).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("No connect handler registered for {}", scheme), + ) + })?; + handler + .connect(&destination, &options, &mut authenticator) + .await? + }; + + let connection = ManagerConnection::spawn(destination, options, client).await?; + let id = connection.id; + self.connections.write().await.insert(id, connection); + Ok(id) + } + + /// Retrieves the list of supported capabilities for this manager + async fn capabilities(&self) -> io::Result { + Ok(ManagerCapabilities::all()) + } + + /// Retrieves information about the connection to the server with the specified `id` + async fn info(&self, id: ConnectionId) -> io::Result { + match self.connections.read().await.get(&id) { + Some(connection) => Ok(ConnectionInfo { + id: connection.id, + destination: connection.destination.clone(), + options: connection.options.clone(), + }), + None => Err(io::Error::new( + io::ErrorKind::NotConnected, + "No connection found", + )), + } + } + + /// Retrieves a list of connections to servers + async fn list(&self) -> io::Result { + Ok(ConnectionList( + self.connections + .read() + .await + .values() + .map(|conn| (conn.id, conn.destination.clone())) + .collect(), + )) + } + + /// Kills the connection to the server with the specified `id` + async fn kill(&self, id: ConnectionId) -> io::Result<()> { + match self.connections.write().await.remove(&id) { + Some(_) => Ok(()), + None => Err(io::Error::new( + io::ErrorKind::NotConnected, + "No connection found", + )), + } + } +} + +#[derive(Default)] +pub struct DistantManagerServerConnection { + /// Holds on to open channels feeding data back from a server to some connected client, + /// enabling us to cancel the tasks on demand + channels: RwLock>, +} + +#[async_trait] +impl ServerHandler for ManagerServer { + type Request = ManagerRequest; + type Response = ManagerResponse; + type LocalData = DistantManagerServerConnection; + + async fn on_request(&self, ctx: ServerCtx) { + let ServerCtx { + connection_id, + request, + reply, + local_data, + } = ctx; + + let response = match request.payload { + ManagerRequest::Capabilities {} => match self.capabilities().await { + Ok(supported) => ManagerResponse::Capabilities { supported }, + Err(x) => ManagerResponse::from(x), + }, + ManagerRequest::Launch { + destination, + options, + } => match self + .launch( + *destination, + options, + ManagerAuthenticator { + reply: reply.clone(), + registry: Arc::clone(&self.registry), + }, + ) + .await + { + Ok(destination) => ManagerResponse::Launched { destination }, + Err(x) => ManagerResponse::from(x), + }, + ManagerRequest::Connect { + destination, + options, + } => match self + .connect( + *destination, + options, + ManagerAuthenticator { + reply: reply.clone(), + registry: Arc::clone(&self.registry), + }, + ) + .await + { + Ok(id) => ManagerResponse::Connected { id }, + Err(x) => ManagerResponse::from(x), + }, + ManagerRequest::Authenticate { id, msg } => { + match self.registry.write().await.remove(&id) { + Some(cb) => match cb.send(msg) { + Ok(_) => return, + Err(_) => ManagerResponse::Error { + description: "Unable to forward authentication callback".to_string(), + }, + }, + None => ManagerResponse::from(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid authentication id", + )), + } + } + ManagerRequest::OpenChannel { id } => match self.connections.read().await.get(&id) { + Some(connection) => match connection.open_channel(reply.clone()) { + Ok(channel) => { + debug!("[Conn {id}] Channel {} has been opened", channel.id()); + let id = channel.id(); + local_data.channels.write().await.insert(id, channel); + ManagerResponse::ChannelOpened { id } + } + Err(x) => ManagerResponse::from(x), + }, + None => ManagerResponse::from(io::Error::new( + io::ErrorKind::NotConnected, + "Connection does not exist", + )), + }, + ManagerRequest::Channel { id, request } => { + match local_data.channels.read().await.get(&id) { + // TODO: For now, we are NOT sending back a response to acknowledge + // a successful channel send. We could do this in order for + // the client to listen for a complete send, but is it worth it? + Some(channel) => match channel.send(request) { + Ok(_) => return, + Err(x) => ManagerResponse::from(x), + }, + None => ManagerResponse::from(io::Error::new( + io::ErrorKind::NotConnected, + "Channel is not open or does not exist", + )), + } + } + ManagerRequest::CloseChannel { id } => { + match local_data.channels.write().await.remove(&id) { + Some(channel) => match channel.close() { + Ok(_) => { + debug!("Channel {id} has been closed"); + ManagerResponse::ChannelClosed { id } + } + Err(x) => ManagerResponse::from(x), + }, + None => ManagerResponse::from(io::Error::new( + io::ErrorKind::NotConnected, + "Channel is not open or does not exist", + )), + } + } + ManagerRequest::Info { id } => match self.info(id).await { + Ok(info) => ManagerResponse::Info(info), + Err(x) => ManagerResponse::from(x), + }, + ManagerRequest::List => match self.list().await { + Ok(list) => ManagerResponse::List(list), + Err(x) => ManagerResponse::from(x), + }, + ManagerRequest::Kill { id } => match self.kill(id).await { + Ok(()) => ManagerResponse::Killed, + Err(x) => ManagerResponse::from(x), + }, + }; + + if let Err(x) = reply.send(response).await { + error!("[Conn {}] {}", connection_id, x); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::{ReconnectStrategy, UntypedClient}; + use crate::common::FramedTransport; + use crate::server::ServerReply; + use crate::{boxed_connect_handler, boxed_launch_handler}; + use tokio::sync::mpsc; + + fn test_config() -> Config { + Config { + launch_fallback_scheme: "ssh".to_string(), + connect_fallback_scheme: "distant".to_string(), + connection_buffer_size: 100, + user: false, + launch_handlers: HashMap::new(), + connect_handlers: HashMap::new(), + } + } + + /// Create an untyped client that is detached such that reads and writes will fail + fn detached_untyped_client() -> UntypedClient { + UntypedClient::spawn_inmemory(FramedTransport::pair(1).0, ReconnectStrategy::Fail) + } + + /// Create a new server and authenticator + fn setup(config: Config) -> (ManagerServer, ManagerAuthenticator) { + let registry = Arc::new(RwLock::new(HashMap::new())); + + let authenticator = ManagerAuthenticator { + reply: ServerReply { + origin_id: format!("{}", rand::random::()), + tx: mpsc::channel(1).0, + }, + registry: Arc::clone(®istry), + }; + + let server = ManagerServer { + config, + connections: RwLock::new(HashMap::new()), + registry, + }; + + (server, authenticator) + } + + #[tokio::test] + async fn launch_should_fail_if_destination_scheme_is_unsupported() { + let (server, authenticator) = setup(test_config()); + + let destination = "scheme://host".parse::().unwrap(); + let options = "".parse::().unwrap(); + let err = server + .launch(destination, options, authenticator) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err); + } + + #[tokio::test] + async fn launch_should_fail_if_handler_tied_to_scheme_fails() { + let mut config = test_config(); + + let handler = boxed_launch_handler!(|_a, _b, _c| { + Err(io::Error::new(io::ErrorKind::Other, "test failure")) + }); + + config.launch_handlers.insert("scheme".to_string(), handler); + + let (server, authenticator) = setup(config); + let destination = "scheme://host".parse::().unwrap(); + let options = "".parse::().unwrap(); + let err = server + .launch(destination, options, authenticator) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!(err.to_string(), "test failure"); + } + + #[tokio::test] + async fn launch_should_return_new_destination_on_success() { + let mut config = test_config(); + + let handler = boxed_launch_handler!(|_a, _b, _c| { + Ok("scheme2://host2".parse::().unwrap()) + }); + + config.launch_handlers.insert("scheme".to_string(), handler); + + let (server, authenticator) = setup(config); + let destination = "scheme://host".parse::().unwrap(); + let options = "key=value".parse::().unwrap(); + let destination = server + .launch(destination, options, authenticator) + .await + .unwrap(); + + assert_eq!( + destination, + "scheme2://host2".parse::().unwrap() + ); + } + + #[tokio::test] + async fn connect_should_fail_if_destination_scheme_is_unsupported() { + let (server, authenticator) = setup(test_config()); + + let destination = "scheme://host".parse::().unwrap(); + let options = "".parse::().unwrap(); + let err = server + .connect(destination, options, authenticator) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err); + } + + #[tokio::test] + async fn connect_should_fail_if_handler_tied_to_scheme_fails() { + let mut config = test_config(); + + let handler = boxed_connect_handler!(|_a, _b, _c| { + Err(io::Error::new(io::ErrorKind::Other, "test failure")) + }); + + config + .connect_handlers + .insert("scheme".to_string(), handler); + + let (server, authenticator) = setup(config); + let destination = "scheme://host".parse::().unwrap(); + let options = "".parse::().unwrap(); + let err = server + .connect(destination, options, authenticator) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!(err.to_string(), "test failure"); + } + + #[tokio::test] + async fn connect_should_return_id_of_new_connection_on_success() { + let mut config = test_config(); + + let handler = boxed_connect_handler!(|_a, _b, _c| { Ok(detached_untyped_client()) }); + + config + .connect_handlers + .insert("scheme".to_string(), handler); + + let (server, authenticator) = setup(config); + let destination = "scheme://host".parse::().unwrap(); + let options = "key=value".parse::().unwrap(); + let id = server + .connect(destination, options, authenticator) + .await + .unwrap(); + + let lock = server.connections.read().await; + let connection = lock.get(&id).unwrap(); + assert_eq!(connection.id, id); + assert_eq!(connection.destination, "scheme://host"); + assert_eq!(connection.options, "key=value".parse().unwrap()); + } + + #[tokio::test] + async fn info_should_fail_if_no_connection_found_for_specified_id() { + let (server, _) = setup(test_config()); + + let err = server.info(999).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err); + } + + #[tokio::test] + async fn info_should_return_information_about_established_connection() { + let (server, _) = setup(test_config()); + + let connection = ManagerConnection::spawn( + "scheme://host".parse().unwrap(), + "key=value".parse().unwrap(), + detached_untyped_client(), + ) + .await + .unwrap(); + let id = connection.id; + server.connections.write().await.insert(id, connection); + + let info = server.info(id).await.unwrap(); + assert_eq!( + info, + ConnectionInfo { + id, + destination: "scheme://host".parse().unwrap(), + options: "key=value".parse().unwrap(), + } + ); + } + + #[tokio::test] + async fn list_should_return_empty_connection_list_if_no_established_connections() { + let (server, _) = setup(test_config()); + + let list = server.list().await.unwrap(); + assert_eq!(list, ConnectionList(HashMap::new())); + } + + #[tokio::test] + async fn list_should_return_a_list_of_established_connections() { + let (server, _) = setup(test_config()); + + let connection = ManagerConnection::spawn( + "scheme://host".parse().unwrap(), + "key=value".parse().unwrap(), + detached_untyped_client(), + ) + .await + .unwrap(); + let id_1 = connection.id; + server.connections.write().await.insert(id_1, connection); + + let connection = ManagerConnection::spawn( + "other://host2".parse().unwrap(), + "key=value".parse().unwrap(), + detached_untyped_client(), + ) + .await + .unwrap(); + let id_2 = connection.id; + server.connections.write().await.insert(id_2, connection); + + let list = server.list().await.unwrap(); + assert_eq!( + list.get(&id_1).unwrap(), + &"scheme://host".parse::().unwrap() + ); + assert_eq!( + list.get(&id_2).unwrap(), + &"other://host2".parse::().unwrap() + ); + } + + #[tokio::test] + async fn kill_should_fail_if_no_connection_found_for_specified_id() { + let (server, _) = setup(test_config()); + + let err = server.kill(999).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err); + } + + #[tokio::test] + async fn kill_should_terminate_established_connection_and_remove_it_from_the_list() { + let (server, _) = setup(test_config()); + + let connection = ManagerConnection::spawn( + "scheme://host".parse().unwrap(), + "key=value".parse().unwrap(), + detached_untyped_client(), + ) + .await + .unwrap(); + let id = connection.id; + server.connections.write().await.insert(id, connection); + + server.kill(id).await.unwrap(); + + let lock = server.connections.read().await; + assert!(!lock.contains_key(&id), "Connection still exists"); + } +} diff --git a/distant-net/src/manager/server/authentication.rs b/distant-net/src/manager/server/authentication.rs new file mode 100644 index 0000000..5c035d4 --- /dev/null +++ b/distant-net/src/manager/server/authentication.rs @@ -0,0 +1,103 @@ +use crate::{ + common::authentication::{msg::*, Authenticator}, + manager::data::{ManagerAuthenticationId, ManagerResponse}, + server::ServerReply, +}; +use async_trait::async_trait; +use std::{collections::HashMap, io, sync::Arc}; +use tokio::sync::{oneshot, RwLock}; + +/// Implementation of [`Authenticator`] used by a manger to perform authentication with +/// remote servers it is managing. +#[derive(Clone)] +pub struct ManagerAuthenticator { + /// Used to communicate authentication requests + pub(super) reply: ServerReply, + + /// Used to store one-way response senders that are used to return callbacks + pub(super) registry: + Arc>>>, +} + +impl ManagerAuthenticator { + /// Sends an [`Authentication`] `msg` that expects a reply, storing a callback. + async fn send(&self, msg: Authentication) -> io::Result { + let (tx, rx) = oneshot::channel(); + let id = rand::random(); + + self.registry.write().await.insert(id, tx); + self.reply + .send(ManagerResponse::Authenticate { id, msg }) + .await?; + rx.await + .map_err(|x| io::Error::new(io::ErrorKind::Other, x)) + } + + /// Sends an [`Authentication`] `msg` without expecting a reply. No callback is stored. + async fn fire(&self, msg: Authentication) -> io::Result<()> { + let id = rand::random(); + self.reply + .send(ManagerResponse::Authenticate { id, msg }) + .await?; + Ok(()) + } +} + +/// Represents an interface for submitting challenges for authentication. +#[async_trait] +impl Authenticator for ManagerAuthenticator { + async fn initialize( + &mut self, + initialization: Initialization, + ) -> io::Result { + match self + .send(Authentication::Initialization(initialization)) + .await + { + Ok(AuthenticationResponse::Initialization(x)) => Ok(x), + Ok(x) => Err(io::Error::new( + io::ErrorKind::Other, + format!("Unexpected response: {x:?}"), + )), + Err(x) => Err(x), + } + } + + async fn challenge(&mut self, challenge: Challenge) -> io::Result { + match self.send(Authentication::Challenge(challenge)).await { + Ok(AuthenticationResponse::Challenge(x)) => Ok(x), + Ok(x) => Err(io::Error::new( + io::ErrorKind::Other, + format!("Unexpected response: {x:?}"), + )), + Err(x) => Err(x), + } + } + + async fn verify(&mut self, verification: Verification) -> io::Result { + match self.send(Authentication::Verification(verification)).await { + Ok(AuthenticationResponse::Verification(x)) => Ok(x), + Ok(x) => Err(io::Error::new( + io::ErrorKind::Other, + format!("Unexpected response: {x:?}"), + )), + Err(x) => Err(x), + } + } + + async fn info(&mut self, info: Info) -> io::Result<()> { + self.fire(Authentication::Info(info)).await + } + + async fn error(&mut self, error: Error) -> io::Result<()> { + self.fire(Authentication::Error(error)).await + } + + async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()> { + self.fire(Authentication::StartMethod(start_method)).await + } + + async fn finished(&mut self) -> io::Result<()> { + self.fire(Authentication::Finished).await + } +} diff --git a/distant-core/src/manager/server/config.rs b/distant-net/src/manager/server/config.rs similarity index 88% rename from distant-core/src/manager/server/config.rs rename to distant-net/src/manager/server/config.rs index 419940e..d7d898a 100644 --- a/distant-core/src/manager/server/config.rs +++ b/distant-net/src/manager/server/config.rs @@ -1,7 +1,8 @@ -use crate::{BoxedConnectHandler, BoxedLaunchHandler}; +use super::{BoxedConnectHandler, BoxedLaunchHandler}; use std::collections::HashMap; -pub struct DistantManagerConfig { +/// Configuration settings for a manager. +pub struct Config { /// Scheme to use when none is provided in a destination for launch pub launch_fallback_scheme: String, @@ -21,7 +22,7 @@ pub struct DistantManagerConfig { pub connect_handlers: HashMap, } -impl Default for DistantManagerConfig { +impl Default for Config { fn default() -> Self { Self { // Default to using ssh to launch distant diff --git a/distant-net/src/manager/server/connection.rs b/distant-net/src/manager/server/connection.rs new file mode 100644 index 0000000..7a673be --- /dev/null +++ b/distant-net/src/manager/server/connection.rs @@ -0,0 +1,218 @@ +use crate::{ + client::{Mailbox, UntypedClient}, + common::{ConnectionId, Destination, Map, UntypedRequest, UntypedResponse}, + manager::data::{ManagerChannelId, ManagerResponse}, + server::ServerReply, +}; +use log::*; +use std::{collections::HashMap, io}; +use tokio::{sync::mpsc, task::JoinHandle}; + +/// Represents a connection a distant manager has with some distant-compatible server +pub struct ManagerConnection { + pub id: ConnectionId, + pub destination: Destination, + pub options: Map, + tx: mpsc::UnboundedSender, + + action_task: JoinHandle<()>, + request_task: JoinHandle<()>, + response_task: JoinHandle<()>, +} + +#[derive(Clone)] +pub struct ManagerChannel { + channel_id: ManagerChannelId, + tx: mpsc::UnboundedSender, +} + +impl ManagerChannel { + /// Returns the id associated with the channel. + pub fn id(&self) -> ManagerChannelId { + self.channel_id + } + + /// Sends the untyped request to the server on the other side of the channel. + pub fn send(&self, req: UntypedRequest<'static>) -> io::Result<()> { + let id = self.channel_id; + + self.tx.send(Action::Write { id, req }).map_err(|x| { + io::Error::new( + io::ErrorKind::BrokenPipe, + format!("channel {id} send failed: {x}"), + ) + }) + } + + /// Closes the channel, unregistering it with the connection. + pub fn close(&self) -> io::Result<()> { + let id = self.channel_id; + self.tx.send(Action::Unregister { id }).map_err(|x| { + io::Error::new( + io::ErrorKind::BrokenPipe, + format!("channel {id} close failed: {x}"), + ) + }) + } +} + +impl ManagerConnection { + pub async fn spawn( + spawn: Destination, + options: Map, + client: UntypedClient, + ) -> io::Result { + let connection_id = rand::random(); + let (tx, rx) = mpsc::unbounded_channel(); + + let (request_tx, request_rx) = mpsc::unbounded_channel(); + let action_task = tokio::spawn(action_task(connection_id, rx, request_tx)); + let response_task = tokio::spawn(response_task( + connection_id, + client.assign_default_mailbox(100).await?, + tx.clone(), + )); + let request_task = tokio::spawn(request_task(connection_id, client, request_rx)); + + Ok(Self { + id: connection_id, + destination: spawn, + options, + tx, + action_task, + request_task, + response_task, + }) + } + + pub fn open_channel(&self, reply: ServerReply) -> io::Result { + let channel_id = rand::random(); + self.tx + .send(Action::Register { + id: channel_id, + reply, + }) + .map_err(|x| { + io::Error::new( + io::ErrorKind::BrokenPipe, + format!("open_channel failed: {x}"), + ) + })?; + Ok(ManagerChannel { + channel_id, + tx: self.tx.clone(), + }) + } +} + +impl Drop for ManagerConnection { + fn drop(&mut self) { + self.action_task.abort(); + self.request_task.abort(); + self.response_task.abort(); + } +} + +enum Action { + Register { + id: ManagerChannelId, + reply: ServerReply, + }, + + Unregister { + id: ManagerChannelId, + }, + + Read { + res: UntypedResponse<'static>, + }, + + Write { + id: ManagerChannelId, + req: UntypedRequest<'static>, + }, +} + +/// Internal task to process outgoing [`UntypedRequest`]s. +async fn request_task( + id: ConnectionId, + mut client: UntypedClient, + mut rx: mpsc::UnboundedReceiver>, +) { + while let Some(req) = rx.recv().await { + if let Err(x) = client.fire(req).await { + error!("[Conn {id}] Failed to send request: {x}"); + } + } +} + +/// Internal task to process incoming [`UntypedResponse`]s. +async fn response_task( + id: ConnectionId, + mut mailbox: Mailbox>, + tx: mpsc::UnboundedSender, +) { + while let Some(res) = mailbox.next().await { + if let Err(x) = tx.send(Action::Read { res }) { + error!("[Conn {id}] Failed to forward received response: {x}"); + } + } +} + +/// Internal task to process [`Action`] items. +/// +/// * `id` - the id of the connection. +/// * `rx` - used to receive new [`Action`]s to process. +/// * `tx` - used to send outgoing requests through the connection. +async fn action_task( + id: ConnectionId, + mut rx: mpsc::UnboundedReceiver, + tx: mpsc::UnboundedSender>, +) { + let mut registered = HashMap::new(); + + while let Some(action) = rx.recv().await { + match action { + Action::Register { id, reply } => { + registered.insert(id, reply); + } + Action::Unregister { id } => { + registered.remove(&id); + } + Action::Read { mut res } => { + // Split {channel id}_{request id} back into pieces and + // update the origin id to match the request id only + let channel_id = match res.origin_id.split_once('_') { + Some((cid_str, oid_str)) => { + if let Ok(cid) = cid_str.parse::() { + res.set_origin_id(oid_str.to_string()); + cid + } else { + continue; + } + } + None => continue, + }; + + if let Some(reply) = registered.get(&channel_id) { + let response = ManagerResponse::Channel { + id: channel_id, + response: res, + }; + if let Err(x) = reply.send(response).await { + error!("[Conn {id}] {x}"); + } + } + } + Action::Write { id, mut req } => { + // Combine channel id with request id so we can properly forward + // the response containing this in the origin id + req.set_id(format!("{id}_{}", req.id)); + + if let Err(x) = tx.send(req) { + error!("[Conn {id}] {x}"); + } + } + } + } +} diff --git a/distant-net/src/manager/server/handler.rs b/distant-net/src/manager/server/handler.rs new file mode 100644 index 0000000..ae7eacd --- /dev/null +++ b/distant-net/src/manager/server/handler.rs @@ -0,0 +1,312 @@ +use crate::client::UntypedClient; +use crate::common::{authentication::Authenticator, Destination, Map}; +use async_trait::async_trait; +use std::{future::Future, io}; + +pub type BoxedLaunchHandler = Box; +pub type BoxedConnectHandler = Box; + +/// Represents an interface to start a server at some remote `destination`. +/// +/// * `destination` is the location where the server will be started. +/// * `options` is provided to include extra information needed to launch or establish the +/// connection. +/// * `authenticator` is provided to support a challenge-based authentication while launching. +/// +/// Returns a [`Destination`] representing the new origin to use if a connection is desired. +#[async_trait] +pub trait LaunchHandler: Send + Sync { + async fn launch( + &self, + destination: &Destination, + options: &Map, + authenticator: &mut dyn Authenticator, + ) -> io::Result; +} + +#[async_trait] +impl LaunchHandler for F +where + F: Fn(&Destination, &Map, &mut dyn Authenticator) -> R + Send + Sync + 'static, + R: Future> + Send + 'static, +{ + async fn launch( + &self, + destination: &Destination, + options: &Map, + authenticator: &mut dyn Authenticator, + ) -> io::Result { + self(destination, options, authenticator).await + } +} + +/// Generates a new [`LaunchHandler`] for the provided anonymous function in the form of +/// +/// ``` +/// use distant_net::boxed_launch_handler; +/// +/// let _handler = boxed_launch_handler!(|destination, options, authenticator| { +/// todo!("Implement handler logic."); +/// }); +/// +/// let _handler = boxed_launch_handler!(|destination, options, authenticator| async { +/// todo!("We support async within as well regardless of the keyword!"); +/// }); +/// +/// let _handler = boxed_launch_handler!(move |destination, options, authenticator| { +/// todo!("You can also explicitly mark to move into the closure"); +/// }); +/// ``` +#[macro_export] +macro_rules! boxed_launch_handler { + (|$destination:ident, $options:ident, $authenticator:ident| $(async)? $body:block) => {{ + let x: $crate::manager::BoxedLaunchHandler = Box::new( + |$destination: &$crate::common::Destination, + $options: &$crate::common::Map, + $authenticator: &mut dyn $crate::common::authentication::Authenticator| async { + $body + }, + ); + x + }}; + (move |$destination:ident, $options:ident, $authenticator:ident| $(async)? $body:block) => {{ + let x: $crate::manager::BoxedLaunchHandler = Box::new( + move |$destination: &$crate::common::Destination, + $options: &$crate::common::Map, + $authenticator: &mut dyn $crate::common::authentication::Authenticator| async move { + $body + }, + ); + x + }}; +} + +/// Represents an interface to perform a connection to some remote `destination`. +/// +/// * `destination` is the location of the server to connect to. +/// * `options` is provided to include extra information needed to establish the connection. +/// * `authenticator` is provided to support a challenge-based authentication while connecting. +/// +/// Returns an [`UntypedClient`] representing the connection. +#[async_trait] +pub trait ConnectHandler: Send + Sync { + async fn connect( + &self, + destination: &Destination, + options: &Map, + authenticator: &mut dyn Authenticator, + ) -> io::Result; +} + +#[async_trait] +impl ConnectHandler for F +where + F: Fn(&Destination, &Map, &mut dyn Authenticator) -> R + Send + Sync + 'static, + R: Future> + Send + 'static, +{ + async fn connect( + &self, + destination: &Destination, + options: &Map, + authenticator: &mut dyn Authenticator, + ) -> io::Result { + self(destination, options, authenticator).await + } +} + +/// Generates a new [`ConnectHandler`] for the provided anonymous function in the form of +/// +/// ``` +/// use distant_net::boxed_connect_handler; +/// +/// let _handler = boxed_connect_handler!(|destination, options, authenticator| { +/// todo!("Implement handler logic."); +/// }); +/// +/// let _handler = boxed_connect_handler!(|destination, options, authenticator| async { +/// todo!("We support async within as well regardless of the keyword!"); +/// }); +/// +/// let _handler = boxed_connect_handler!(move |destination, options, authenticator| { +/// todo!("You can also explicitly mark to move into the closure"); +/// }); +/// ``` +#[macro_export] +macro_rules! boxed_connect_handler { + (|$destination:ident, $options:ident, $authenticator:ident| $(async)? $body:block) => {{ + let x: $crate::manager::BoxedConnectHandler = Box::new( + |$destination: &$crate::common::Destination, + $options: &$crate::common::Map, + $authenticator: &mut dyn $crate::common::authentication::Authenticator| async { + $body + }, + ); + x + }}; + (move |$destination:ident, $options:ident, $authenticator:ident| $(async)? $body:block) => {{ + let x: $crate::manager::BoxedConnectHandler = Box::new( + move |$destination: &$crate::common::Destination, + $options: &$crate::common::Map, + $authenticator: &mut dyn $crate::common::authentication::Authenticator| async move { + $body + }, + ); + x + }}; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::FramedTransport; + use test_log::test; + + #[inline] + fn test_destination() -> Destination { + "scheme://host:1234".parse().unwrap() + } + + #[inline] + fn test_options() -> Map { + Map::default() + } + + #[inline] + fn test_authenticator() -> impl Authenticator { + FramedTransport::pair(1).0 + } + + #[test(tokio::test)] + async fn boxed_launch_handler_should_generate_valid_boxed_launch_handler() { + let handler = boxed_launch_handler!(|_destination, _options, _authenticator| { + Err(io::Error::from(io::ErrorKind::Other)) + }); + assert_eq!( + handler + .launch( + &test_destination(), + &test_options(), + &mut test_authenticator() + ) + .await + .unwrap_err() + .kind(), + io::ErrorKind::Other + ); + + let handler = boxed_launch_handler!(|_destination, _options, _authenticator| async { + Err(io::Error::from(io::ErrorKind::Other)) + }); + assert_eq!( + handler + .launch( + &test_destination(), + &test_options(), + &mut test_authenticator() + ) + .await + .unwrap_err() + .kind(), + io::ErrorKind::Other + ); + + let handler = boxed_launch_handler!(move |_destination, _options, _authenticator| { + Err(io::Error::from(io::ErrorKind::Other)) + }); + assert_eq!( + handler + .launch( + &test_destination(), + &test_options(), + &mut test_authenticator() + ) + .await + .unwrap_err() + .kind(), + io::ErrorKind::Other + ); + + let handler = boxed_launch_handler!(move |_destination, _options, _authenticator| async { + Err(io::Error::from(io::ErrorKind::Other)) + }); + assert_eq!( + handler + .launch( + &test_destination(), + &test_options(), + &mut test_authenticator() + ) + .await + .unwrap_err() + .kind(), + io::ErrorKind::Other + ); + } + + #[test(tokio::test)] + async fn boxed_connect_handler_should_generate_valid_boxed_connect_handler() { + let handler = boxed_connect_handler!(|_destination, _options, _authenticator| { + Err(io::Error::from(io::ErrorKind::Other)) + }); + assert_eq!( + handler + .connect( + &test_destination(), + &test_options(), + &mut test_authenticator() + ) + .await + .unwrap_err() + .kind(), + io::ErrorKind::Other + ); + + let handler = boxed_connect_handler!(|_destination, _options, _authenticator| async { + Err(io::Error::from(io::ErrorKind::Other)) + }); + assert_eq!( + handler + .connect( + &test_destination(), + &test_options(), + &mut test_authenticator() + ) + .await + .unwrap_err() + .kind(), + io::ErrorKind::Other + ); + + let handler = boxed_connect_handler!(move |_destination, _options, _authenticator| { + Err(io::Error::from(io::ErrorKind::Other)) + }); + assert_eq!( + handler + .connect( + &test_destination(), + &test_options(), + &mut test_authenticator() + ) + .await + .unwrap_err() + .kind(), + io::ErrorKind::Other + ); + + let handler = boxed_connect_handler!(move |_destination, _options, _authenticator| async { + Err(io::Error::from(io::ErrorKind::Other)) + }); + assert_eq!( + handler + .connect( + &test_destination(), + &test_options(), + &mut test_authenticator() + ) + .await + .unwrap_err() + .kind(), + io::ErrorKind::Other + ); + } +} diff --git a/distant-net/src/packet.rs b/distant-net/src/packet.rs deleted file mode 100644 index 0c26f7a..0000000 --- a/distant-net/src/packet.rs +++ /dev/null @@ -1,254 +0,0 @@ -/// Represents a generic id type -pub type Id = String; - -mod request; -mod response; - -pub use request::*; -pub use response::*; - -#[derive(Clone, Debug, PartialEq, Eq)] -enum MsgPackStrParseError { - InvalidFormat, - Utf8Error(std::str::Utf8Error), -} - -/// Parse msgpack str, returning remaining bytes and str on success, or error on failure -fn parse_msg_pack_str(input: &[u8]) -> Result<(&[u8], &str), MsgPackStrParseError> { - let ilen = input.len(); - if ilen == 0 { - return Err(MsgPackStrParseError::InvalidFormat); - } - - // * fixstr using 0xa0 - 0xbf to mark the start of the str where < 32 bytes - // * str 8 (0xd9) if up to (2^8)-1 bytes, using next byte for len - // * str 16 (0xda) if up to (2^16)-1 bytes, using next two bytes for len - // * str 32 (0xdb) if up to (2^32)-1 bytes, using next four bytes for len - let (input, len): (&[u8], usize) = if input[0] >= 0xa0 && input[0] <= 0xbf { - (&input[1..], (input[0] & 0b00011111).into()) - } else if input[0] == 0xd9 && ilen > 2 { - (&input[2..], input[1].into()) - } else if input[0] == 0xda && ilen > 3 { - (&input[3..], u16::from_be_bytes([input[1], input[2]]).into()) - } else if input[0] == 0xdb && ilen > 5 { - ( - &input[5..], - u32::from_be_bytes([input[1], input[2], input[3], input[4]]) - .try_into() - .unwrap(), - ) - } else { - return Err(MsgPackStrParseError::InvalidFormat); - }; - - let s = match std::str::from_utf8(&input[..len]) { - Ok(s) => s, - Err(x) => return Err(MsgPackStrParseError::Utf8Error(x)), - }; - - Ok((&input[len..], s)) -} - -#[cfg(test)] -mod tests { - use super::*; - - mod parse_msg_pack_str { - use super::*; - - #[test] - fn should_be_able_to_parse_fixstr() { - // Empty str - let (input, s) = parse_msg_pack_str(&[0xa0]).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, ""); - - // Single character - let (input, s) = parse_msg_pack_str(&[0xa1, b'a']).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, "a"); - - // 31 byte str - let (input, s) = parse_msg_pack_str(&[ - 0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', - ]) - .unwrap(); - assert!(input.is_empty()); - assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); - - // Verify that we only consume up to fixstr length - assert_eq!(parse_msg_pack_str(&[0xa0, b'a']).unwrap().0, b"a"); - assert_eq!( - parse_msg_pack_str(&[ - 0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', b'a', b'a', b'b' - ]) - .unwrap() - .0, - b"b" - ); - } - - #[test] - fn should_be_able_to_parse_str_8() { - // 32 byte str - let (input, s) = parse_msg_pack_str(&[ - 0xd9, 32, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', - b'a', b'a', b'a', b'a', b'a', b'a', - ]) - .unwrap(); - assert!(input.is_empty()); - assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); - - // 2^8 - 1 (255) byte str - let test_str = "a".repeat(2usize.pow(8) - 1); - let mut input = vec![0xd9, 255]; - input.extend_from_slice(test_str.as_bytes()); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, test_str); - - // Verify that we only consume up to 2^8 - 1 length - let mut input = vec![0xd9, 255]; - input.extend_from_slice(test_str.as_bytes()); - input.extend_from_slice(b"hello"); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert_eq!(input, b"hello"); - assert_eq!(s, test_str); - } - - #[test] - fn should_be_able_to_parse_str_16() { - // 2^8 byte str (256) - let test_str = "a".repeat(2usize.pow(8)); - let mut input = vec![0xda, 1, 0]; - input.extend_from_slice(test_str.as_bytes()); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, test_str); - - // 2^16 - 1 (65535) byte str - let test_str = "a".repeat(2usize.pow(16) - 1); - let mut input = vec![0xda, 255, 255]; - input.extend_from_slice(test_str.as_bytes()); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, test_str); - - // Verify that we only consume up to 2^16 - 1 length - let mut input = vec![0xda, 255, 255]; - input.extend_from_slice(test_str.as_bytes()); - input.extend_from_slice(b"hello"); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert_eq!(input, b"hello"); - assert_eq!(s, test_str); - } - - #[test] - fn should_be_able_to_parse_str_32() { - // 2^16 byte str - let test_str = "a".repeat(2usize.pow(16)); - let mut input = vec![0xdb, 0, 1, 0, 0]; - input.extend_from_slice(test_str.as_bytes()); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, test_str); - - // NOTE: We are not going to run the below tests, not because they aren't valid but - // because this generates a 4GB str which takes 20+ seconds to run - - // 2^32 - 1 byte str (4294967295 bytes) - /* let test_str = "a".repeat(2usize.pow(32) - 1); - let mut input = vec![0xdb, 255, 255, 255, 255]; - input.extend_from_slice(test_str.as_bytes()); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert!(input.is_empty()); - assert_eq!(s, test_str); */ - - // Verify that we only consume up to 2^32 - 1 length - /* let mut input = vec![0xdb, 255, 255, 255, 255]; - input.extend_from_slice(test_str.as_bytes()); - input.extend_from_slice(b"hello"); - let (input, s) = parse_msg_pack_str(&input).unwrap(); - assert_eq!(input, b"hello"); - assert_eq!(s, test_str); */ - } - - #[test] - fn should_fail_parsing_str_with_invalid_length() { - // Make sure that parse doesn't fail looking for bytes after str 8 len - assert_eq!( - parse_msg_pack_str(&[0xd9]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xd9, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - - // Make sure that parse doesn't fail looking for bytes after str 16 len - assert_eq!( - parse_msg_pack_str(&[0xda]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xda, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xda, 0, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - - // Make sure that parse doesn't fail looking for bytes after str 32 len - assert_eq!( - parse_msg_pack_str(&[0xdb]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xdb, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xdb, 0, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xdb, 0, 0, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - assert_eq!( - parse_msg_pack_str(&[0xdb, 0, 0, 0, 0]), - Err(MsgPackStrParseError::InvalidFormat) - ); - } - - #[test] - fn should_fail_parsing_other_types() { - assert_eq!( - parse_msg_pack_str(&[0xc3]), // Boolean (true) - Err(MsgPackStrParseError::InvalidFormat) - ); - } - - #[test] - fn should_fail_if_empty_input() { - assert_eq!( - parse_msg_pack_str(&[]), - Err(MsgPackStrParseError::InvalidFormat) - ); - } - - #[test] - fn should_fail_if_str_is_not_utf8() { - assert!(matches!( - parse_msg_pack_str(&[0xa4, 0, 159, 146, 150]), - Err(MsgPackStrParseError::Utf8Error(_)) - )); - } - } -} diff --git a/distant-net/src/server.rs b/distant-net/src/server.rs index 9626733..ea5fdae 100644 --- a/distant-net/src/server.rs +++ b/distant-net/src/server.rs @@ -1,18 +1,22 @@ +use crate::common::{authentication::Verifier, Listener, Transport}; use async_trait::async_trait; +use log::*; use serde::{de::DeserializeOwned, Serialize}; +use std::{io, sync::Arc, time::Duration}; +use tokio::sync::RwLock; + +mod builder; +pub use builder::*; mod config; pub use config::*; mod connection; -pub use connection::*; +use connection::*; mod context; pub use context::*; -mod ext; -pub use ext::*; - mod r#ref; pub use r#ref::*; @@ -20,38 +24,417 @@ mod reply; pub use reply::*; mod state; -pub use state::*; +use state::*; + +mod shutdown_timer; +use shutdown_timer::*; + +/// Represents a server that can be used to receive requests & send responses to clients. +pub struct Server { + /// Custom configuration details associated with the server + config: ServerConfig, + + /// Handler used to process various server events + handler: T, + + /// Performs authentication using various methods + verifier: Verifier, +} -/// Interface for a general-purpose server that receives requests to handle +/// Interface for a handler that receives connections and requests #[async_trait] -pub trait Server: Send { +pub trait ServerHandler: Send { /// Type of data received by the server - type Request: DeserializeOwned + Send + Sync; + type Request; /// Type of data sent back by the server - type Response: Serialize + Send; + type Response; /// Type of data to store locally tied to the specific connection - type LocalData: Send + Sync; + type LocalData: Send; - /// Returns configuration tied to server instance - fn config(&self) -> ServerConfig { - ServerConfig::default() - } - - /// Invoked immediately on server start, being provided the raw listener to use (untyped - /// transport), and returning the listener when ready to start (enabling servers that need to - /// tweak a listener to do so) - /* async fn on_start(&mut self, listener: L) -> Box> { - } */ - - /// Invoked upon a new connection becoming established, which provides a mutable reference to - /// the data created for the connection. This can be useful in performing some additional - /// initialization on the data prior to it being used anywhere else. + /// Invoked upon a new connection becoming established. + /// + /// ### Note + /// + /// This can be useful in performing some additional initialization on the connection's local + /// data prior to it being used anywhere else. #[allow(unused_variables)] - async fn on_accept(&self, local_data: &mut Self::LocalData) {} + async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { + Ok(()) + } /// Invoked upon receiving a request from a client. The server should process this /// request, which can be found in `ctx`, and send one or more replies in response. async fn on_request(&self, ctx: ServerCtx); } + +impl Server<()> { + /// Creates a new [`Server`], starting with a default configuration, no authentication methods, + /// and no [`ServerHandler`]. + pub fn new() -> Self { + Self { + config: Default::default(), + handler: (), + verifier: Verifier::empty(), + } + } + + /// Creates a new [`TcpServerBuilder`] that is used to construct a [`Server`]. + pub fn tcp() -> TcpServerBuilder<()> { + TcpServerBuilder::default() + } + + /// Creates a new [`UnixSocketServerBuilder`] that is used to construct a [`Server`]. + #[cfg(unix)] + pub fn unix_socket() -> UnixSocketServerBuilder<()> { + UnixSocketServerBuilder::default() + } + + /// Creates a new [`WindowsPipeServerBuilder`] that is used to construct a [`Server`]. + #[cfg(windows)] + pub fn windows_pipe() -> WindowsPipeServerBuilder<()> { + WindowsPipeServerBuilder::default() + } +} + +impl Default for Server<()> { + fn default() -> Self { + Self::new() + } +} + +impl Server { + /// Consumes the current server, replacing its config with `config` and returning it. + pub fn config(self, config: ServerConfig) -> Self { + Self { + config, + handler: self.handler, + verifier: self.verifier, + } + } + + /// Consumes the current server, replacing its handler with `handler` and returning it. + pub fn handler(self, handler: U) -> Server { + Server { + config: self.config, + handler, + verifier: self.verifier, + } + } + + /// Consumes the current server, replacing its verifier with `verifier` and returning it. + pub fn verifier(self, verifier: Verifier) -> Self { + Self { + config: self.config, + handler: self.handler, + verifier, + } + } +} + +impl Server +where + T: ServerHandler + Sync + 'static, + T::Request: DeserializeOwned + Send + Sync + 'static, + T::Response: Serialize + Send + 'static, + T::LocalData: Default + Send + Sync + 'static, +{ + /// Consumes the server, starting a task to process connections from the `listener` and + /// returning a [`ServerRef`] that can be used to control the active server instance. + pub fn start(self, listener: L) -> io::Result> + where + L: Listener + 'static, + L::Output: Transport + 'static, + { + let state = Arc::new(ServerState::new()); + let task = tokio::spawn(self.task(Arc::clone(&state), listener)); + + Ok(Box::new(GenericServerRef { state, task })) + } + + /// Internal task that is run to receive connections and spawn connection tasks + async fn task(self, state: Arc, mut listener: L) + where + L: Listener + 'static, + L::Output: Transport + 'static, + { + let Server { + config, + handler, + verifier, + } = self; + + let handler = Arc::new(handler); + let timer = ShutdownTimer::start(config.shutdown); + let mut notification = timer.clone_notification(); + let timer = Arc::new(RwLock::new(timer)); + let verifier = Arc::new(verifier); + + loop { + // Receive a new connection, exiting if no longer accepting connections or if the shutdown + // signal has been received + let transport = tokio::select! { + result = listener.accept() => { + match result { + Ok(x) => x, + Err(x) => { + error!("Server no longer accepting connections: {x}"); + timer.read().await.abort(); + break; + } + } + } + _ = notification.wait() => { + info!( + "Server shutdown triggered after {}s", + config.shutdown.duration().unwrap_or_default().as_secs_f32(), + ); + + for (id, task) in state.connections.write().await.drain() { + info!("Terminating task {id}"); + task.abort(); + } + + break; + } + }; + + // Ensure that the shutdown timer is cancelled now that we have a connection + timer.read().await.stop(); + + let connection = ConnectionTask::build() + .handler(Arc::downgrade(&handler)) + .state(Arc::downgrade(&state)) + .keychain(state.keychain.clone()) + .transport(transport) + .shutdown_timer(Arc::downgrade(&timer)) + .sleep_duration(config.connection_sleep) + .verifier(Arc::downgrade(&verifier)) + .spawn(); + + state + .connections + .write() + .await + .insert(connection.id(), connection); + } + + // Once we stop listening, we still want to wait until all connections have terminated + info!("Server waiting for active connections to terminate"); + while state.has_active_connections().await { + tokio::time::sleep(Duration::from_millis(50)).await; + } + info!("Server task terminated"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::{ + authentication::{AuthenticationMethod, DummyAuthHandler, NoneAuthenticationMethod}, + Connection, InmemoryTransport, MpscListener, Request, Response, + }; + use async_trait::async_trait; + use std::time::Duration; + use test_log::test; + use tokio::sync::mpsc; + + pub struct TestServerHandler; + + #[async_trait] + impl ServerHandler for TestServerHandler { + type Request = u16; + type Response = String; + type LocalData = (); + + async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { + Ok(()) + } + + async fn on_request(&self, ctx: ServerCtx) { + // Always send back "hello" + ctx.reply.send("hello".to_string()).await.unwrap(); + } + } + + #[inline] + fn make_test_server(config: ServerConfig) -> Server { + let methods: Vec> = + vec![Box::new(NoneAuthenticationMethod::new())]; + + Server { + config, + handler: TestServerHandler, + verifier: Verifier::new(methods), + } + } + + #[allow(clippy::type_complexity)] + fn make_listener( + buffer: usize, + ) -> ( + mpsc::Sender, + MpscListener, + ) { + MpscListener::channel(buffer) + } + + #[test(tokio::test)] + async fn should_invoke_handler_upon_receiving_a_request() { + // Create a test listener where we will forward a connection + let (tx, listener) = make_listener(100); + + // Make bounded transport pair and send off one of them to act as our connection + let (transport, connection) = InmemoryTransport::pair(100); + tx.send(connection) + .await + .expect("Failed to feed listener a connection"); + + let _server = make_test_server(ServerConfig::default()) + .start(listener) + .expect("Failed to start server"); + + // Perform handshake and authentication with the server before beginning to send data + let mut connection = Connection::client(transport, DummyAuthHandler) + .await + .expect("Failed to connect to server"); + + connection + .write_frame(Request::new(123).to_vec().unwrap()) + .await + .expect("Failed to send request"); + + // Wait for a response + let frame = connection.read_frame().await.unwrap().unwrap(); + let response: Response = Response::from_slice(frame.as_item()).unwrap(); + assert_eq!(response.payload, "hello"); + } + + #[test(tokio::test)] + async fn should_lonely_shutdown_if_no_connections_received_after_n_secs_when_config_set() { + let (_tx, listener) = make_listener(100); + + let server = make_test_server(ServerConfig { + shutdown: Shutdown::Lonely(Duration::from_millis(100)), + ..Default::default() + }) + .start(listener) + .expect("Failed to start server"); + + // Wait for some time + tokio::time::sleep(Duration::from_millis(300)).await; + + assert!(server.is_finished(), "Server shutdown not triggered!"); + } + + #[test(tokio::test)] + async fn should_lonely_shutdown_if_last_connection_terminated_and_then_no_connections_after_n_secs( + ) { + // Create a test listener where we will forward a connection + let (tx, listener) = make_listener(100); + + // Make bounded transport pair and send off one of them to act as our connection + let (transport, connection) = InmemoryTransport::pair(100); + tx.send(connection) + .await + .expect("Failed to feed listener a connection"); + + let server = make_test_server(ServerConfig { + shutdown: Shutdown::Lonely(Duration::from_millis(100)), + ..Default::default() + }) + .start(listener) + .expect("Failed to start server"); + + // Drop the connection by dropping the transport + drop(transport); + + // Wait for some time + tokio::time::sleep(Duration::from_millis(300)).await; + + assert!(server.is_finished(), "Server shutdown not triggered!"); + } + + #[test(tokio::test)] + async fn should_not_lonely_shutdown_as_long_as_a_connection_exists() { + // Create a test listener where we will forward a connection + let (tx, listener) = make_listener(100); + + // Make bounded transport pair and send off one of them to act as our connection + let (_transport, connection) = InmemoryTransport::pair(100); + tx.send(connection) + .await + .expect("Failed to feed listener a connection"); + + let server = make_test_server(ServerConfig { + shutdown: Shutdown::Lonely(Duration::from_millis(100)), + ..Default::default() + }) + .start(listener) + .expect("Failed to start server"); + + // Wait for some time + tokio::time::sleep(Duration::from_millis(300)).await; + + assert!(!server.is_finished(), "Server shutdown when it should not!"); + } + + #[test(tokio::test)] + async fn should_shutdown_after_n_seconds_even_with_connections_if_config_set_to_after() { + let (tx, listener) = make_listener(100); + + // Make bounded transport pair and send off one of them to act as our connection + let (_transport, connection) = InmemoryTransport::pair(100); + tx.send(connection) + .await + .expect("Failed to feed listener a connection"); + + let server = make_test_server(ServerConfig { + shutdown: Shutdown::After(Duration::from_millis(100)), + ..Default::default() + }) + .start(listener) + .expect("Failed to start server"); + + // Wait for some time + tokio::time::sleep(Duration::from_millis(300)).await; + + assert!(server.is_finished(), "Server shutdown not triggered!"); + } + + #[test(tokio::test)] + async fn should_shutdown_after_n_seconds_if_config_set_to_after() { + let (_tx, listener) = make_listener(100); + + let server = make_test_server(ServerConfig { + shutdown: Shutdown::After(Duration::from_millis(100)), + ..Default::default() + }) + .start(listener) + .expect("Failed to start server"); + + // Wait for some time + tokio::time::sleep(Duration::from_millis(300)).await; + + assert!(server.is_finished(), "Server shutdown not triggered!"); + } + + #[test(tokio::test)] + async fn should_never_shutdown_if_config_set_to_never() { + let (_tx, listener) = make_listener(100); + + let server = make_test_server(ServerConfig { + shutdown: Shutdown::Never, + ..Default::default() + }) + .start(listener) + .expect("Failed to start server"); + + // Wait for some time + tokio::time::sleep(Duration::from_millis(300)).await; + + assert!(!server.is_finished(), "Server shutdown when it should not!"); + } +} diff --git a/distant-net/src/client/ext.rs b/distant-net/src/server/builder.rs similarity index 99% rename from distant-net/src/client/ext.rs rename to distant-net/src/server/builder.rs index d23a3d2..a732474 100644 --- a/distant-net/src/client/ext.rs +++ b/distant-net/src/server/builder.rs @@ -1,14 +1,15 @@ mod tcp; -pub use tcp::*; #[cfg(unix)] mod unix; -#[cfg(unix)] -pub use unix::*; - #[cfg(windows)] mod windows; +pub use tcp::*; + +#[cfg(unix)] +pub use unix::*; + #[cfg(windows)] pub use windows::*; diff --git a/distant-net/src/server/builder/tcp.rs b/distant-net/src/server/builder/tcp.rs new file mode 100644 index 0000000..5feb2ce --- /dev/null +++ b/distant-net/src/server/builder/tcp.rs @@ -0,0 +1,102 @@ +use crate::common::{authentication::Verifier, PortRange, TcpListener}; +use crate::server::{Server, ServerConfig, ServerHandler, TcpServerRef}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{io, net::IpAddr}; + +pub struct TcpServerBuilder(Server); + +impl Server { + /// Consume [`Server`] and produce a builder for a TCP variant. + pub fn into_tcp_builder(self) -> TcpServerBuilder { + TcpServerBuilder(self) + } +} + +impl Default for TcpServerBuilder<()> { + fn default() -> Self { + Self(Server::new()) + } +} + +impl TcpServerBuilder { + pub fn config(self, config: ServerConfig) -> Self { + Self(self.0.config(config)) + } + + pub fn handler(self, handler: U) -> TcpServerBuilder { + TcpServerBuilder(self.0.handler(handler)) + } + + pub fn verifier(self, verifier: Verifier) -> Self { + Self(self.0.verifier(verifier)) + } +} + +impl TcpServerBuilder +where + T: ServerHandler + Sync + 'static, + T::Request: DeserializeOwned + Send + Sync + 'static, + T::Response: Serialize + Send + 'static, + T::LocalData: Default + Send + Sync + 'static, +{ + pub async fn start

(self, addr: IpAddr, port: P) -> io::Result + where + P: Into + Send, + { + let listener = TcpListener::bind(addr, port).await?; + let port = listener.port(); + let inner = self.0.start(listener)?; + Ok(TcpServerRef { addr, port, inner }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::Client; + use crate::common::{authentication::DummyAuthHandler, Request}; + use crate::server::ServerCtx; + use async_trait::async_trait; + use std::net::{Ipv6Addr, SocketAddr}; + use test_log::test; + + pub struct TestServerHandler; + + #[async_trait] + impl ServerHandler for TestServerHandler { + type Request = String; + type Response = String; + type LocalData = (); + + async fn on_request(&self, ctx: ServerCtx) { + // Echo back what we received + ctx.reply + .send(ctx.request.payload.to_string()) + .await + .unwrap(); + } + } + + #[test(tokio::test)] + async fn should_invoke_handler_upon_receiving_a_request() { + let server = TcpServerBuilder::default() + .handler(TestServerHandler) + .verifier(Verifier::none()) + .start(IpAddr::V6(Ipv6Addr::LOCALHOST), 0) + .await + .expect("Failed to start TCP server"); + + let mut client: Client = + Client::tcp(SocketAddr::from((server.ip_addr(), server.port()))) + .auth_handler(DummyAuthHandler) + .connect() + .await + .expect("Client failed to connect"); + + let response = client + .send(Request::new("hello".to_string())) + .await + .expect("Failed to send message"); + assert_eq!(response.payload, "hello"); + } +} diff --git a/distant-net/src/server/builder/unix.rs b/distant-net/src/server/builder/unix.rs new file mode 100644 index 0000000..94fd95e --- /dev/null +++ b/distant-net/src/server/builder/unix.rs @@ -0,0 +1,108 @@ +use crate::common::{authentication::Verifier, UnixSocketListener}; +use crate::server::{Server, ServerConfig, ServerHandler, UnixSocketServerRef}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{io, path::Path}; + +pub struct UnixSocketServerBuilder(Server); + +impl Server { + /// Consume [`Server`] and produce a builder for a Unix socket variant. + pub fn into_unix_socket_builder(self) -> UnixSocketServerBuilder { + UnixSocketServerBuilder(self) + } +} + +impl Default for UnixSocketServerBuilder<()> { + fn default() -> Self { + Self(Server::new()) + } +} + +impl UnixSocketServerBuilder { + pub fn config(self, config: ServerConfig) -> Self { + Self(self.0.config(config)) + } + + pub fn handler(self, handler: U) -> UnixSocketServerBuilder { + UnixSocketServerBuilder(self.0.handler(handler)) + } + + pub fn verifier(self, verifier: Verifier) -> Self { + Self(self.0.verifier(verifier)) + } +} + +impl UnixSocketServerBuilder +where + T: ServerHandler + Sync + 'static, + T::Request: DeserializeOwned + Send + Sync + 'static, + T::Response: Serialize + Send + 'static, + T::LocalData: Default + Send + Sync + 'static, +{ + pub async fn start

(self, path: P) -> io::Result + where + P: AsRef + Send, + { + let path = path.as_ref(); + let listener = UnixSocketListener::bind(path).await?; + let path = listener.path().to_path_buf(); + let inner = self.0.start(listener)?; + Ok(UnixSocketServerRef { path, inner }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::Client; + use crate::common::{authentication::DummyAuthHandler, Request}; + use crate::server::ServerCtx; + use async_trait::async_trait; + use tempfile::NamedTempFile; + use test_log::test; + + pub struct TestServerHandler; + + #[async_trait] + impl ServerHandler for TestServerHandler { + type Request = String; + type Response = String; + type LocalData = (); + + async fn on_request(&self, ctx: ServerCtx) { + // Echo back what we received + ctx.reply + .send(ctx.request.payload.to_string()) + .await + .unwrap(); + } + } + + #[test(tokio::test)] + async fn should_invoke_handler_upon_receiving_a_request() { + // Generate a socket path and delete the file after so there is nothing there + let path = NamedTempFile::new() + .expect("Failed to create socket file") + .path() + .to_path_buf(); + + let server = UnixSocketServerBuilder::default() + .handler(TestServerHandler) + .verifier(Verifier::none()) + .start(path) + .await + .expect("Failed to start Unix socket server"); + + let mut client: Client = Client::unix_socket(server.path()) + .auth_handler(DummyAuthHandler) + .connect() + .await + .expect("Client failed to connect"); + + let response = client + .send(Request::new("hello".to_string())) + .await + .expect("Failed to send message"); + assert_eq!(response.payload, "hello"); + } +} diff --git a/distant-net/src/server/builder/windows.rs b/distant-net/src/server/builder/windows.rs new file mode 100644 index 0000000..96e603a --- /dev/null +++ b/distant-net/src/server/builder/windows.rs @@ -0,0 +1,116 @@ +use crate::common::{authentication::Verifier, WindowsPipeListener}; +use crate::server::{Server, ServerConfig, ServerHandler, WindowsPipeServerRef}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + ffi::{OsStr, OsString}, + io, +}; + +pub struct WindowsPipeServerBuilder(Server); + +impl Server { + /// Consume [`Server`] and produce a builder for a Windows pipe variant. + pub fn into_windows_pipe_builder(self) -> WindowsPipeServerBuilder { + WindowsPipeServerBuilder(self) + } +} + +impl Default for WindowsPipeServerBuilder<()> { + fn default() -> Self { + Self(Server::new()) + } +} + +impl WindowsPipeServerBuilder { + pub fn config(self, config: ServerConfig) -> Self { + Self(self.0.config(config)) + } + + pub fn handler(self, handler: U) -> WindowsPipeServerBuilder { + WindowsPipeServerBuilder(self.0.handler(handler)) + } + + pub fn verifier(self, verifier: Verifier) -> Self { + Self(self.0.verifier(verifier)) + } +} + +impl WindowsPipeServerBuilder +where + T: ServerHandler + Sync + 'static, + T::Request: DeserializeOwned + Send + Sync + 'static, + T::Response: Serialize + Send + 'static, + T::LocalData: Default + Send + Sync + 'static, +{ + /// Start a new server at the specified address using the given codec + pub async fn start(self, addr: A) -> io::Result + where + A: AsRef + Send, + { + let a = addr.as_ref(); + let listener = WindowsPipeListener::bind(a)?; + let addr = listener.addr().to_os_string(); + let inner = self.0.start(listener)?; + Ok(WindowsPipeServerRef { addr, inner }) + } + + /// Start a new server at the specified address via `\\.\pipe\{name}` using the given codec + pub async fn start_local(self, name: N) -> io::Result + where + Self: Sized, + N: AsRef + Send, + { + let mut addr = OsString::from(r"\\.\pipe\"); + addr.push(name.as_ref()); + self.start(addr).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::Client; + use crate::common::{authentication::DummyAuthHandler, Request}; + use crate::server::ServerCtx; + use async_trait::async_trait; + use test_log::test; + + pub struct TestServerHandler; + + #[async_trait] + impl ServerHandler for TestServerHandler { + type Request = String; + type Response = String; + type LocalData = (); + + async fn on_request(&self, ctx: ServerCtx) { + // Echo back what we received + ctx.reply + .send(ctx.request.payload.to_string()) + .await + .unwrap(); + } + } + + #[test(tokio::test)] + async fn should_invoke_handler_upon_receiving_a_request() { + let server = WindowsPipeServerBuilder::default() + .handler(TestServerHandler) + .verifier(Verifier::none()) + .start_local(format!("test_pipe_{}", rand::random::())) + .await + .expect("Failed to start Windows pipe server"); + + let mut client: Client = Client::windows_pipe(server.addr()) + .auth_handler(DummyAuthHandler) + .connect() + .await + .expect("Client failed to connect"); + + let response = client + .send(Request::new("hello".to_string())) + .await + .expect("Failed to send message"); + assert_eq!(response.payload, "hello"); + } +} diff --git a/distant-net/src/server/config.rs b/distant-net/src/server/config.rs index 6bc6938..5bdb098 100644 --- a/distant-net/src/server/config.rs +++ b/distant-net/src/server/config.rs @@ -2,13 +2,27 @@ use derive_more::{Display, Error}; use serde::{Deserialize, Serialize}; use std::{num::ParseFloatError, str::FromStr, time::Duration}; +const DEFAULT_CONNECTION_SLEEP: Duration = Duration::from_millis(1); + /// Represents a general-purpose set of properties tied with a server instance -#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct ServerConfig { + /// Time to wait inbetween connection read/write when nothing was read or written on last pass + pub connection_sleep: Duration, + /// Rules for how a server will shutdown automatically pub shutdown: Shutdown, } +impl Default for ServerConfig { + fn default() -> Self { + Self { + connection_sleep: DEFAULT_CONNECTION_SLEEP, + shutdown: Default::default(), + } + } +} + /// Rules for how a server will shut itself down automatically #[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] pub enum Shutdown { diff --git a/distant-net/src/server/connection.rs b/distant-net/src/server/connection.rs index f4e17cf..c5e3c11 100644 --- a/distant-net/src/server/connection.rs +++ b/distant-net/src/server/connection.rs @@ -1,51 +1,755 @@ -use crate::ConnectionId; -use tokio::task::JoinHandle; +use super::{ConnectionCtx, ServerCtx, ServerHandler, ServerReply, ServerState, ShutdownTimer}; +use crate::common::{ + authentication::{Keychain, Verifier}, + Backup, Connection, ConnectionId, Interest, Response, Transport, UntypedRequest, +}; +use log::*; +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + future::Future, + io, + pin::Pin, + sync::{Arc, Weak}, + task::{Context, Poll}, + time::Duration, +}; +use tokio::{ + sync::{mpsc, oneshot, RwLock}, + task::JoinHandle, +}; + +pub type ServerKeychain = Keychain>; + +/// Time to wait inbetween connection read/write when nothing was read or written on last pass +const SLEEP_DURATION: Duration = Duration::from_millis(1); /// Represents an individual connection on the server -pub struct ServerConnection { +pub struct ConnectionTask { /// Unique identifier tied to the connection - pub id: ConnectionId, + id: ConnectionId, + + /// Task that is processing requests and responses + task: JoinHandle>, +} + +impl ConnectionTask { + /// Starts building a new connection + pub fn build() -> ConnectionTaskBuilder<(), ()> { + let id: ConnectionId = rand::random(); + ConnectionTaskBuilder { + id, + handler: Weak::new(), + state: Weak::new(), + keychain: Keychain::new(), + transport: (), + shutdown_timer: Weak::new(), + sleep_duration: SLEEP_DURATION, + verifier: Weak::new(), + } + } + + /// Returns the id associated with the connection + pub fn id(&self) -> ConnectionId { + self.id + } - /// Task that is processing incoming requests from the connection - pub(crate) reader_task: Option>, + /// Returns true if the task has finished + pub fn is_finished(&self) -> bool { + self.task.is_finished() + } - /// Task that is processing outgoing responses to the connection - pub(crate) writer_task: Option>, + /// Aborts the connection + pub fn abort(&self) { + self.task.abort(); + } } -impl Default for ServerConnection { - fn default() -> Self { - Self::new() +impl Future for ConnectionTask { + type Output = io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Future::poll(Pin::new(&mut self.task), cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => match x { + Ok(x) => Poll::Ready(x), + Err(x) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, x))), + }, + } } } -impl ServerConnection { - /// Creates a new connection, generating a unique id to represent the connection - pub fn new() -> Self { - Self { - id: rand::random(), - reader_task: None, - writer_task: None, +pub struct ConnectionTaskBuilder { + id: ConnectionId, + handler: Weak, + state: Weak, + keychain: Keychain>, + transport: T, + shutdown_timer: Weak>, + sleep_duration: Duration, + verifier: Weak, +} + +impl ConnectionTaskBuilder { + pub fn handler(self, handler: Weak) -> ConnectionTaskBuilder { + ConnectionTaskBuilder { + id: self.id, + handler, + state: self.state, + keychain: self.keychain, + transport: self.transport, + shutdown_timer: self.shutdown_timer, + sleep_duration: self.sleep_duration, + verifier: self.verifier, } } - /// Returns true if connection is still processing incoming or outgoing messages - pub fn is_active(&self) -> bool { - let reader_active = - self.reader_task.is_some() && !self.reader_task.as_ref().unwrap().is_finished(); - let writer_active = - self.writer_task.is_some() && !self.writer_task.as_ref().unwrap().is_finished(); - reader_active || writer_active + pub fn state(self, state: Weak) -> ConnectionTaskBuilder { + ConnectionTaskBuilder { + id: self.id, + handler: self.handler, + state, + keychain: self.keychain, + transport: self.transport, + shutdown_timer: self.shutdown_timer, + sleep_duration: self.sleep_duration, + verifier: self.verifier, + } } - /// Aborts the connection - pub fn abort(&self) { - if let Some(task) = self.reader_task.as_ref() { - task.abort(); + pub fn keychain(self, keychain: ServerKeychain) -> ConnectionTaskBuilder { + ConnectionTaskBuilder { + id: self.id, + handler: self.handler, + state: self.state, + keychain, + transport: self.transport, + shutdown_timer: self.shutdown_timer, + sleep_duration: self.sleep_duration, + verifier: self.verifier, + } + } + + pub fn transport(self, transport: U) -> ConnectionTaskBuilder { + ConnectionTaskBuilder { + id: self.id, + handler: self.handler, + keychain: self.keychain, + state: self.state, + transport, + shutdown_timer: self.shutdown_timer, + sleep_duration: self.sleep_duration, + verifier: self.verifier, + } + } + + pub(crate) fn shutdown_timer( + self, + shutdown_timer: Weak>, + ) -> ConnectionTaskBuilder { + ConnectionTaskBuilder { + id: self.id, + handler: self.handler, + state: self.state, + keychain: self.keychain, + transport: self.transport, + shutdown_timer, + sleep_duration: self.sleep_duration, + verifier: self.verifier, + } + } + + pub fn sleep_duration(self, sleep_duration: Duration) -> ConnectionTaskBuilder { + ConnectionTaskBuilder { + id: self.id, + handler: self.handler, + state: self.state, + keychain: self.keychain, + transport: self.transport, + shutdown_timer: self.shutdown_timer, + sleep_duration, + verifier: self.verifier, } + } - if let Some(task) = self.writer_task.as_ref() { - task.abort(); + pub fn verifier(self, verifier: Weak) -> ConnectionTaskBuilder { + ConnectionTaskBuilder { + id: self.id, + handler: self.handler, + state: self.state, + keychain: self.keychain, + transport: self.transport, + shutdown_timer: self.shutdown_timer, + sleep_duration: self.sleep_duration, + verifier, } } } + +impl ConnectionTaskBuilder +where + H: ServerHandler + Sync + 'static, + H::Request: DeserializeOwned + Send + Sync + 'static, + H::Response: Serialize + Send + 'static, + H::LocalData: Default + Send + Sync + 'static, + T: Transport + 'static, +{ + pub fn spawn(self) -> ConnectionTask { + let id = self.id; + + ConnectionTask { + id, + task: tokio::spawn(self.run()), + } + } + + async fn run(self) -> io::Result<()> { + let ConnectionTaskBuilder { + id, + handler, + state, + keychain, + transport, + shutdown_timer, + sleep_duration, + verifier, + } = self; + + // Will check if no more connections and restart timer if that's the case + macro_rules! terminate_connection { + // Prints an error message before terminating the connection by panicking + (@error $($msg:tt)+) => { + error!($($msg)+); + terminate_connection!(); + return Err(io::Error::new(io::ErrorKind::Other, format!($($msg)+))); + }; + + // Prints a debug message before terminating the connection by cleanly returning + (@debug $($msg:tt)+) => { + debug!($($msg)+); + terminate_connection!(); + return Ok(()); + }; + + // Performs the connection termination by removing it from server state and + // restarting the shutdown timer if it was the last connection + () => { + // Remove the connection from our state if it has closed + if let Some(state) = Weak::upgrade(&state) { + state.connections.write().await.remove(&self.id); + + // If we have no more connections, start the timer + if let Some(timer) = Weak::upgrade(&shutdown_timer) { + if state.connections.read().await.is_empty() { + timer.write().await.restart(); + } + } + } + }; + } + + // Properly establish the connection's transport + debug!("[Conn {id}] Establishing full connection"); + let mut connection = match Weak::upgrade(&verifier) { + Some(verifier) => { + match Connection::server(transport, verifier.as_ref(), keychain).await { + Ok(connection) => connection, + Err(x) => { + terminate_connection!(@error "[Conn {id}] Failed to setup connection: {x}"); + } + } + } + None => { + terminate_connection!(@error "[Conn {id}] Verifier has been dropped"); + } + }; + + // Attempt to upgrade our handler for use with the connection going forward + debug!("[Conn {id}] Preparing connection handler"); + let handler = match Weak::upgrade(&handler) { + Some(handler) => handler, + None => { + terminate_connection!(@error "[Conn {id}] Handler has been dropped"); + } + }; + + // Construct a queue of outgoing responses + let (tx, mut rx) = mpsc::channel::>(1); + + // Create local data for the connection and then process it + debug!("[Conn {id}] Officially accepting connection"); + let mut local_data = H::LocalData::default(); + if let Err(x) = handler + .on_accept(ConnectionCtx { + connection_id: id, + local_data: &mut local_data, + }) + .await + { + terminate_connection!(@error "[Conn {id}] Accepting connection failed: {x}"); + } + + let local_data = Arc::new(local_data); + + debug!("[Conn {id}] Beginning read/write loop"); + loop { + let ready = match connection + .ready(Interest::READABLE | Interest::WRITABLE) + .await + { + Ok(ready) => ready, + Err(x) => { + terminate_connection!(@error "[Conn {id}] Failed to examine ready state: {x}"); + } + }; + + // Keep track of whether we read or wrote anything + let mut read_blocked = !ready.is_readable(); + let mut write_blocked = !ready.is_writable(); + + if ready.is_readable() { + match connection.try_read_frame() { + Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) { + Ok(request) => match request.to_typed_request() { + Ok(request) => { + let reply = ServerReply { + origin_id: request.id.clone(), + tx: tx.clone(), + }; + + let ctx = ServerCtx { + connection_id: id, + request, + reply: reply.clone(), + local_data: Arc::clone(&local_data), + }; + + // Spawn a new task to run the request handler so we don't block + // our connection from processing other requests + let handler = Arc::clone(&handler); + tokio::spawn(async move { handler.on_request(ctx).await }); + } + Err(x) => { + if log::log_enabled!(Level::Trace) { + trace!( + "[Conn {id}] Failed receiving {}", + String::from_utf8_lossy(&request.payload), + ); + } + + error!("[Conn {id}] Invalid request: {x}"); + } + }, + Err(x) => { + error!("[Conn {id}] Invalid request payload: {x}"); + } + }, + Ok(None) => { + terminate_connection!(@debug "[Conn {id}] Connection closed"); + } + Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true, + Err(x) => { + // NOTE: We do NOT break out of the loop, as this could happen + // if someone sends bad data at any point, but does not + // mean that the reader itself has failed. This can + // happen from getting non-compliant typed data + error!("[Conn {id}] {x}"); + } + } + } + + // If our socket is ready to be written to, we try to get the next item from + // the queue and process it + if ready.is_writable() { + // If we get more data to write, attempt to write it, which will result in writing + // any queued bytes as well. Othewise, we attempt to flush any pending outgoing + // bytes that weren't sent earlier. + if let Ok(response) = rx.try_recv() { + // Log our message as a string, which can be expensive + if log_enabled!(Level::Trace) { + trace!( + "[Conn {id}] Sending {}", + &response + .to_vec() + .map(|x| String::from_utf8_lossy(&x).to_string()) + .unwrap_or_else(|_| "".to_string()) + ); + } + + match response.to_vec() { + Ok(data) => match connection.try_write_frame(data) { + Ok(()) => (), + Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true, + Err(x) => error!("[Conn {id}] Send failed: {x}"), + }, + Err(x) => { + error!("[Conn {id}] Unable to serialize outgoing response: {x}"); + } + } + } else { + // In the case of flushing, there are two scenarios in which we want to + // mark no write occurring: + // + // 1. When flush did not write any bytes, which can happen when the buffer + // is empty + // 2. When the call to write bytes blocks + match connection.try_flush() { + Ok(0) => write_blocked = true, + Ok(_) => (), + Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true, + Err(x) => { + error!("[Conn {id}] Failed to flush outgoing data: {x}"); + } + } + } + } + + // If we did not read or write anything, sleep a bit to offload CPU usage + if read_blocked && write_blocked { + tokio::time::sleep(sleep_duration).await; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::authentication::DummyAuthHandler; + use crate::common::{ + HeapSecretKey, InmemoryTransport, Ready, Reconnectable, Request, Response, + }; + use crate::server::Shutdown; + use async_trait::async_trait; + use std::sync::atomic::{AtomicBool, Ordering}; + use test_log::test; + + struct TestServerHandler; + + #[async_trait] + impl ServerHandler for TestServerHandler { + type Request = u16; + type Response = String; + type LocalData = (); + + async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { + Ok(()) + } + + async fn on_request(&self, ctx: ServerCtx) { + // Always send back "hello" + ctx.reply.send("hello".to_string()).await.unwrap(); + } + } + + macro_rules! wait_for_termination { + ($task:ident) => {{ + let timeout_millis = 500; + let sleep_millis = 50; + let start = std::time::Instant::now(); + while !$task.is_finished() { + if start.elapsed() > std::time::Duration::from_millis(timeout_millis) { + panic!("Exceeded timeout of {timeout_millis}ms"); + } + tokio::time::sleep(std::time::Duration::from_millis(sleep_millis)).await; + } + }}; + } + + #[test(tokio::test)] + async fn should_terminate_if_fails_access_verifier() { + let handler = Arc::new(TestServerHandler); + let state = Arc::new(ServerState::default()); + let keychain = ServerKeychain::new(); + let (t1, _t2) = InmemoryTransport::pair(100); + let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never))); + + let task = ConnectionTask::build() + .handler(Arc::downgrade(&handler)) + .state(Arc::downgrade(&state)) + .keychain(keychain) + .transport(t1) + .shutdown_timer(Arc::downgrade(&shutdown_timer)) + .verifier(Weak::new()) + .spawn(); + + wait_for_termination!(task); + + let err = task.await.unwrap_err(); + assert!( + err.to_string().contains("Verifier has been dropped"), + "Unexpected error: {err}" + ); + } + + #[test(tokio::test)] + async fn should_terminate_if_fails_to_setup_server_connection() { + let handler = Arc::new(TestServerHandler); + let state = Arc::new(ServerState::default()); + let keychain = ServerKeychain::new(); + let (t1, t2) = InmemoryTransport::pair(100); + let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never))); + + // Create a verifier that wants a key, so we will fail from client-side + let verifier = Arc::new(Verifier::static_key(HeapSecretKey::generate(32).unwrap())); + + let task = ConnectionTask::build() + .handler(Arc::downgrade(&handler)) + .state(Arc::downgrade(&state)) + .keychain(keychain) + .transport(t1) + .shutdown_timer(Arc::downgrade(&shutdown_timer)) + .verifier(Arc::downgrade(&verifier)) + .spawn(); + + // Spawn a task to handle establishing connection from client-side + tokio::spawn(async move { + let _client = Connection::client(t2, DummyAuthHandler) + .await + .expect("Fail to establish client-side connection"); + }); + + wait_for_termination!(task); + + let err = task.await.unwrap_err(); + assert!( + err.to_string().contains("Failed to setup connection"), + "Unexpected error: {err}" + ); + } + + #[test(tokio::test)] + async fn should_terminate_if_fails_access_server_handler() { + let state = Arc::new(ServerState::default()); + let keychain = ServerKeychain::new(); + let (t1, t2) = InmemoryTransport::pair(100); + let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never))); + let verifier = Arc::new(Verifier::none()); + + let task = ConnectionTask::build() + .handler(Weak::::new()) + .state(Arc::downgrade(&state)) + .keychain(keychain) + .transport(t1) + .shutdown_timer(Arc::downgrade(&shutdown_timer)) + .verifier(Arc::downgrade(&verifier)) + .spawn(); + + // Spawn a task to handle establishing connection from client-side + tokio::spawn(async move { + let _client = Connection::client(t2, DummyAuthHandler) + .await + .expect("Fail to establish client-side connection"); + }); + + wait_for_termination!(task); + + let err = task.await.unwrap_err(); + assert!( + err.to_string().contains("Handler has been dropped"), + "Unexpected error: {err}" + ); + } + + #[test(tokio::test)] + async fn should_terminate_if_accepting_connection_fails_on_server_handler() { + struct BadAcceptServerHandler; + + #[async_trait] + impl ServerHandler for BadAcceptServerHandler { + type Request = u16; + type Response = String; + type LocalData = (); + + async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { + Err(io::Error::new(io::ErrorKind::Other, "bad accept")) + } + + async fn on_request( + &self, + _: ServerCtx, + ) { + unreachable!(); + } + } + + let handler = Arc::new(BadAcceptServerHandler); + let state = Arc::new(ServerState::default()); + let keychain = ServerKeychain::new(); + let (t1, t2) = InmemoryTransport::pair(100); + let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never))); + let verifier = Arc::new(Verifier::none()); + + let task = ConnectionTask::build() + .handler(Arc::downgrade(&handler)) + .state(Arc::downgrade(&state)) + .keychain(keychain) + .transport(t1) + .shutdown_timer(Arc::downgrade(&shutdown_timer)) + .verifier(Arc::downgrade(&verifier)) + .spawn(); + + // Spawn a task to handle establishing connection from client-side, and then closes to + // trigger the server-side to close + tokio::spawn(async move { + let _client = Connection::client(t2, DummyAuthHandler) + .await + .expect("Fail to establish client-side connection"); + }); + + wait_for_termination!(task); + + let err = task.await.unwrap_err(); + assert!( + err.to_string().contains("Accepting connection failed"), + "Unexpected error: {err}" + ); + } + + #[test(tokio::test)] + async fn should_terminate_if_connection_fails_to_become_ready() { + let handler = Arc::new(TestServerHandler); + let state = Arc::new(ServerState::default()); + let keychain = ServerKeychain::new(); + let (t1, t2) = InmemoryTransport::pair(100); + let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never))); + let verifier = Arc::new(Verifier::none()); + + struct FakeTransport { + inner: InmemoryTransport, + fail_ready: Arc, + } + + #[async_trait] + impl Transport for FakeTransport { + fn try_read(&self, buf: &mut [u8]) -> io::Result { + self.inner.try_read(buf) + } + + fn try_write(&self, buf: &[u8]) -> io::Result { + self.inner.try_write(buf) + } + + async fn ready(&self, interest: Interest) -> io::Result { + if self.fail_ready.load(Ordering::Relaxed) { + Err(io::Error::new( + io::ErrorKind::Other, + "targeted ready failure", + )) + } else { + self.inner.ready(interest).await + } + } + } + + #[async_trait] + impl Reconnectable for FakeTransport { + async fn reconnect(&mut self) -> io::Result<()> { + self.inner.reconnect().await + } + } + + let fail_ready = Arc::new(AtomicBool::new(false)); + let task = ConnectionTask::build() + .handler(Arc::downgrade(&handler)) + .state(Arc::downgrade(&state)) + .keychain(keychain) + .transport(FakeTransport { + inner: t1, + fail_ready: Arc::clone(&fail_ready), + }) + .shutdown_timer(Arc::downgrade(&shutdown_timer)) + .verifier(Arc::downgrade(&verifier)) + .spawn(); + + // Spawn a task to handle establishing connection from client-side, set ready to fail + // for the server-side after client connection completes, and wait a bit + tokio::spawn(async move { + let _client = Connection::client(t2, DummyAuthHandler) + .await + .expect("Fail to establish client-side connection"); + + // NOTE: Need to sleep for a little bit to hand control back to server to finish + // its side of the connection before toggling ready to fail + tokio::time::sleep(Duration::from_millis(50)).await; + + // Toggle ready to fail and then wait awhile so we fail by ready and not connection + // being dropped + fail_ready.store(true, Ordering::Relaxed); + tokio::time::sleep(Duration::from_secs(1)).await; + }); + + wait_for_termination!(task); + + let err = task.await.unwrap_err(); + assert!( + err.to_string().contains("Failed to examine ready state"), + "Unexpected error: {err}" + ); + } + + #[test(tokio::test)] + async fn should_terminate_if_connection_closes() { + let handler = Arc::new(TestServerHandler); + let state = Arc::new(ServerState::default()); + let keychain = ServerKeychain::new(); + let (t1, t2) = InmemoryTransport::pair(100); + let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never))); + let verifier = Arc::new(Verifier::none()); + + let task = ConnectionTask::build() + .handler(Arc::downgrade(&handler)) + .state(Arc::downgrade(&state)) + .keychain(keychain) + .transport(t1) + .shutdown_timer(Arc::downgrade(&shutdown_timer)) + .verifier(Arc::downgrade(&verifier)) + .spawn(); + + // Spawn a task to handle establishing connection from client-side, and then closes to + // trigger the server-side to close + tokio::spawn(async move { + let _client = Connection::client(t2, DummyAuthHandler) + .await + .expect("Fail to establish client-side connection"); + }); + + wait_for_termination!(task); + task.await.unwrap(); + } + + #[test(tokio::test)] + async fn should_invoke_server_handler_to_process_request_in_new_task_and_forward_responses() { + let handler = Arc::new(TestServerHandler); + let state = Arc::new(ServerState::default()); + let keychain = ServerKeychain::new(); + let (t1, t2) = InmemoryTransport::pair(100); + let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never))); + let verifier = Arc::new(Verifier::none()); + + ConnectionTask::build() + .handler(Arc::downgrade(&handler)) + .state(Arc::downgrade(&state)) + .keychain(keychain) + .transport(t1) + .shutdown_timer(Arc::downgrade(&shutdown_timer)) + .verifier(Arc::downgrade(&verifier)) + .spawn(); + + // Spawn a task to handle establishing connection from client-side + let task = tokio::spawn(async move { + let mut client = Connection::client(t2, DummyAuthHandler) + .await + .expect("Fail to establish client-side connection"); + + client.write_frame_for(&Request::new(123u16)).await.unwrap(); + client + .read_frame_as::>() + .await + .unwrap() + .unwrap() + }); + + let response = task.await.unwrap(); + assert_eq!(response.payload, "hello"); + } +} diff --git a/distant-net/src/server/context.rs b/distant-net/src/server/context.rs index 3f3363d..d1ecef6 100644 --- a/distant-net/src/server/context.rs +++ b/distant-net/src/server/context.rs @@ -1,17 +1,27 @@ -use crate::{ConnectionId, Request, ServerReply}; +use super::ServerReply; +use crate::common::{ConnectionId, Request}; use std::sync::Arc; /// Represents contextual information for working with an inbound request -pub struct ServerCtx { +pub struct ServerCtx { /// Unique identifer associated with the connection that sent the request pub connection_id: ConnectionId, /// The request being handled - pub request: Request, + pub request: Request, /// Used to send replies back to be sent out by the server - pub reply: ServerReply, + pub reply: ServerReply, /// Reference to the connection's local data - pub local_data: Arc, + pub local_data: Arc, +} + +/// Represents contextual information for working with an inbound connection +pub struct ConnectionCtx<'a, D> { + /// Unique identifer associated with the connection + pub connection_id: ConnectionId, + + /// Reference to the connection's local data + pub local_data: &'a mut D, } diff --git a/distant-net/src/server/ext.rs b/distant-net/src/server/ext.rs deleted file mode 100644 index 52d8f3c..0000000 --- a/distant-net/src/server/ext.rs +++ /dev/null @@ -1,440 +0,0 @@ -use crate::{ - utils::Timer, GenericServerRef, Listener, Request, Response, Server, ServerConnection, - ServerCtx, ServerRef, ServerReply, ServerState, Shutdown, TypedAsyncRead, TypedAsyncWrite, -}; -use log::*; -use serde::{de::DeserializeOwned, Serialize}; -use std::{ - io, - sync::{Arc, Weak}, -}; -use tokio::sync::{mpsc, Mutex}; - -mod tcp; -pub use tcp::*; - -#[cfg(unix)] -mod unix; - -#[cfg(unix)] -pub use unix::*; - -#[cfg(windows)] -mod windows; - -#[cfg(windows)] -pub use windows::*; - -/// Extension trait to provide a reference implementation of starting a server -/// that will listen for new connections (exposed as [`TypedAsyncWrite`] and [`TypedAsyncRead`]) -/// and process them using the [`Server`] implementation -pub trait ServerExt { - type Request; - type Response; - - /// Start a new server using the provided listener - fn start(self, listener: L) -> io::Result> - where - L: Listener + 'static, - R: TypedAsyncRead> + Send + 'static, - W: TypedAsyncWrite> + Send + 'static; -} - -impl ServerExt for S -where - S: Server + Sync + 'static, - Req: DeserializeOwned + Send + Sync + 'static, - Res: Serialize + Send + 'static, - Data: Default + Send + Sync + 'static, -{ - type Request = Req; - type Response = Res; - - fn start(self, listener: L) -> io::Result> - where - L: Listener + 'static, - R: TypedAsyncRead> + Send + 'static, - W: TypedAsyncWrite> + Send + 'static, - { - let server = Arc::new(self); - let state = Arc::new(ServerState::new()); - - let task = tokio::spawn(task(server, Arc::clone(&state), listener)); - - Ok(Box::new(GenericServerRef { state, task })) - } -} - -async fn task(server: Arc, state: Arc, mut listener: L) -where - S: Server + Sync + 'static, - Req: DeserializeOwned + Send + Sync + 'static, - Res: Serialize + Send + 'static, - Data: Default + Send + Sync + 'static, - L: Listener + 'static, - R: TypedAsyncRead> + Send + 'static, - W: TypedAsyncWrite> + Send + 'static, -{ - // Grab a copy of our server's configuration so we can leverage it below - let config = server.config(); - - // Create the timer that will be used shutdown the server after duration elapsed - let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1); - - // NOTE: We do a manual map such that the shutdown sender is not captured and dropped when - // there is no shutdown after configured. This is because we need the future for the - // shutdown receiver to last forever in the event that there is no shutdown configured, - // not return immediately, which is what would happen if the sender was dropped. - #[allow(clippy::manual_map)] - let mut shutdown_timer = match config.shutdown { - // Create a timer, start it, and drop it so it will always happen - Shutdown::After(duration) => { - Timer::new(duration, async move { - let _ = shutdown_tx.send(()).await; - }) - .start(); - None - } - Shutdown::Lonely(duration) => Some(Timer::new(duration, async move { - let _ = shutdown_tx.send(()).await; - })), - Shutdown::Never => None, - }; - - if let Some(timer) = shutdown_timer.as_mut() { - info!( - "Server shutdown timer configured: {}s", - timer.duration().as_secs_f32() - ); - timer.start(); - } - - let mut shutdown_timer = shutdown_timer.map(|timer| Arc::new(Mutex::new(timer))); - - loop { - let server = Arc::clone(&server); - - // Receive a new connection, exiting if no longer accepting connections or if the shutdown - // signal has been received - let (mut writer, mut reader) = tokio::select! { - result = listener.accept() => { - match result { - Ok(x) => x, - Err(x) => { - error!("Server no longer accepting connections: {x}"); - if let Some(timer) = shutdown_timer.take() { - timer.lock().await.abort(); - } - break; - } - } - } - _ = shutdown_rx.recv() => { - info!( - "Server shutdown triggered after {}s", - config.shutdown.duration().unwrap_or_default().as_secs_f32(), - ); - break; - } - }; - - let mut connection = ServerConnection::new(); - let connection_id = connection.id; - let state = Arc::clone(&state); - - // Ensure that the shutdown timer is cancelled now that we have a connection - if let Some(timer) = shutdown_timer.as_ref() { - timer.lock().await.stop(); - } - - // Create some default data for the new connection and pass it - // to the callback prior to processing new requests - let local_data = { - let mut data = Data::default(); - server.on_accept(&mut data).await; - Arc::new(data) - }; - - // Start a writer task that reads from a channel and forwards all - // data through the writer - let (tx, mut rx) = mpsc::channel::>(1); - connection.writer_task = Some(tokio::spawn(async move { - while let Some(data) = rx.recv().await { - // Log our message as a string, which can be expensive - if log_enabled!(Level::Trace) { - trace!( - "[Conn {connection_id}] Sending {}", - &data - .to_vec() - .map(|x| String::from_utf8_lossy(&x).to_string()) - .unwrap_or_else(|_| "".to_string()) - ); - } - - if let Err(x) = writer.write(data).await { - error!("[Conn {connection_id}] Failed to send {x}"); - break; - } - } - })); - - // Start a reader task that reads requests and processes them - // using the provided handler - let weak_state = Arc::downgrade(&state); - let weak_shutdown_timer = shutdown_timer - .as_ref() - .map(Arc::downgrade) - .unwrap_or_default(); - connection.reader_task = Some(tokio::spawn(async move { - loop { - match reader.read().await { - Ok(Some(request)) => { - let reply = ServerReply { - origin_id: request.id.clone(), - tx: tx.clone(), - }; - - let ctx = ServerCtx { - connection_id, - request, - reply: reply.clone(), - local_data: Arc::clone(&local_data), - }; - - server.on_request(ctx).await; - } - Ok(None) => { - debug!("[Conn {connection_id}] Connection closed"); - - // Remove the connection from our state if it has closed - if let Some(state) = Weak::upgrade(&weak_state) { - state.connections.write().await.remove(&connection_id); - - // If we have no more connections, start the timer - if let Some(timer) = Weak::upgrade(&weak_shutdown_timer) { - if state.connections.read().await.is_empty() { - timer.lock().await.start(); - } - } - } - break; - } - Err(x) => { - // NOTE: We do NOT break out of the loop, as this could happen - // if someone sends bad data at any point, but does not - // mean that the reader itself has failed. This can - // happen from getting non-compliant typed data - error!("[Conn {connection_id}] {x}"); - } - } - } - })); - - state - .connections - .write() - .await - .insert(connection_id, connection); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - IntoSplit, MpscListener, MpscTransport, MpscTransportReadHalf, MpscTransportWriteHalf, - ServerConfig, - }; - use async_trait::async_trait; - use std::time::Duration; - - pub struct TestServer(ServerConfig); - - #[async_trait] - impl Server for TestServer { - type Request = u16; - type Response = String; - type LocalData = (); - - fn config(&self) -> ServerConfig { - self.0.clone() - } - - async fn on_request(&self, ctx: ServerCtx) { - // Always send back "hello" - ctx.reply.send("hello".to_string()).await.unwrap(); - } - } - - #[allow(clippy::type_complexity)] - fn make_listener( - buffer: usize, - ) -> ( - mpsc::Sender<( - MpscTransportWriteHalf>, - MpscTransportReadHalf>, - )>, - MpscListener<( - MpscTransportWriteHalf>, - MpscTransportReadHalf>, - )>, - ) { - MpscListener::channel(buffer) - } - - #[tokio::test] - async fn should_invoke_handler_upon_receiving_a_request() { - // Create a test listener where we will forward a connection - let (tx, listener) = make_listener(100); - - // Make bounded transport pair and send off one of them to act as our connection - let (mut transport, connection) = - MpscTransport::, Response>::pair(100); - tx.send(connection.into_split()) - .await - .expect("Failed to feed listener a connection"); - - let _server = ServerExt::start(TestServer(ServerConfig::default()), listener) - .expect("Failed to start server"); - - transport - .write(Request::new(123)) - .await - .expect("Failed to send request"); - - let response: Response = transport.read().await.unwrap().unwrap(); - assert_eq!(response.payload, "hello"); - } - - #[tokio::test] - async fn should_lonely_shutdown_if_no_connections_received_after_n_secs_when_config_set() { - let (_tx, listener) = make_listener(100); - - let server = ServerExt::start( - TestServer(ServerConfig { - shutdown: Shutdown::Lonely(Duration::from_millis(100)), - }), - listener, - ) - .expect("Failed to start server"); - - // Wait for some time - tokio::time::sleep(Duration::from_millis(300)).await; - - assert!(server.is_finished(), "Server shutdown not triggered!"); - } - - #[tokio::test] - async fn should_lonely_shutdown_if_last_connection_terminated_and_then_no_connections_after_n_secs( - ) { - // Create a test listener where we will forward a connection - let (tx, listener) = make_listener(100); - - // Make bounded transport pair and send off one of them to act as our connection - let (transport, connection) = MpscTransport::, Response>::pair(100); - tx.send(connection.into_split()) - .await - .expect("Failed to feed listener a connection"); - - let server = ServerExt::start( - TestServer(ServerConfig { - shutdown: Shutdown::Lonely(Duration::from_millis(100)), - }), - listener, - ) - .expect("Failed to start server"); - - // Drop the connection by dropping the transport - drop(transport); - - // Wait for some time - tokio::time::sleep(Duration::from_millis(300)).await; - - assert!(server.is_finished(), "Server shutdown not triggered!"); - } - - #[tokio::test] - async fn should_not_lonely_shutdown_as_long_as_a_connection_exists() { - // Create a test listener where we will forward a connection - let (tx, listener) = make_listener(100); - - // Make bounded transport pair and send off one of them to act as our connection - let (_transport, connection) = MpscTransport::, Response>::pair(100); - tx.send(connection.into_split()) - .await - .expect("Failed to feed listener a connection"); - - let server = ServerExt::start( - TestServer(ServerConfig { - shutdown: Shutdown::Lonely(Duration::from_millis(100)), - }), - listener, - ) - .expect("Failed to start server"); - - // Wait for some time - tokio::time::sleep(Duration::from_millis(300)).await; - - assert!(!server.is_finished(), "Server shutdown when it should not!"); - } - - #[tokio::test] - async fn should_shutdown_after_n_seconds_even_with_connections_if_config_set_to_after() { - let (tx, listener) = make_listener(100); - - // Make bounded transport pair and send off one of them to act as our connection - let (_transport, connection) = MpscTransport::, Response>::pair(100); - tx.send(connection.into_split()) - .await - .expect("Failed to feed listener a connection"); - - let server = ServerExt::start( - TestServer(ServerConfig { - shutdown: Shutdown::After(Duration::from_millis(100)), - }), - listener, - ) - .expect("Failed to start server"); - - // Wait for some time - tokio::time::sleep(Duration::from_millis(300)).await; - - assert!(server.is_finished(), "Server shutdown not triggered!"); - } - - #[tokio::test] - async fn should_shutdown_after_n_seconds_if_config_set_to_after() { - let (_tx, listener) = make_listener(100); - - let server = ServerExt::start( - TestServer(ServerConfig { - shutdown: Shutdown::After(Duration::from_millis(100)), - }), - listener, - ) - .expect("Failed to start server"); - - // Wait for some time - tokio::time::sleep(Duration::from_millis(300)).await; - - assert!(server.is_finished(), "Server shutdown not triggered!"); - } - - #[tokio::test] - async fn should_never_shutdown_if_config_set_to_never() { - let (_tx, listener) = make_listener(100); - - let server = ServerExt::start( - TestServer(ServerConfig { - shutdown: Shutdown::Never, - }), - listener, - ) - .expect("Failed to start server"); - - // Wait for some time - tokio::time::sleep(Duration::from_millis(300)).await; - - assert!(!server.is_finished(), "Server shutdown when it should not!"); - } -} diff --git a/distant-net/src/server/ext/tcp.rs b/distant-net/src/server/ext/tcp.rs deleted file mode 100644 index ff764cd..0000000 --- a/distant-net/src/server/ext/tcp.rs +++ /dev/null @@ -1,94 +0,0 @@ -use crate::{ - Codec, FramedTransport, IntoSplit, MappedListener, PortRange, Server, ServerExt, TcpListener, - TcpServerRef, -}; -use async_trait::async_trait; -use serde::{de::DeserializeOwned, Serialize}; -use std::{io, net::IpAddr}; - -/// Extension trait to provide a reference implementation of starting a TCP server -/// that will listen for new connections and process them using the [`Server`] implementation -#[async_trait] -pub trait TcpServerExt { - type Request; - type Response; - - /// Start a new server using the provided listener - async fn start(self, addr: IpAddr, port: P, codec: C) -> io::Result - where - P: Into + Send, - C: Codec + Send + Sync + 'static; -} - -#[async_trait] -impl TcpServerExt for S -where - S: Server + Sync + 'static, - Req: DeserializeOwned + Send + Sync + 'static, - Res: Serialize + Send + 'static, - Data: Default + Send + Sync + 'static, -{ - type Request = Req; - type Response = Res; - - async fn start(self, addr: IpAddr, port: P, codec: C) -> io::Result - where - P: Into + Send, - C: Codec + Send + Sync + 'static, - { - let listener = TcpListener::bind(addr, port).await?; - let port = listener.port(); - - let listener = MappedListener::new(listener, move |transport| { - let transport = FramedTransport::new(transport, codec.clone()); - transport.into_split() - }); - let inner = ServerExt::start(self, listener)?; - Ok(TcpServerRef { addr, port, inner }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{Client, PlainCodec, Request, ServerCtx, TcpClientExt}; - use std::net::{Ipv6Addr, SocketAddr}; - - pub struct TestServer; - - #[async_trait] - impl Server for TestServer { - type Request = String; - type Response = String; - type LocalData = (); - - async fn on_request(&self, ctx: ServerCtx) { - // Echo back what we received - ctx.reply - .send(ctx.request.payload.to_string()) - .await - .unwrap(); - } - } - - #[tokio::test] - async fn should_invoke_handler_upon_receiving_a_request() { - let server = - TcpServerExt::start(TestServer, IpAddr::V6(Ipv6Addr::LOCALHOST), 0, PlainCodec) - .await - .expect("Failed to start TCP server"); - - let mut client: Client = Client::connect( - SocketAddr::from((server.ip_addr(), server.port())), - PlainCodec, - ) - .await - .expect("Client failed to connect"); - - let response = client - .send(Request::new("hello".to_string())) - .await - .expect("Failed to send message"); - assert_eq!(response.payload, "hello"); - } -} diff --git a/distant-net/src/server/ext/unix.rs b/distant-net/src/server/ext/unix.rs deleted file mode 100644 index 1a2838f..0000000 --- a/distant-net/src/server/ext/unix.rs +++ /dev/null @@ -1,97 +0,0 @@ -use crate::{ - Codec, FramedTransport, IntoSplit, MappedListener, Server, ServerExt, UnixSocketListener, - UnixSocketServerRef, -}; -use async_trait::async_trait; -use serde::{de::DeserializeOwned, Serialize}; -use std::{io, path::Path}; - -/// Extension trait to provide a reference implementation of starting a Unix socket server -/// that will listen for new connections and process them using the [`Server`] implementation -#[async_trait] -pub trait UnixSocketServerExt { - type Request; - type Response; - - /// Start a new server using the provided listener - async fn start(self, path: P, codec: C) -> io::Result - where - P: AsRef + Send, - C: Codec + Send + Sync + 'static; -} - -#[async_trait] -impl UnixSocketServerExt for S -where - S: Server + Sync + 'static, - Req: DeserializeOwned + Send + Sync + 'static, - Res: Serialize + Send + 'static, - Data: Default + Send + Sync + 'static, -{ - type Request = Req; - type Response = Res; - - async fn start(self, path: P, codec: C) -> io::Result - where - P: AsRef + Send, - C: Codec + Send + Sync + 'static, - { - let path = path.as_ref(); - let listener = UnixSocketListener::bind(path).await?; - let path = listener.path().to_path_buf(); - - let listener = MappedListener::new(listener, move |transport| { - let transport = FramedTransport::new(transport, codec.clone()); - transport.into_split() - }); - let inner = ServerExt::start(self, listener)?; - Ok(UnixSocketServerRef { path, inner }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{Client, PlainCodec, Request, ServerCtx, UnixSocketClientExt}; - use tempfile::NamedTempFile; - - pub struct TestServer; - - #[async_trait] - impl Server for TestServer { - type Request = String; - type Response = String; - type LocalData = (); - - async fn on_request(&self, ctx: ServerCtx) { - // Echo back what we received - ctx.reply - .send(ctx.request.payload.to_string()) - .await - .unwrap(); - } - } - - #[tokio::test] - async fn should_invoke_handler_upon_receiving_a_request() { - // Generate a socket path and delete the file after so there is nothing there - let path = NamedTempFile::new() - .expect("Failed to create socket file") - .path() - .to_path_buf(); - - let server = UnixSocketServerExt::start(TestServer, path, PlainCodec) - .await - .expect("Failed to start Unix socket server"); - - let mut client: Client = Client::connect(server.path(), PlainCodec) - .await - .expect("Client failed to connect"); - - let response = client - .send(Request::new("hello".to_string())) - .await - .expect("Failed to send message"); - assert_eq!(response.payload, "hello"); - } -} diff --git a/distant-net/src/server/ext/windows.rs b/distant-net/src/server/ext/windows.rs deleted file mode 100644 index d2e8715..0000000 --- a/distant-net/src/server/ext/windows.rs +++ /dev/null @@ -1,109 +0,0 @@ -use crate::{ - Codec, FramedTransport, IntoSplit, MappedListener, Server, ServerExt, WindowsPipeListener, - WindowsPipeServerRef, -}; -use async_trait::async_trait; -use serde::{de::DeserializeOwned, Serialize}; -use std::{ - ffi::{OsStr, OsString}, - io, -}; - -/// Extension trait to provide a reference implementation of starting a Windows pipe server -/// that will listen for new connections and process them using the [`Server`] implementation -#[async_trait] -pub trait WindowsPipeServerExt { - type Request; - type Response; - - /// Start a new server at the specified address using the given codec - async fn start(self, addr: A, codec: C) -> io::Result - where - A: AsRef + Send, - C: Codec + Send + Sync + 'static; - - /// Start a new server at the specified address via `\\.\pipe\{name}` using the given codec - async fn start_local(self, name: N, codec: C) -> io::Result - where - Self: Sized, - N: AsRef + Send, - C: Codec + Send + Sync + 'static, - { - let mut addr = OsString::from(r"\\.\pipe\"); - addr.push(name.as_ref()); - self.start(addr, codec).await - } -} - -#[async_trait] -impl WindowsPipeServerExt for S -where - S: Server + Sync + 'static, - Req: DeserializeOwned + Send + Sync + 'static, - Res: Serialize + Send + 'static, - Data: Default + Send + Sync + 'static, -{ - type Request = Req; - type Response = Res; - - async fn start(self, addr: A, codec: C) -> io::Result - where - A: AsRef + Send, - C: Codec + Send + Sync + 'static, - { - let a = addr.as_ref(); - let listener = WindowsPipeListener::bind(a)?; - let addr = listener.addr().to_os_string(); - - let listener = MappedListener::new(listener, move |transport| { - let transport = FramedTransport::new(transport, codec.clone()); - transport.into_split() - }); - let inner = ServerExt::start(self, listener)?; - Ok(WindowsPipeServerRef { addr, inner }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{Client, PlainCodec, Request, ServerCtx, WindowsPipeClientExt}; - - pub struct TestServer; - - #[async_trait] - impl Server for TestServer { - type Request = String; - type Response = String; - type LocalData = (); - - async fn on_request(&self, ctx: ServerCtx) { - // Echo back what we received - ctx.reply - .send(ctx.request.payload.to_string()) - .await - .unwrap(); - } - } - - #[tokio::test] - async fn should_invoke_handler_upon_receiving_a_request() { - let server = WindowsPipeServerExt::start_local( - TestServer, - format!("test_pip_{}", rand::random::()), - PlainCodec, - ) - .await - .expect("Failed to start Windows pipe server"); - - let mut client: Client = Client::connect(server.addr(), PlainCodec) - .await - .expect("Client failed to connect"); - - let response = client - .send(Request::new("hello".to_string())) - .await - .expect("Failed to send message"); - assert_eq!(response.payload, "hello"); - } -} diff --git a/distant-net/src/server/ref.rs b/distant-net/src/server/ref.rs index 5359ece..a693d19 100644 --- a/distant-net/src/server/ref.rs +++ b/distant-net/src/server/ref.rs @@ -1,4 +1,5 @@ -use crate::{AsAny, ServerState}; +use super::ServerState; +use crate::common::AsAny; use log::*; use std::{ future::Future, @@ -12,9 +13,6 @@ use tokio::task::{JoinError, JoinHandle}; /// Interface to engage with a server instance pub trait ServerRef: AsAny + Send { - /// Returns a reference to the state of the server - fn state(&self) -> &ServerState; - /// Returns true if the server is no longer running fn is_finished(&self) -> bool; @@ -54,6 +52,14 @@ impl dyn ServerRef { ) -> Result, Box> { self.into_any().downcast::() } + + /// Waits for the server to complete by continuously polling the finished state. + pub async fn polling_wait(&self) -> io::Result<()> { + while !self.is_finished() { + tokio::time::sleep(Duration::from_millis(100)).await; + } + Ok(()) + } } /// Represents a generic reference to a server @@ -64,10 +70,6 @@ pub struct GenericServerRef { /// Runtime-specific implementation of [`ServerRef`] for a [`tokio::task::JoinHandle`] impl ServerRef for GenericServerRef { - fn state(&self) -> &ServerState { - &self.state - } - fn is_finished(&self) -> bool { self.task.is_finished() } diff --git a/distant-net/src/server/ref/tcp.rs b/distant-net/src/server/ref/tcp.rs index f042d92..80f4552 100644 --- a/distant-net/src/server/ref/tcp.rs +++ b/distant-net/src/server/ref/tcp.rs @@ -1,4 +1,4 @@ -use crate::{ServerRef, ServerState}; +use super::ServerRef; use std::net::IpAddr; /// Reference to a TCP server instance @@ -25,10 +25,6 @@ impl TcpServerRef { } impl ServerRef for TcpServerRef { - fn state(&self) -> &ServerState { - self.inner.state() - } - fn is_finished(&self) -> bool { self.inner.is_finished() } diff --git a/distant-net/src/server/ref/unix.rs b/distant-net/src/server/ref/unix.rs index 3e762a6..8642cea 100644 --- a/distant-net/src/server/ref/unix.rs +++ b/distant-net/src/server/ref/unix.rs @@ -1,4 +1,4 @@ -use crate::{ServerRef, ServerState}; +use super::ServerRef; use std::path::{Path, PathBuf}; /// Reference to a unix socket server instance @@ -24,10 +24,6 @@ impl UnixSocketServerRef { } impl ServerRef for UnixSocketServerRef { - fn state(&self) -> &ServerState { - self.inner.state() - } - fn is_finished(&self) -> bool { self.inner.is_finished() } diff --git a/distant-net/src/server/ref/windows.rs b/distant-net/src/server/ref/windows.rs index 6d0ee77..4c29762 100644 --- a/distant-net/src/server/ref/windows.rs +++ b/distant-net/src/server/ref/windows.rs @@ -1,4 +1,4 @@ -use crate::{ServerRef, ServerState}; +use super::ServerRef; use std::ffi::{OsStr, OsString}; /// Reference to a unix socket server instance @@ -24,10 +24,6 @@ impl WindowsPipeServerRef { } impl ServerRef for WindowsPipeServerRef { - fn state(&self) -> &ServerState { - self.inner.state() - } - fn is_finished(&self) -> bool { self.inner.is_finished() } diff --git a/distant-net/src/server/reply.rs b/distant-net/src/server/reply.rs index 0756ab0..eaa422f 100644 --- a/distant-net/src/server/reply.rs +++ b/distant-net/src/server/reply.rs @@ -1,4 +1,4 @@ -use crate::{Id, Response}; +use crate::common::{Id, Response}; use std::{future::Future, io, pin::Pin, sync::Arc}; use tokio::sync::{mpsc, Mutex}; diff --git a/distant-net/src/server/shutdown_timer.rs b/distant-net/src/server/shutdown_timer.rs new file mode 100644 index 0000000..fc83639 --- /dev/null +++ b/distant-net/src/server/shutdown_timer.rs @@ -0,0 +1,96 @@ +use super::Shutdown; +use crate::common::utils::Timer; +use log::*; +use std::time::Duration; +use tokio::sync::watch; + +/// Cloneable notification for when a [`ShutdownTimer`] has completed. +#[derive(Clone)] +pub(crate) struct ShutdownNotification(watch::Receiver<()>); + +impl ShutdownNotification { + /// Waits to receive a notification that the shutdown timer has concluded + pub async fn wait(&mut self) { + let _ = self.0.changed().await; + } +} + +/// Wrapper around [`Timer`] to support shutdown-specific notifications. +pub(crate) struct ShutdownTimer { + timer: Timer<()>, + watcher: ShutdownNotification, + shutdown: Shutdown, +} + +impl ShutdownTimer { + // Creates and starts the timer that will be used shutdown the server after duration elapsed. + pub fn start(shutdown: Shutdown) -> Self { + let (tx, rx) = watch::channel(()); + let mut timer = match shutdown { + // Create a timer that will complete after `duration`, dropping it to ensure that it + // will always happen no matter if stop/abort is called + Shutdown::After(duration) => { + info!( + "Server shutdown timer configured: terminate after {}s", + duration.as_secs_f32() + ); + Timer::new(duration, async move { + let _ = tx.send(()); + }) + } + + // Create a timer that will complete after `duration` + Shutdown::Lonely(duration) => { + info!( + "Server shutdown timer configured: terminate after no activity for {}s", + duration.as_secs_f32() + ); + Timer::new(duration, async move { + let _ = tx.send(()); + }) + } + + // Create a timer that will never complete (max timeout possible) so we hold on to the + // sender to avoid the receiver from completing + Shutdown::Never => { + info!("Server shutdown timer configured: never terminate"); + Timer::new(Duration::MAX, async move { + let _ = tx.send(()); + }) + } + }; + + timer.start(); + + Self { + timer, + watcher: ShutdownNotification(rx), + shutdown, + } + } + + /// Restarts the timer, doing nothing if the timer is already running + pub fn restart(&mut self) { + if !self.timer.is_running() { + self.timer.start(); + } + } + + /// Stops the timer, only applying if the timer is `lonely` + pub fn stop(&self) { + if let Shutdown::Lonely(_) = self.shutdown { + self.timer.stop(); + } + } + + /// Stops the timer completely by killing the internal callback task, meaning it can never be + /// started again + pub fn abort(&self) { + self.timer.abort(); + } + + /// Clones the notification + pub fn clone_notification(&self) -> ShutdownNotification { + self.watcher.clone() + } +} diff --git a/distant-net/src/server/state.rs b/distant-net/src/server/state.rs index 54553ea..bbd1a07 100644 --- a/distant-net/src/server/state.rs +++ b/distant-net/src/server/state.rs @@ -1,19 +1,33 @@ -use crate::{ConnectionId, ServerConnection}; +use super::ConnectionTask; +use crate::common::{authentication::Keychain, Backup, ConnectionId}; use std::collections::HashMap; -use tokio::sync::RwLock; +use tokio::sync::{oneshot, RwLock}; /// Contains all top-level state for the server pub struct ServerState { /// Mapping of connection ids to their transports - pub connections: RwLock>, + pub connections: RwLock>, + + /// Mapping of connection ids to (OTP, backup) + pub keychain: Keychain>, } impl ServerState { pub fn new() -> Self { Self { connections: RwLock::new(HashMap::new()), + keychain: Keychain::new(), } } + + /// Returns true if there is at least one active connection + pub async fn has_active_connections(&self) -> bool { + self.connections + .read() + .await + .values() + .any(|task| !task.is_finished()) + } } impl Default for ServerState { diff --git a/distant-net/src/transport.rs b/distant-net/src/transport.rs deleted file mode 100644 index acd7109..0000000 --- a/distant-net/src/transport.rs +++ /dev/null @@ -1,112 +0,0 @@ -use async_trait::async_trait; -use std::{io, marker::Unpin}; -use tokio::io::{AsyncRead, AsyncWrite}; - -/// Interface to split something into writing and reading halves -pub trait IntoSplit { - type Write; - type Read; - - fn into_split(self) -> (Self::Write, Self::Read); -} - -impl IntoSplit for (W, R) { - type Write = W; - type Read = R; - - fn into_split(self) -> (Self::Write, Self::Read) { - (self.0, self.1) - } -} - -/// Interface representing a transport of raw bytes into and out of the system -pub trait RawTransport: RawTransportRead + RawTransportWrite {} - -/// Interface representing a transport of raw bytes into the system -pub trait RawTransportRead: AsyncRead + Send + Unpin {} - -/// Interface representing a transport of raw bytes out of the system -pub trait RawTransportWrite: AsyncWrite + Send + Unpin {} - -/// Interface representing a transport of typed data into and out of the system -pub trait TypedTransport: TypedAsyncRead + TypedAsyncWrite {} - -/// Interface to read some structured data asynchronously -#[async_trait] -pub trait TypedAsyncRead { - /// Reads some data, returning `Some(T)` if available or `None` if the reader - /// has closed and no longer is providing data - async fn read(&mut self) -> io::Result>; -} - -#[async_trait] -impl TypedAsyncRead for (W, R) -where - W: Send, - R: TypedAsyncRead + Send, -{ - async fn read(&mut self) -> io::Result> { - self.1.read().await - } -} - -#[async_trait] -impl TypedAsyncRead for Box + Send> { - async fn read(&mut self) -> io::Result> { - (**self).read().await - } -} - -/// Interface to write some structured data asynchronously -#[async_trait] -pub trait TypedAsyncWrite { - async fn write(&mut self, data: T) -> io::Result<()>; -} - -#[async_trait] -impl TypedAsyncWrite for (W, R) -where - W: TypedAsyncWrite + Send, - R: Send, - T: Send + 'static, -{ - async fn write(&mut self, data: T) -> io::Result<()> { - self.0.write(data).await - } -} - -#[async_trait] -impl TypedAsyncWrite for Box + Send> { - async fn write(&mut self, data: T) -> io::Result<()> { - (**self).write(data).await - } -} - -mod router; - -mod framed; -pub use framed::*; - -mod inmemory; -pub use inmemory::*; - -mod mpsc; -pub use mpsc::*; - -mod tcp; -pub use tcp::*; - -#[cfg(unix)] -mod unix; - -#[cfg(unix)] -pub use unix::*; - -mod untyped; -pub use untyped::*; - -#[cfg(windows)] -mod windows; - -#[cfg(windows)] -pub use windows::*; diff --git a/distant-net/src/transport/framed.rs b/distant-net/src/transport/framed.rs deleted file mode 100644 index c6a99fa..0000000 --- a/distant-net/src/transport/framed.rs +++ /dev/null @@ -1,215 +0,0 @@ -use crate::{ - utils, Codec, IntoSplit, RawTransport, RawTransportRead, RawTransportWrite, UntypedTransport, - UntypedTransportRead, UntypedTransportWrite, -}; -use async_trait::async_trait; -use futures::{SinkExt, StreamExt}; -use log::*; -use serde::{de::DeserializeOwned, Serialize}; -use std::io; -use tokio_util::codec::{Framed, FramedRead, FramedWrite}; - -#[cfg(test)] -mod test; - -#[cfg(test)] -pub use test::*; - -mod read; -pub use read::*; - -mod write; -pub use write::*; - -/// Represents [`TypedTransport`] of data across the network using frames in order to support -/// typed messages instead of arbitrary bytes being sent across the wire. -/// -/// Note that this type does **not** implement [`RawTransport`] and instead acts as a wrapper -/// around a transport to provide a higher-level interface -#[derive(Debug)] -pub struct FramedTransport(Framed) -where - T: RawTransport, - C: Codec; - -impl FramedTransport -where - T: RawTransport, - C: Codec, -{ - /// Creates a new instance of the transport, wrapping the stream in a `Framed` - pub fn new(transport: T, codec: C) -> Self { - Self(Framed::new(transport, codec)) - } -} - -impl UntypedTransport for FramedTransport -where - T: RawTransport, - C: Codec + Send, -{ -} - -impl IntoSplit for FramedTransport -where - T: RawTransport + IntoSplit, - ::Read: RawTransportRead, - ::Write: RawTransportWrite, - C: Codec + Send, -{ - type Read = FramedTransportReadHalf<::Read, C>; - type Write = FramedTransportWriteHalf<::Write, C>; - - fn into_split(self) -> (Self::Write, Self::Read) { - let parts = self.0.into_parts(); - let (write_half, read_half) = parts.io.into_split(); - - // Create our split read half and populate its buffer with existing contents - let mut f_read = FramedRead::new(read_half, parts.codec.clone()); - *f_read.read_buffer_mut() = parts.read_buf; - - // Create our split write half and populate its buffer with existing contents - let mut f_write = FramedWrite::new(write_half, parts.codec); - *f_write.write_buffer_mut() = parts.write_buf; - - let read_half = FramedTransportReadHalf(f_read); - let write_half = FramedTransportWriteHalf(f_write); - - (write_half, read_half) - } -} - -#[async_trait] -impl UntypedTransportWrite for FramedTransport -where - T: RawTransport + Send, - C: Codec + Send, -{ - async fn write(&mut self, data: D) -> io::Result<()> - where - D: Serialize + Send + 'static, - { - // Serialize data into a byte stream - // NOTE: Cannot used packed implementation for now due to issues with deserialization - let data = utils::serialize_to_vec(&data)?; - - // Use underlying codec to send data (may encrypt, sign, etc.) - self.0.send(&data).await - } -} - -#[async_trait] -impl UntypedTransportRead for FramedTransport -where - T: RawTransport + Send, - C: Codec + Send, -{ - async fn read(&mut self) -> io::Result> - where - D: DeserializeOwned, - { - // Use underlying codec to receive data (may decrypt, validate, etc.) - if let Some(data) = self.0.next().await { - let data = data?; - - // Deserialize byte stream into our expected type - match utils::deserialize_from_slice(&data) { - Ok(data) => Ok(Some(data)), - Err(x) => { - error!("Invalid data: {}", String::from_utf8_lossy(&data)); - Err(x) - } - } - } else { - Ok(None) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{InmemoryTransport, PlainCodec}; - use serde::{Deserialize, Serialize}; - - #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] - pub struct TestData { - name: String, - value: usize, - } - - #[tokio::test] - async fn send_should_convert_data_into_byte_stream_and_send_through_stream() { - let (_tx, mut rx, stream) = InmemoryTransport::make(1); - let mut transport = FramedTransport::new(stream, PlainCodec::new()); - - let data = TestData { - name: String::from("test"), - value: 123, - }; - - let bytes = utils::serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - transport.write(data).await.unwrap(); - - let outgoing = rx.recv().await.unwrap(); - assert_eq!(outgoing, frame); - } - - #[tokio::test] - async fn receive_should_return_none_if_stream_is_closed() { - let (_, _, stream) = InmemoryTransport::make(1); - let mut transport = FramedTransport::new(stream, PlainCodec::new()); - - let result = transport.read::().await; - match result { - Ok(None) => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn receive_should_fail_if_unable_to_convert_to_type() { - let (tx, _rx, stream) = InmemoryTransport::make(1); - let mut transport = FramedTransport::new(stream, PlainCodec::new()); - - #[derive(Serialize, Deserialize)] - struct OtherTestData(usize); - - let data = OtherTestData(123); - let bytes = utils::serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - tx.send(frame).await.unwrap(); - let result = transport.read::().await; - assert!(result.is_err(), "Unexpectedly succeeded") - } - - #[tokio::test] - async fn receive_should_return_some_instance_of_type_when_coming_into_stream() { - let (tx, _rx, stream) = InmemoryTransport::make(1); - let mut transport = FramedTransport::new(stream, PlainCodec::new()); - - let data = TestData { - name: String::from("test"), - value: 123, - }; - - let bytes = utils::serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - tx.send(frame).await.unwrap(); - let received_data = transport.read::().await.unwrap().unwrap(); - assert_eq!(received_data, data); - } -} diff --git a/distant-net/src/transport/framed/read.rs b/distant-net/src/transport/framed/read.rs deleted file mode 100644 index 7d86c8d..0000000 --- a/distant-net/src/transport/framed/read.rs +++ /dev/null @@ -1,115 +0,0 @@ -use crate::{transport::framed::utils, Codec, UntypedTransportRead}; -use async_trait::async_trait; -use futures::StreamExt; -use log::*; -use serde::de::DeserializeOwned; -use std::io; -use tokio::io::AsyncRead; -use tokio_util::codec::FramedRead; - -/// Represents a transport of inbound data from the network using frames in order to support -/// typed messages instead of arbitrary bytes being sent across the wire. -/// -/// Note that this type does **not** implement [`AsyncRead`] and instead acts as a -/// wrapper to provide a higher-level interface -pub struct FramedTransportReadHalf(pub(super) FramedRead) -where - R: AsyncRead, - C: Codec; - -#[async_trait] -impl UntypedTransportRead for FramedTransportReadHalf -where - R: AsyncRead + Send + Unpin, - C: Codec + Send, -{ - async fn read(&mut self) -> io::Result> - where - D: DeserializeOwned, - { - // Use underlying codec to receive data (may decrypt, validate, etc.) - if let Some(data) = self.0.next().await { - let data = data?; - - // Deserialize byte stream into our expected type - match utils::deserialize_from_slice(&data) { - Ok(data) => Ok(Some(data)), - Err(x) => { - error!("Invalid data: {}", String::from_utf8_lossy(&data)); - Err(x) - } - } - } else { - Ok(None) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{FramedTransport, InmemoryTransport, IntoSplit, PlainCodec}; - use serde::{Deserialize, Serialize}; - - #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] - pub struct TestData { - name: String, - value: usize, - } - - #[tokio::test] - async fn receive_should_return_none_if_stream_is_closed() { - let (_, _, stream) = InmemoryTransport::make(1); - let transport = FramedTransport::new(stream, PlainCodec::new()); - let (_, mut reader) = transport.into_split(); - - let result = reader.read::().await; - match result { - Ok(None) => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn receive_should_fail_if_unable_to_convert_to_type() { - let (tx, _rx, stream) = InmemoryTransport::make(1); - let transport = FramedTransport::new(stream, PlainCodec::new()); - let (_, mut reader) = transport.into_split(); - - #[derive(Serialize, Deserialize)] - struct OtherTestData(usize); - - let data = OtherTestData(123); - let bytes = utils::serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - tx.send(frame).await.unwrap(); - let result = reader.read::().await; - assert!(result.is_err(), "Unexpectedly succeeded"); - } - - #[tokio::test] - async fn receive_should_return_some_instance_of_type_when_coming_into_stream() { - let (tx, _rx, stream) = InmemoryTransport::make(1); - let transport = FramedTransport::new(stream, PlainCodec::new()); - let (_, mut reader) = transport.into_split(); - - let data = TestData { - name: String::from("test"), - value: 123, - }; - - let bytes = utils::serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - tx.send(frame).await.unwrap(); - let received_data = reader.read::().await.unwrap().unwrap(); - assert_eq!(received_data, data); - } -} diff --git a/distant-net/src/transport/framed/test.rs b/distant-net/src/transport/framed/test.rs deleted file mode 100644 index cfae093..0000000 --- a/distant-net/src/transport/framed/test.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::{FramedTransport, InmemoryTransport, PlainCodec}; - -#[cfg(test)] -impl FramedTransport { - /// Makes a connected pair of framed inmemory transports with plain codec for testing purposes - pub fn make_test_pair() -> ( - FramedTransport, - FramedTransport, - ) { - Self::pair(100) - } -} diff --git a/distant-net/src/transport/framed/write.rs b/distant-net/src/transport/framed/write.rs deleted file mode 100644 index 56b3ab3..0000000 --- a/distant-net/src/transport/framed/write.rs +++ /dev/null @@ -1,72 +0,0 @@ -use crate::{transport::framed::utils, Codec, UntypedTransportWrite}; -use async_trait::async_trait; -use futures::SinkExt; -use serde::Serialize; -use std::io; -use tokio::io::AsyncWrite; -use tokio_util::codec::FramedWrite; - -/// Represents a transport of outbound data to the network using frames in order to support -/// typed messages instead of arbitrary bytes being sent across the wire. -/// -/// Note that this type does **not** implement [`AsyncWrite`] and instead acts as a -/// wrapper to provide a higher-level interface -pub struct FramedTransportWriteHalf(pub(super) FramedWrite) -where - W: AsyncWrite, - C: Codec; - -#[async_trait] -impl UntypedTransportWrite for FramedTransportWriteHalf -where - W: AsyncWrite + Send + Unpin, - C: Codec + Send, -{ - async fn write(&mut self, data: D) -> io::Result<()> - where - D: Serialize + Send + 'static, - { - // Serialize data into a byte stream - // NOTE: Cannot used packed implementation for now due to issues with deserialization - let data = utils::serialize_to_vec(&data)?; - - // Use underlying codec to send data (may encrypt, sign, etc.) - self.0.send(&data).await - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{FramedTransport, InmemoryTransport, IntoSplit, PlainCodec}; - use serde::{Deserialize, Serialize}; - - #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] - pub struct TestData { - name: String, - value: usize, - } - - #[tokio::test] - async fn send_should_convert_data_into_byte_stream_and_send_through_stream() { - let (_tx, mut rx, stream) = InmemoryTransport::make(1); - let transport = FramedTransport::new(stream, PlainCodec::new()); - let (mut wh, _) = transport.into_split(); - - let data = TestData { - name: String::from("test"), - value: 123, - }; - - let bytes = utils::serialize_to_vec(&data).unwrap(); - let len = (bytes.len() as u64).to_be_bytes(); - let mut frame = Vec::new(); - frame.extend(len.iter().copied()); - frame.extend(bytes); - - wh.write(data).await.unwrap(); - - let outgoing = rx.recv().await.unwrap(); - assert_eq!(outgoing, frame); - } -} diff --git a/distant-net/src/transport/inmemory.rs b/distant-net/src/transport/inmemory.rs deleted file mode 100644 index 81169b6..0000000 --- a/distant-net/src/transport/inmemory.rs +++ /dev/null @@ -1,225 +0,0 @@ -use crate::{ - FramedTransport, IntoSplit, PlainCodec, RawTransport, RawTransportRead, RawTransportWrite, -}; -use std::{ - io, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::{ - io::{AsyncRead, AsyncWrite, ReadBuf}, - sync::mpsc, -}; - -mod read; -pub use read::*; - -mod write; -pub use write::*; - -/// Represents a [`RawTransport`] comprised of two inmemory channels -#[derive(Debug)] -pub struct InmemoryTransport { - incoming: InmemoryTransportReadHalf, - outgoing: InmemoryTransportWriteHalf, -} - -impl InmemoryTransport { - pub fn new(incoming: mpsc::Receiver>, outgoing: mpsc::Sender>) -> Self { - Self { - incoming: InmemoryTransportReadHalf::new(incoming), - outgoing: InmemoryTransportWriteHalf::new(outgoing), - } - } - - /// Returns (incoming_tx, outgoing_rx, transport) - pub fn make(buffer: usize) -> (mpsc::Sender>, mpsc::Receiver>, Self) { - let (incoming_tx, incoming_rx) = mpsc::channel(buffer); - let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer); - - ( - incoming_tx, - outgoing_rx, - Self::new(incoming_rx, outgoing_tx), - ) - } - - /// Returns pair of transports that are connected such that one sends to the other and - /// vice versa - pub fn pair(buffer: usize) -> (Self, Self) { - let (tx, rx, transport) = Self::make(buffer); - (transport, Self::new(rx, tx)) - } -} - -impl RawTransport for InmemoryTransport {} -impl RawTransportRead for InmemoryTransport {} -impl RawTransportWrite for InmemoryTransport {} -impl IntoSplit for InmemoryTransport { - type Read = InmemoryTransportReadHalf; - type Write = InmemoryTransportWriteHalf; - - fn into_split(self) -> (Self::Write, Self::Read) { - (self.outgoing, self.incoming) - } -} - -impl AsyncRead for InmemoryTransport { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.incoming).poll_read(cx, buf) - } -} - -impl AsyncWrite for InmemoryTransport { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.outgoing).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.outgoing).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.outgoing).poll_shutdown(cx) - } -} - -impl FramedTransport { - /// Produces a pair of inmemory transports that are connected to each other using - /// a standard codec - /// - /// Sets the buffer for message passing for each underlying transport to the given buffer size - pub fn pair( - buffer: usize, - ) -> ( - FramedTransport, - FramedTransport, - ) { - let (a, b) = InmemoryTransport::pair(buffer); - let a = FramedTransport::new(a, PlainCodec::new()); - let b = FramedTransport::new(b, PlainCodec::new()); - (a, b) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - #[tokio::test] - async fn make_should_return_sender_that_sends_data_to_transport() { - let (tx, _, mut transport) = InmemoryTransport::make(3); - - tx.send(b"test msg 1".to_vec()).await.unwrap(); - tx.send(b"test msg 2".to_vec()).await.unwrap(); - tx.send(b"test msg 3".to_vec()).await.unwrap(); - - // Should get data matching a singular message - let mut buf = [0; 256]; - let len = transport.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 1"); - - // Next call would get the second message - let len = transport.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 2"); - - // When the last of the senders is dropped, we should still get - // the rest of the data that was sent first before getting - // an indicator that there is no more data - drop(tx); - - let len = transport.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 3"); - - let len = transport.read(&mut buf).await.unwrap(); - assert_eq!(len, 0, "Unexpectedly got more data"); - } - - #[tokio::test] - async fn make_should_return_receiver_that_receives_data_from_transport() { - let (_, mut rx, mut transport) = InmemoryTransport::make(3); - - transport.write_all(b"test msg 1").await.unwrap(); - transport.write_all(b"test msg 2").await.unwrap(); - transport.write_all(b"test msg 3").await.unwrap(); - - // Should get data matching a singular message - assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec())); - - // Next call would get the second message - assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec())); - - // When the transport is dropped, we should still get - // the rest of the data that was sent first before getting - // an indicator that there is no more data - drop(transport); - - assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec())); - - assert_eq!(rx.recv().await, None, "Unexpectedly got more data"); - } - - #[tokio::test] - async fn into_split_should_provide_a_read_half_that_receives_from_sender() { - let (tx, _, transport) = InmemoryTransport::make(3); - let (_, mut read_half) = transport.into_split(); - - tx.send(b"test msg 1".to_vec()).await.unwrap(); - tx.send(b"test msg 2".to_vec()).await.unwrap(); - tx.send(b"test msg 3".to_vec()).await.unwrap(); - - // Should get data matching a singular message - let mut buf = [0; 256]; - let len = read_half.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 1"); - - // Next call would get the second message - let len = read_half.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 2"); - - // When the last of the senders is dropped, we should still get - // the rest of the data that was sent first before getting - // an indicator that there is no more data - drop(tx); - - let len = read_half.read(&mut buf).await.unwrap(); - assert_eq!(&buf[..len], b"test msg 3"); - - let len = read_half.read(&mut buf).await.unwrap(); - assert_eq!(len, 0, "Unexpectedly got more data"); - } - - #[tokio::test] - async fn into_split_should_provide_a_write_half_that_sends_to_receiver() { - let (_, mut rx, transport) = InmemoryTransport::make(3); - let (mut write_half, _) = transport.into_split(); - - write_half.write_all(b"test msg 1").await.unwrap(); - write_half.write_all(b"test msg 2").await.unwrap(); - write_half.write_all(b"test msg 3").await.unwrap(); - - // Should get data matching a singular message - assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec())); - - // Next call would get the second message - assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec())); - - // When the transport is dropped, we should still get - // the rest of the data that was sent first before getting - // an indicator that there is no more data - drop(write_half); - - assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec())); - - assert_eq!(rx.recv().await, None, "Unexpectedly got more data"); - } -} diff --git a/distant-net/src/transport/inmemory/read.rs b/distant-net/src/transport/inmemory/read.rs deleted file mode 100644 index a05e0f9..0000000 --- a/distant-net/src/transport/inmemory/read.rs +++ /dev/null @@ -1,249 +0,0 @@ -use crate::RawTransportRead; -use futures::ready; -use std::{ - io, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::{ - io::{AsyncRead, ReadBuf}, - sync::mpsc, -}; - -/// Read portion of an inmemory channel -#[derive(Debug)] -pub struct InmemoryTransportReadHalf { - rx: mpsc::Receiver>, - overflow: Vec, -} - -impl InmemoryTransportReadHalf { - pub fn new(rx: mpsc::Receiver>) -> Self { - Self { - rx, - overflow: Vec::new(), - } - } -} - -impl RawTransportRead for InmemoryTransportReadHalf {} - -impl AsyncRead for InmemoryTransportReadHalf { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - // If we cannot fit any more into the buffer at the moment, we wait - if buf.remaining() == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "Cannot poll as buf.remaining() == 0", - ))); - } - - // If we have overflow from the last poll, put that in the buffer - if !self.overflow.is_empty() { - if self.overflow.len() > buf.remaining() { - let extra = self.overflow.split_off(buf.remaining()); - buf.put_slice(&self.overflow); - self.overflow = extra; - } else { - buf.put_slice(&self.overflow); - self.overflow.clear(); - } - - return Poll::Ready(Ok(())); - } - - // Otherwise, we poll for the next batch to read in - match ready!(self.rx.poll_recv(cx)) { - Some(mut x) => { - if x.len() > buf.remaining() { - self.overflow = x.split_off(buf.remaining()); - } - buf.put_slice(&x); - Poll::Ready(Ok(())) - } - None => Poll::Ready(Ok(())), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{InmemoryTransport, IntoSplit}; - use tokio::io::AsyncReadExt; - - #[tokio::test] - async fn read_half_should_fail_if_buf_has_no_space_remaining() { - let (_tx, _rx, transport) = InmemoryTransport::make(1); - let (_t_write, mut t_read) = transport.into_split(); - - let mut buf = [0u8; 0]; - match t_read.read(&mut buf).await { - Err(x) if x.kind() == io::ErrorKind::Other => {} - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn read_half_should_update_buf_with_all_overflow_from_last_read_if_it_all_fits() { - let (tx, _rx, transport) = InmemoryTransport::make(1); - let (_t_write, mut t_read) = transport.into_split(); - - tx.send(vec![1, 2, 3]).await.expect("Failed to send"); - - let mut buf = [0u8; 2]; - - // First, read part of the data (first two bytes) - match t_read.read(&mut buf).await { - Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]), - x => panic!("Unexpected result: {:?}", x), - } - - // Second, we send more data because the last message was placed in overflow - tx.send(vec![4, 5, 6]).await.expect("Failed to send"); - - // Third, read remainder of the overflow from first message (third byte) - match t_read.read(&mut buf).await { - Ok(n) if n == 1 => assert_eq!(&buf[..n], &[3]), - x => panic!("Unexpected result: {:?}", x), - } - - // Fourth, verify that we start to receive the next overflow - match t_read.read(&mut buf).await { - Ok(n) if n == 2 => assert_eq!(&buf[..n], &[4, 5]), - x => panic!("Unexpected result: {:?}", x), - } - - // Fifth, verify that we get the last bit of overflow - match t_read.read(&mut buf).await { - Ok(n) if n == 1 => assert_eq!(&buf[..n], &[6]), - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn read_half_should_update_buf_with_some_of_overflow_that_can_fit() { - let (tx, _rx, transport) = InmemoryTransport::make(1); - let (_t_write, mut t_read) = transport.into_split(); - - tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); - - let mut buf = [0u8; 2]; - - // First, read part of the data (first two bytes) - match t_read.read(&mut buf).await { - Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]), - x => panic!("Unexpected result: {:?}", x), - } - - // Second, we send more data because the last message was placed in overflow - tx.send(vec![6]).await.expect("Failed to send"); - - // Third, read next chunk of the overflow from first message (next two byte) - match t_read.read(&mut buf).await { - Ok(n) if n == 2 => assert_eq!(&buf[..n], &[3, 4]), - x => panic!("Unexpected result: {:?}", x), - } - - // Fourth, read last chunk of the overflow from first message (fifth byte) - match t_read.read(&mut buf).await { - Ok(n) if n == 1 => assert_eq!(&buf[..n], &[5]), - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn read_half_should_update_buf_with_all_of_inner_channel_when_it_fits() { - let (tx, _rx, transport) = InmemoryTransport::make(1); - let (_t_write, mut t_read) = transport.into_split(); - - let mut buf = [0u8; 5]; - - tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); - - // First, read all of data that fits exactly - match t_read.read(&mut buf).await { - Ok(n) if n == 5 => assert_eq!(&buf[..n], &[1, 2, 3, 4, 5]), - x => panic!("Unexpected result: {:?}", x), - } - - tx.send(vec![6, 7, 8]).await.expect("Failed to send"); - - // Second, read data that fits within buf - match t_read.read(&mut buf).await { - Ok(n) if n == 3 => assert_eq!(&buf[..n], &[6, 7, 8]), - x => panic!("Unexpected result: {:?}", x), - } - } - - #[tokio::test] - async fn read_half_should_update_buf_with_some_of_inner_channel_that_can_fit_and_add_rest_to_overflow( - ) { - let (tx, _rx, transport) = InmemoryTransport::make(1); - let (_t_write, mut t_read) = transport.into_split(); - - let mut buf = [0u8; 1]; - - tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); - - // Attempt a read that places more in overflow - match t_read.read(&mut buf).await { - Ok(n) if n == 1 => assert_eq!(&buf[..n], &[1]), - x => panic!("Unexpected result: {:?}", x), - } - - // Verify overflow contains the rest - assert_eq!(&t_read.overflow, &[2, 3, 4, 5]); - - // Queue up extra data that will not be read until overflow is finished - tx.send(vec![6, 7, 8]).await.expect("Failed to send"); - - // Read next data point - match t_read.read(&mut buf).await { - Ok(n) if n == 1 => assert_eq!(&buf[..n], &[2]), - x => panic!("Unexpected result: {:?}", x), - } - - // Verify overflow contains the rest without having added extra data - assert_eq!(&t_read.overflow, &[3, 4, 5]); - } - - #[tokio::test] - async fn read_half_should_yield_pending_if_no_data_available_on_inner_channel() { - let (_tx, _rx, transport) = InmemoryTransport::make(1); - let (_t_write, mut t_read) = transport.into_split(); - - let mut buf = [0u8; 1]; - - // Attempt a read that should yield ok with no change, which is what should - // happen when nothing is read into buf - let f = t_read.read(&mut buf); - tokio::pin!(f); - match futures::poll!(f) { - Poll::Pending => {} - x => panic!("Unexpected poll result: {:?}", x), - } - } - - #[tokio::test] - async fn read_half_should_not_update_buf_if_inner_channel_closed() { - let (tx, _rx, transport) = InmemoryTransport::make(1); - let (_t_write, mut t_read) = transport.into_split(); - - let mut buf = [0u8; 1]; - - // Drop the channel that would be sending data to the transport - drop(tx); - - // Attempt a read that should yield ok with no change, which is what should - // happen when nothing is read into buf - match t_read.read(&mut buf).await { - Ok(n) if n == 0 => assert_eq!(&buf, &[0]), - x => panic!("Unexpected result: {:?}", x), - } - } -} diff --git a/distant-net/src/transport/inmemory/write.rs b/distant-net/src/transport/inmemory/write.rs deleted file mode 100644 index ae96e74..0000000 --- a/distant-net/src/transport/inmemory/write.rs +++ /dev/null @@ -1,147 +0,0 @@ -use crate::RawTransportWrite; -use futures::ready; -use std::{ - fmt, - future::Future, - io, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::{io::AsyncWrite, sync::mpsc}; - -/// Write portion of an inmemory channel -pub struct InmemoryTransportWriteHalf { - tx: Option>>, - task: Option> + Send + Sync + 'static>>>, -} - -impl InmemoryTransportWriteHalf { - pub fn new(tx: mpsc::Sender>) -> Self { - Self { - tx: Some(tx), - task: None, - } - } -} - -impl fmt::Debug for InmemoryTransportWriteHalf { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("InmemoryTransportWrite") - .field("tx", &self.tx) - .field( - "task", - &if self.tx.is_some() { - "Some(...)" - } else { - "None" - }, - ) - .finish() - } -} - -impl RawTransportWrite for InmemoryTransportWriteHalf {} - -impl AsyncWrite for InmemoryTransportWriteHalf { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - loop { - match self.task.as_mut() { - Some(task) => { - let res = ready!(task.as_mut().poll(cx)); - self.task.take(); - return Poll::Ready(res); - } - None => match self.tx.as_mut() { - Some(tx) => { - let n = buf.len(); - let tx_2 = tx.clone(); - let data = buf.to_vec(); - let task = - Box::pin(async move { tx_2.send(data).await.map(|_| n).or(Ok(0)) }); - self.task.replace(task); - } - None => return Poll::Ready(Ok(0)), - }, - } - } - } - - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - self.tx.take(); - self.task.take(); - Poll::Ready(Ok(())) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{InmemoryTransport, IntoSplit}; - use tokio::io::AsyncWriteExt; - - #[tokio::test] - async fn write_half_should_return_buf_len_if_can_send_immediately() { - let (_tx, mut rx, transport) = InmemoryTransport::make(1); - let (mut t_write, _t_read) = transport.into_split(); - - // Write that is not waiting should always succeed with full contents - let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write"); - assert_eq!(n, 3, "Unexpected byte count returned"); - - // Verify we actually had the data sent - let data = rx.try_recv().expect("Failed to recv data"); - assert_eq!(data, &[1, 2, 3]); - } - - #[tokio::test] - async fn write_half_should_return_support_eventually_sending_by_retrying_when_not_ready() { - let (_tx, mut rx, transport) = InmemoryTransport::make(1); - let (mut t_write, _t_read) = transport.into_split(); - - // Queue a write already so that we block on the next one - let _ = t_write.write(&[1, 2, 3]).await.expect("Failed to write"); - - // Verify that the next write is pending - let f = t_write.write(&[4, 5]); - tokio::pin!(f); - match futures::poll!(&mut f) { - Poll::Pending => {} - x => panic!("Unexpected poll result: {:?}", x), - } - - // Consume first batch of data so future of second can continue - let data = rx.try_recv().expect("Failed to recv data"); - assert_eq!(data, &[1, 2, 3]); - - // Verify that poll now returns success - match futures::poll!(f) { - Poll::Ready(Ok(n)) if n == 2 => {} - x => panic!("Unexpected poll result: {:?}", x), - } - - // Consume second batch of data - let data = rx.try_recv().expect("Failed to recv data"); - assert_eq!(data, &[4, 5]); - } - - #[tokio::test] - async fn write_half_should_zero_if_inner_channel_closed() { - let (_tx, rx, transport) = InmemoryTransport::make(1); - let (mut t_write, _t_read) = transport.into_split(); - - // Drop receiving end that transport would talk to - drop(rx); - - // Channel is dropped, so return 0 to indicate no bytes sent - let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write"); - assert_eq!(n, 0, "Unexpected byte count returned"); - } -} diff --git a/distant-net/src/transport/mpsc.rs b/distant-net/src/transport/mpsc.rs deleted file mode 100644 index bd50473..0000000 --- a/distant-net/src/transport/mpsc.rs +++ /dev/null @@ -1,66 +0,0 @@ -use crate::{IntoSplit, TypedAsyncRead, TypedAsyncWrite, TypedTransport}; -use async_trait::async_trait; -use std::io; -use tokio::sync::mpsc; - -mod read; -pub use read::*; - -mod write; -pub use write::*; - -/// Represents a [`TypedTransport`] of data across the network that uses [`mpsc::Sender`] and -/// [`mpsc::Receiver`] underneath. -#[derive(Debug)] -pub struct MpscTransport { - outbound: MpscTransportWriteHalf, - inbound: MpscTransportReadHalf, -} - -impl MpscTransport { - pub fn new(outbound: mpsc::Sender, inbound: mpsc::Receiver) -> Self { - Self { - outbound: MpscTransportWriteHalf::new(outbound), - inbound: MpscTransportReadHalf::new(inbound), - } - } - - /// Creates a pair of connected transports using `buffer` as maximum - /// channel capacity for each - pub fn pair(buffer: usize) -> (MpscTransport, MpscTransport) { - let (t_tx, t_rx) = mpsc::channel(buffer); - let (u_tx, u_rx) = mpsc::channel(buffer); - ( - MpscTransport::new(t_tx, u_rx), - MpscTransport::new(u_tx, t_rx), - ) - } -} - -impl TypedTransport for MpscTransport {} - -#[async_trait] -impl TypedAsyncWrite for MpscTransport { - async fn write(&mut self, data: T) -> io::Result<()> { - self.outbound - .write(data) - .await - .map_err(|x| io::Error::new(io::ErrorKind::Other, x)) - } -} - -#[async_trait] -impl TypedAsyncRead for MpscTransport { - async fn read(&mut self) -> io::Result> { - self.inbound.read().await - } -} - -impl IntoSplit for MpscTransport { - type Read = MpscTransportReadHalf; - type Write = MpscTransportWriteHalf; - - fn into_split(self) -> (Self::Write, Self::Read) { - (self.outbound, self.inbound) - } -} diff --git a/distant-net/src/transport/mpsc/read.rs b/distant-net/src/transport/mpsc/read.rs deleted file mode 100644 index da7e41f..0000000 --- a/distant-net/src/transport/mpsc/read.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::TypedAsyncRead; -use async_trait::async_trait; -use std::io; -use tokio::sync::mpsc; - -#[derive(Debug)] -pub struct MpscTransportReadHalf { - rx: mpsc::Receiver, -} - -impl MpscTransportReadHalf { - pub fn new(rx: mpsc::Receiver) -> Self { - Self { rx } - } -} - -#[async_trait] -impl TypedAsyncRead for MpscTransportReadHalf { - async fn read(&mut self) -> io::Result> { - Ok(self.rx.recv().await) - } -} diff --git a/distant-net/src/transport/mpsc/write.rs b/distant-net/src/transport/mpsc/write.rs deleted file mode 100644 index 7801268..0000000 --- a/distant-net/src/transport/mpsc/write.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::TypedAsyncWrite; -use async_trait::async_trait; -use std::io; -use tokio::sync::mpsc; - -#[derive(Debug)] -pub struct MpscTransportWriteHalf { - tx: mpsc::Sender, -} - -impl MpscTransportWriteHalf { - pub fn new(tx: mpsc::Sender) -> Self { - Self { tx } - } -} - -#[async_trait] -impl TypedAsyncWrite for MpscTransportWriteHalf { - async fn write(&mut self, data: T) -> io::Result<()> { - self.tx - .send(data) - .await - .map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string())) - } -} diff --git a/distant-net/src/transport/router.rs b/distant-net/src/transport/router.rs deleted file mode 100644 index 4f1e4c4..0000000 --- a/distant-net/src/transport/router.rs +++ /dev/null @@ -1,370 +0,0 @@ -/// Creates a new struct around a [`UntypedTransport`](crate::UntypedTransport) that routes incoming -/// and outgoing messages to different transports, enabling the ability to transform a singular -/// transport into multiple typed transports that can be combined with [`Client`](crate::Client) -/// and [`Server`](crate::Server) to mix having a variety of clients and servers available on the -/// same underlying [`UntypedTransport`](crate::UntypedTransport). -/// -/// ```no_run -/// use distant_net::router; -/// -/// # // To send, the data needs to be serializable -/// # // To receive, the data needs to be deserializable -/// # #[derive(serde::Serialize, serde::Deserialize)] -/// # struct CustomData(u8, u8); -/// -/// // Create a router that produces three transports from one: -/// // 1. `Transport` - receives `String` and sends `u8` -/// // 2. `Transport` - receives `CustomData` and sends `bool` -/// // 3. `Transport, u8>` - receives `u8` and sends `Option` -/// router!(TestRouter { -/// one: String => u8, -/// two: CustomData => bool, -/// three: u8 => Option, -/// }); -/// -/// router!( -/// #[router(inbound = 10, outbound = 20)] -/// TestRouterWithCustomBounds { -/// one: String => u8, -/// two: CustomData => bool, -/// three: u8 => Option, -/// } -/// ); -/// -/// # let (transport, _) = distant_net::FramedTransport::pair(1); -/// -/// let router = TestRouter::new(transport); -/// -/// let one = router.one; // MpscTransport -/// let two = router.two; // MpscTransport -/// let three = router.three; // MpscTransport, u8> -/// ``` -#[macro_export] -macro_rules! router { - ( - $(#[router($($mname:ident = $mvalue:literal),*)])? - $vis:vis $name:ident { - $($transport:ident : $res_ty:ty => $req_ty:ty),+ $(,)? - } - ) => { - $crate::paste::paste! { - #[doc = "Implements a message router for splitting out transport messages"] - #[allow(dead_code)] - $vis struct $name { - reader_task: tokio::task::JoinHandle<()>, - writer_task: tokio::task::JoinHandle<()>, - $( - pub $transport: $crate::MpscTransport<$req_ty, $res_ty>, - )+ - } - - #[allow(dead_code)] - impl $name { - /// Returns the size of the inbound buffer used by this router - pub const fn inbound_buffer_size() -> usize { - Self::buffer_sizes().0 - } - - /// Returns the size of the outbound buffer used by this router - pub const fn outbound_buffer_size() -> usize { - Self::buffer_sizes().1 - } - - /// Returns the size of the inbound and outbound buffers used by this router - /// in the form of `(inbound, outbound)` - pub const fn buffer_sizes() -> (usize, usize) { - // Set defaults for inbound and outbound buffer sizes - let _inbound = 10000; - let _outbound = 10000; - - $($( - let [<_ $mname:snake>] = $mvalue; - )*)? - - (_inbound, _outbound) - } - - #[doc = "Creates a new instance of [`" $name "`]"] - pub fn new(split: T) -> Self - where - T: $crate::IntoSplit, - W: $crate::UntypedTransportWrite + 'static, - R: $crate::UntypedTransportRead + 'static, - { - let (writer, reader) = split.into_split(); - Self::from_writer_and_reader(writer, reader) - } - - #[doc = "Creates a new instance of [`" $name "`] from the given writer and reader"] - pub fn from_writer_and_reader(mut writer: W, mut reader: R) -> Self - where - W: $crate::UntypedTransportWrite + 'static, - R: $crate::UntypedTransportRead + 'static, - { - - $( - let ( - [<$transport:snake _inbound_tx>], - [<$transport:snake _inbound_rx>] - ) = tokio::sync::mpsc::channel(Self::inbound_buffer_size()); - let ( - [<$transport:snake _outbound_tx>], - mut [<$transport:snake _outbound_rx>] - ) = tokio::sync::mpsc::channel(Self::outbound_buffer_size()); - let [<$transport:snake>]: $crate::MpscTransport<$req_ty, $res_ty> = - $crate::MpscTransport::new( - [<$transport:snake _outbound_tx>], - [<$transport:snake _inbound_rx>] - ); - )+ - - #[derive(serde::Deserialize)] - #[serde(untagged)] - enum [<$name:camel In>] { - $([<$transport:camel>]($res_ty)),+ - } - - let reader_task = tokio::spawn(async move { - loop { - match $crate::UntypedTransportRead::read(&mut reader).await { - $( - Ok(Some([<$name:camel In>]::[<$transport:camel>](x))) => { - if let Err(x) = [<$transport:snake _inbound_tx>].send(x).await { - $crate::log::error!( - "Failed to forward received data from {} of {}: {}", - std::stringify!($transport), - std::stringify!($name), - x - ); - } - } - )+ - - // Quit if the reader no longer has data - // NOTE: Compiler says this is unreachable, but it is? - #[allow(unreachable_patterns)] - Ok(None) => { - $crate::log::trace!( - "Router {} has closed", - std::stringify!($name), - ); - break; - } - - // Drop any received data that does not map to something - // NOTE: Compiler says this is unreachable, but it is? - #[allow(unreachable_patterns)] - Err(x) => { - $crate::log::error!( - "Failed to read from any transport of {}: {}", - std::stringify!($name), - x - ); - continue; - } - } - } - }); - - let writer_task = tokio::spawn(async move { - loop { - tokio::select! { - $( - Some(x) = [<$transport:snake _outbound_rx>].recv() => { - if let Err(x) = $crate::UntypedTransportWrite::write( - &mut writer, - x, - ).await { - $crate::log::error!( - "Failed to write to {} of {}: {}", - std::stringify!($transport), - std::stringify!($name), - x - ); - } - } - )+ - else => break, - } - } - }); - - Self { - reader_task, - writer_task, - $([<$transport:snake>]),+ - } - } - - pub fn abort(&self) { - self.reader_task.abort(); - self.writer_task.abort(); - } - - pub fn is_finished(&self) -> bool { - self.reader_task.is_finished() && self.writer_task.is_finished() - } - } - } - }; -} - -#[cfg(test)] -mod tests { - use crate::{FramedTransport, TypedAsyncRead, TypedAsyncWrite}; - use serde::{Deserialize, Serialize}; - - // NOTE: Must implement deserialize for our router, - // but we also need serialize to send for our test - #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] - struct CustomData(u8, String); - - // Creates a private `TestRouter` implementation - // - // 1. Transport receiving `CustomData` and sending `String` - // 2. Transport receiving `String` and sending `u8` - // 3. Transport receiving `bool` and sending `bool` - // 4. Transport receiving `Result` and sending `Option` - router!(TestRouter { - one: CustomData => String, - two: String => u8, - three: bool => bool, - should_compile: Result => Option, - }); - - #[test] - fn router_buffer_sizes_should_support_being_overridden() { - router!(DefaultSizes { data: u8 => u8 }); - router!(#[router(inbound = 5)] CustomInboundSize { data: u8 => u8 }); - router!(#[router(outbound = 5)] CustomOutboundSize { data: u8 => u8 }); - router!(#[router(inbound = 5, outbound = 6)] CustomSizes { data: u8 => u8 }); - - assert_eq!(DefaultSizes::buffer_sizes(), (10000, 10000)); - assert_eq!(DefaultSizes::inbound_buffer_size(), 10000); - assert_eq!(DefaultSizes::outbound_buffer_size(), 10000); - - assert_eq!(CustomInboundSize::buffer_sizes(), (5, 10000)); - assert_eq!(CustomInboundSize::inbound_buffer_size(), 5); - assert_eq!(CustomInboundSize::outbound_buffer_size(), 10000); - - assert_eq!(CustomOutboundSize::buffer_sizes(), (10000, 5)); - assert_eq!(CustomOutboundSize::inbound_buffer_size(), 10000); - assert_eq!(CustomOutboundSize::outbound_buffer_size(), 5); - - assert_eq!(CustomSizes::buffer_sizes(), (5, 6)); - assert_eq!(CustomSizes::inbound_buffer_size(), 5); - assert_eq!(CustomSizes::outbound_buffer_size(), 6); - } - - #[tokio::test] - async fn router_should_wire_transports_to_distinguish_incoming_data() { - let (t1, mut t2) = FramedTransport::make_test_pair(); - let TestRouter { - mut one, - mut two, - mut three, - .. - } = TestRouter::new(t1); - - // Send some data of different types that these transports expect - t2.write(false).await.unwrap(); - t2.write("hello world".to_string()).await.unwrap(); - t2.write(CustomData(123, "goodbye world".to_string())) - .await - .unwrap(); - - // Get that data through the appropriate transport - let data = one.read().await.unwrap().unwrap(); - assert_eq!( - data, - CustomData(123, "goodbye world".to_string()), - "string_custom_data_transport got unexpected result" - ); - - let data = two.read().await.unwrap().unwrap(); - assert_eq!( - data, "hello world", - "u8_string_transport got unexpected result" - ); - - let data = three.read().await.unwrap().unwrap(); - assert!(!data, "bool_bool_transport got unexpected result"); - } - - #[tokio::test] - async fn router_should_wire_transports_to_ignore_unknown_incoming_data() { - let (t1, mut t2) = FramedTransport::make_test_pair(); - let TestRouter { - mut one, mut two, .. - } = TestRouter::new(t1); - - #[derive(Serialize, Deserialize)] - struct UnknownData(char, u8); - - // Send some known and unknown data - t2.write("hello world".to_string()).await.unwrap(); - t2.write(UnknownData('a', 99)).await.unwrap(); - t2.write(CustomData(123, "goodbye world".to_string())) - .await - .unwrap(); - - // Get that data through the appropriate transport - let data = one.read().await.unwrap().unwrap(); - assert_eq!( - data, - CustomData(123, "goodbye world".to_string()), - "string_custom_data_transport got unexpected result" - ); - - let data = two.read().await.unwrap().unwrap(); - assert_eq!( - data, "hello world", - "u8_string_transport got unexpected result" - ); - } - - #[tokio::test] - async fn router_should_wire_transports_to_relay_outgoing_data() { - let (t1, mut t2) = FramedTransport::make_test_pair(); - let TestRouter { - mut one, - mut two, - mut three, - .. - } = TestRouter::new(t1); - - // NOTE: Introduce a sleep between each send, otherwise we are - // resolving futures in a way where the ordering may - // get mixed up on the way out - async fn wait() { - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - } - - // Send some data of different types that these transports expect - three.write(true).await.unwrap(); - wait().await; - two.write(123).await.unwrap(); - wait().await; - one.write("hello world".to_string()).await.unwrap(); - - // All of that data should funnel through our primary transport, - // but the order is NOT guaranteed! So we need to store - let data: bool = t2.read().await.unwrap().unwrap(); - assert!( - data, - "Unexpected data received from bool_bool_transport output" - ); - - let data: u8 = t2.read().await.unwrap().unwrap(); - assert_eq!( - data, 123, - "Unexpected data received from u8_string_transport output" - ); - - let data: String = t2.read().await.unwrap().unwrap(); - assert_eq!( - data, "hello world", - "Unexpected data received from string_custom_data_transport output" - ); - } -} diff --git a/distant-net/src/transport/tcp.rs b/distant-net/src/transport/tcp.rs deleted file mode 100644 index 909e1e2..0000000 --- a/distant-net/src/transport/tcp.rs +++ /dev/null @@ -1,196 +0,0 @@ -use crate::{IntoSplit, RawTransport, RawTransportRead, RawTransportWrite}; -use std::{ - fmt, io, - net::IpAddr, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::{ - io::{AsyncRead, AsyncWrite, ReadBuf}, - net::{ - tcp::{OwnedReadHalf, OwnedWriteHalf}, - TcpStream, ToSocketAddrs, - }, -}; - -/// Represents a [`RawTransport`] that leverages a TCP stream -pub struct TcpTransport { - pub(crate) addr: IpAddr, - pub(crate) port: u16, - pub(crate) inner: TcpStream, -} - -impl TcpTransport { - /// Creates a new stream by connecting to a remote machine at the specified - /// IP address and port - pub async fn connect(addrs: impl ToSocketAddrs) -> io::Result { - let stream = TcpStream::connect(addrs).await?; - let addr = stream.peer_addr()?; - Ok(Self { - addr: addr.ip(), - port: addr.port(), - inner: stream, - }) - } - - /// Returns the IP address that the stream is connected to - pub fn ip_addr(&self) -> IpAddr { - self.addr - } - - /// Returns the port that the stream is connected to - pub fn port(&self) -> u16 { - self.port - } -} - -impl fmt::Debug for TcpTransport { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TcpTransport") - .field("addr", &self.addr) - .field("port", &self.port) - .finish() - } -} - -impl RawTransport for TcpTransport {} -impl RawTransportRead for TcpTransport {} -impl RawTransportWrite for TcpTransport {} - -impl RawTransportRead for OwnedReadHalf {} -impl RawTransportWrite for OwnedWriteHalf {} - -impl IntoSplit for TcpTransport { - type Read = OwnedReadHalf; - type Write = OwnedWriteHalf; - - fn into_split(self) -> (Self::Write, Self::Read) { - let (r, w) = TcpStream::into_split(self.inner); - (w, r) - } -} - -impl AsyncRead for TcpTransport { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_read(cx, buf) - } -} - -impl AsyncWrite for TcpTransport { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.inner).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::net::{Ipv6Addr, SocketAddr}; - use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::TcpListener, - sync::oneshot, - task::JoinHandle, - }; - - async fn find_ephemeral_addr() -> SocketAddr { - // Start a listener on a distinct port, get its port, and kill it - // NOTE: This is a race condition as something else could bind to - // this port inbetween us killing it and us attempting to - // connect to it. We're willing to take that chance - let addr = IpAddr::V6(Ipv6Addr::LOCALHOST); - - let listener = TcpListener::bind((addr, 0)) - .await - .expect("Failed to bind on an ephemeral port"); - - let port = listener - .local_addr() - .expect("Failed to look up ephemeral port") - .port(); - - SocketAddr::from((addr, port)) - } - - #[tokio::test] - async fn should_fail_to_connect_if_nothing_listening() { - let addr = find_ephemeral_addr().await; - - // Now this should fail as we've stopped what was listening - TcpTransport::connect(addr).await.expect_err(&format!( - "Unexpectedly succeeded in connecting to ghost address: {}", - addr - )); - } - - #[tokio::test] - async fn should_be_able_to_send_and_receive_data() { - let (tx, rx) = oneshot::channel(); - - // Spawn a task that will wait for a connection, send data, - // and receive data that it will return in the task - let task: JoinHandle> = tokio::spawn(async move { - let addr = find_ephemeral_addr().await; - - // Start listening at the distinct address - let listener = TcpListener::bind(addr).await?; - - // Send the address back to our main test thread - tx.send(addr) - .map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string()))?; - - // Get the connection - let (mut conn, _) = listener.accept().await?; - - // Send some data to the connection (10 bytes) - conn.write_all(b"hello conn").await?; - - // Receive some data from the connection (12 bytes) - let mut buf: [u8; 12] = [0; 12]; - let _ = conn.read_exact(&mut buf).await?; - assert_eq!(&buf, b"hello server"); - - Ok(()) - }); - - // Wait for the server to be ready - let addr = rx.await.expect("Failed to get server server address"); - - // Connect to the socket, send some bytes, and get some bytes - let mut buf: [u8; 10] = [0; 10]; - - let mut conn = TcpTransport::connect(&addr) - .await - .expect("Conn failed to connect"); - conn.read_exact(&mut buf) - .await - .expect("Conn failed to read"); - assert_eq!(&buf, b"hello conn"); - - conn.write_all(b"hello server") - .await - .expect("Conn failed to write"); - - // Verify that the task has completed by waiting on it - let _ = task.await.expect("Server task failed unexpectedly"); - } -} diff --git a/distant-net/src/transport/unix.rs b/distant-net/src/transport/unix.rs deleted file mode 100644 index a20142d..0000000 --- a/distant-net/src/transport/unix.rs +++ /dev/null @@ -1,187 +0,0 @@ -use crate::{IntoSplit, RawTransport, RawTransportRead, RawTransportWrite}; -use std::{ - fmt, io, - path::{Path, PathBuf}, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::{ - io::{AsyncRead, AsyncWrite, ReadBuf}, - net::{ - unix::{OwnedReadHalf, OwnedWriteHalf}, - UnixStream, - }, -}; - -/// Represents a [`RawTransport`] that leverages a Unix socket -pub struct UnixSocketTransport { - pub(crate) path: PathBuf, - pub(crate) inner: UnixStream, -} - -impl UnixSocketTransport { - /// Creates a new stream by connecting to the specified path - pub async fn connect(path: impl AsRef) -> io::Result { - let stream = UnixStream::connect(path.as_ref()).await?; - Ok(Self { - path: path.as_ref().to_path_buf(), - inner: stream, - }) - } - - /// Returns the path to the socket - pub fn path(&self) -> &Path { - &self.path - } -} - -impl fmt::Debug for UnixSocketTransport { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("UnixSocketTransport") - .field("path", &self.path) - .finish() - } -} - -impl RawTransport for UnixSocketTransport {} -impl RawTransportRead for UnixSocketTransport {} -impl RawTransportWrite for UnixSocketTransport {} - -impl RawTransportRead for OwnedReadHalf {} -impl RawTransportWrite for OwnedWriteHalf {} - -impl IntoSplit for UnixSocketTransport { - type Read = OwnedReadHalf; - type Write = OwnedWriteHalf; - - fn into_split(self) -> (Self::Write, Self::Read) { - let (r, w) = UnixStream::into_split(self.inner); - (w, r) - } -} - -impl AsyncRead for UnixSocketTransport { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_read(cx, buf) - } -} - -impl AsyncWrite for UnixSocketTransport { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.inner).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tempfile::NamedTempFile; - use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::UnixListener, - sync::oneshot, - task::JoinHandle, - }; - - #[tokio::test] - async fn should_fail_to_connect_if_socket_does_not_exist() { - // Generate a socket path and delete the file after so there is nothing there - let path = NamedTempFile::new() - .expect("Failed to create socket file") - .path() - .to_path_buf(); - - // Now this should fail as we're already bound to the name - UnixSocketTransport::connect(&path) - .await - .expect_err("Unexpectedly succeeded in connecting to missing socket"); - } - - #[tokio::test] - async fn should_fail_to_connect_if_path_is_not_a_socket() { - // Generate a regular file - let path = NamedTempFile::new() - .expect("Failed to create socket file") - .into_temp_path(); - - // Now this should fail as this file is not a socket - UnixSocketTransport::connect(&path) - .await - .expect_err("Unexpectedly succeeded in connecting to regular file"); - } - - #[tokio::test] - async fn should_be_able_to_send_and_receive_data() { - let (tx, rx) = oneshot::channel(); - - // Spawn a task that will wait for a connection, send data, - // and receive data that it will return in the task - let task: JoinHandle> = tokio::spawn(async move { - // Generate a socket path and delete the file after so there is nothing there - let path = NamedTempFile::new() - .expect("Failed to create socket file") - .path() - .to_path_buf(); - - // Start listening at the socket path - let socket = UnixListener::bind(&path)?; - - // Send the path back to our main test thread - tx.send(path) - .map_err(|x| io::Error::new(io::ErrorKind::Other, x.display().to_string()))?; - - // Get the connection - let (mut conn, _) = socket.accept().await?; - - // Send some data to the connection (10 bytes) - conn.write_all(b"hello conn").await?; - - // Receive some data from the connection (12 bytes) - let mut buf: [u8; 12] = [0; 12]; - let _ = conn.read_exact(&mut buf).await?; - assert_eq!(&buf, b"hello server"); - - Ok(()) - }); - - // Wait for the server to be ready - let path = rx.await.expect("Failed to get server socket path"); - - // Connect to the socket, send some bytes, and get some bytes - let mut buf: [u8; 10] = [0; 10]; - - let mut conn = UnixSocketTransport::connect(&path) - .await - .expect("Conn failed to connect"); - conn.read_exact(&mut buf) - .await - .expect("Conn failed to read"); - assert_eq!(&buf, b"hello conn"); - - conn.write_all(b"hello server") - .await - .expect("Conn failed to write"); - - // Verify that the task has completed by waiting on it - let _ = task.await.expect("Server task failed unexpectedly"); - } -} diff --git a/distant-net/src/transport/untyped.rs b/distant-net/src/transport/untyped.rs deleted file mode 100644 index dfe871a..0000000 --- a/distant-net/src/transport/untyped.rs +++ /dev/null @@ -1,61 +0,0 @@ -use crate::{TypedAsyncRead, TypedAsyncWrite, TypedTransport}; -use async_trait::async_trait; -use serde::{de::DeserializeOwned, Serialize}; -use std::io; - -/// Interface representing a transport that uses [`serde`] to serialize and deserialize data -/// as it is sent and received -pub trait UntypedTransport: UntypedTransportRead + UntypedTransportWrite {} - -/// Interface representing a transport's read half that uses [`serde`] to deserialize data as it is -/// received -#[async_trait] -pub trait UntypedTransportRead: Send + Unpin { - /// Attempts to read some data as `T`, returning [`io::Error`] if unable to deserialize - /// or some other error occurs. `Some(T)` is returned if successful. `None` is - /// returned if no more data is available. - async fn read(&mut self) -> io::Result> - where - T: DeserializeOwned; -} - -/// Interface representing a transport's write half that uses [`serde`] to serialize data as it is -/// sent -#[async_trait] -pub trait UntypedTransportWrite: Send + Unpin { - /// Attempts to write some data of type `T`, returning [`io::Error`] if unable to serialize - /// or some other error occurs. - async fn write(&mut self, data: T) -> io::Result<()> - where - T: Serialize + Send + 'static; -} - -impl TypedTransport for T -where - T: UntypedTransport + Send, - W: Serialize + Send + 'static, - R: DeserializeOwned, -{ -} - -#[async_trait] -impl TypedAsyncWrite for W -where - W: UntypedTransportWrite + Send, - T: Serialize + Send + 'static, -{ - async fn write(&mut self, data: T) -> io::Result<()> { - W::write(self, data).await - } -} - -#[async_trait] -impl TypedAsyncRead for R -where - R: UntypedTransportRead + Send, - T: DeserializeOwned, -{ - async fn read(&mut self) -> io::Result> { - R::read(self).await - } -} diff --git a/distant-net/src/transport/windows.rs b/distant-net/src/transport/windows.rs deleted file mode 100644 index e4d87f8..0000000 --- a/distant-net/src/transport/windows.rs +++ /dev/null @@ -1,202 +0,0 @@ -use crate::{IntoSplit, RawTransport, RawTransportRead, RawTransportWrite}; -use std::{ - ffi::{OsStr, OsString}, - fmt, io, - pin::Pin, - task::{Context, Poll}, - time::Duration, -}; -use tokio::{ - io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf}, - net::windows::named_pipe::ClientOptions, -}; - -// Equivalent to winapi::shared::winerror::ERROR_PIPE_BUSY -// DWORD -> c_uLong -> u32 -const ERROR_PIPE_BUSY: u32 = 231; - -// Time between attempts to connect to a busy pipe -const BUSY_PIPE_SLEEP_MILLIS: u64 = 50; - -mod pipe; -pub use pipe::NamedPipe; - -/// Represents a [`RawTransport`] that leverages a named Windows pipe (client or server) -pub struct WindowsPipeTransport { - pub(crate) addr: OsString, - pub(crate) inner: NamedPipe, -} - -impl WindowsPipeTransport { - /// Establishes a connection to the pipe with the specified name, using the - /// name for a local pipe address in the form of `\\.\pipe\my_pipe_name` where - /// `my_pipe_name` is provided to this function - pub async fn connect_local(name: impl AsRef) -> io::Result { - let mut addr = OsString::from(r"\\.\pipe\"); - addr.push(name.as_ref()); - Self::connect(addr).await - } - - /// Establishes a connection to the pipe at the specified address - /// - /// Address may be something like `\.\pipe\my_pipe_name` - pub async fn connect(addr: impl Into) -> io::Result { - let addr = addr.into(); - - let pipe = loop { - match ClientOptions::new().open(&addr) { - Ok(client) => break client, - Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), - Err(e) => return Err(e), - } - - tokio::time::sleep(Duration::from_millis(BUSY_PIPE_SLEEP_MILLIS)).await; - }; - - Ok(Self { - addr, - inner: NamedPipe::from(pipe), - }) - } - - /// Returns the addr that the listener is bound to - pub fn addr(&self) -> &OsStr { - &self.addr - } -} - -impl fmt::Debug for WindowsPipeTransport { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("WindowsPipeTransport") - .field("addr", &self.addr) - .finish() - } -} - -impl RawTransport for WindowsPipeTransport {} -impl RawTransportRead for WindowsPipeTransport {} -impl RawTransportWrite for WindowsPipeTransport {} - -impl RawTransportRead for ReadHalf {} -impl RawTransportWrite for WriteHalf {} - -impl IntoSplit for WindowsPipeTransport { - type Read = ReadHalf; - type Write = WriteHalf; - - fn into_split(self) -> (Self::Write, Self::Read) { - let (reader, writer) = tokio::io::split(self); - (writer, reader) - } -} - -impl AsyncRead for WindowsPipeTransport { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_read(cx, buf) - } -} - -impl AsyncWrite for WindowsPipeTransport { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.inner).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::windows::named_pipe::ServerOptions, - sync::oneshot, - task::JoinHandle, - }; - - #[tokio::test] - async fn should_fail_to_connect_if_pipe_does_not_exist() { - // Generate a pipe name - let name = format!("test_pipe_{}", rand::random::()); - - // Now this should fail as we're already bound to the name - WindowsPipeTransport::connect_local(&name) - .await - .expect_err("Unexpectedly succeeded in connecting to missing pipe"); - } - - #[tokio::test] - async fn should_be_able_to_send_and_receive_data() { - let (tx, rx) = oneshot::channel(); - - // Spawn a task that will wait for a connection, send data, - // and receive data that it will return in the task - let task: JoinHandle> = tokio::spawn(async move { - // Generate a pipe address (not just a name) - let addr = format!(r"\\.\pipe\test_pipe_{}", rand::random::()); - - // Listen at the pipe - let pipe = ServerOptions::new() - .first_pipe_instance(true) - .create(&addr)?; - - // Send the address back to our main test thread - tx.send(addr) - .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; - - // Get the connection - let mut conn = { - pipe.connect().await?; - pipe - }; - - // Send some data to the connection (10 bytes) - conn.write_all(b"hello conn").await?; - - // Receive some data from the connection (12 bytes) - let mut buf: [u8; 12] = [0; 12]; - let _ = conn.read_exact(&mut buf).await?; - assert_eq!(&buf, b"hello server"); - - Ok(()) - }); - - // Wait for the server to be ready - let address = rx.await.expect("Failed to get server address"); - - // Connect to the pipe, send some bytes, and get some bytes - let mut buf: [u8; 10] = [0; 10]; - - let mut conn = WindowsPipeTransport::connect(&address) - .await - .expect("Conn failed to connect"); - conn.read_exact(&mut buf) - .await - .expect("Conn failed to read"); - assert_eq!(&buf, b"hello conn"); - - conn.write_all(b"hello server") - .await - .expect("Conn failed to write"); - - // Verify that the task has completed by waiting on it - let _ = task.await.expect("Server task failed unexpectedly"); - } -} diff --git a/distant-net/src/transport/windows/pipe.rs b/distant-net/src/transport/windows/pipe.rs deleted file mode 100644 index 532d0ed..0000000 --- a/distant-net/src/transport/windows/pipe.rs +++ /dev/null @@ -1,101 +0,0 @@ -use derive_more::{From, TryInto}; -use std::{ - pin::Pin, - task::{Context, Poll}, -}; -use tokio::{ - io::{self, AsyncRead, AsyncWrite, ReadBuf}, - net::windows::named_pipe::{NamedPipeClient, NamedPipeServer}, -}; - -#[derive(From, TryInto)] -pub enum NamedPipe { - Client(NamedPipeClient), - Server(NamedPipeServer), -} - -impl NamedPipe { - pub fn as_client(&self) -> Option<&NamedPipeClient> { - match self { - Self::Client(x) => Some(x), - _ => None, - } - } - - pub fn as_mut_client(&mut self) -> Option<&mut NamedPipeClient> { - match self { - Self::Client(x) => Some(x), - _ => None, - } - } - - pub fn into_client(self) -> Option { - match self { - Self::Client(x) => Some(x), - _ => None, - } - } - - pub fn as_server(&self) -> Option<&NamedPipeServer> { - match self { - Self::Server(x) => Some(x), - _ => None, - } - } - - pub fn as_mut_server(&mut self) -> Option<&mut NamedPipeServer> { - match self { - Self::Server(x) => Some(x), - _ => None, - } - } - - pub fn into_server(self) -> Option { - match self { - Self::Server(x) => Some(x), - _ => None, - } - } -} -impl AsyncRead for NamedPipe { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match Pin::get_mut(self) { - Self::Client(x) => Pin::new(x).poll_read(cx, buf), - Self::Server(x) => Pin::new(x).poll_read(cx, buf), - } - } -} - -impl AsyncWrite for NamedPipe { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match Pin::get_mut(self) { - Self::Client(x) => Pin::new(x).poll_write(cx, buf), - Self::Server(x) => Pin::new(x).poll_write(cx, buf), - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match Pin::get_mut(self) { - Self::Client(x) => Pin::new(x).poll_flush(cx), - Self::Server(x) => Pin::new(x).poll_flush(cx), - } - } - - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - match Pin::get_mut(self) { - Self::Client(x) => Pin::new(x).poll_shutdown(cx), - Self::Server(x) => Pin::new(x).poll_shutdown(cx), - } - } -} diff --git a/distant-net/tests/auth.rs b/distant-net/tests/auth.rs deleted file mode 100644 index e798446..0000000 --- a/distant-net/tests/auth.rs +++ /dev/null @@ -1,169 +0,0 @@ -use distant_net::{ - AuthClient, AuthErrorKind, AuthQuestion, AuthRequest, AuthServer, AuthVerifyKind, Client, - IntoSplit, MpscListener, MpscTransport, ServerExt, -}; -use std::collections::HashMap; -use tokio::sync::mpsc; - -/// Spawns a server and client connected together, returning the client -fn setup() -> (AuthClient, mpsc::Receiver) { - // Make a pair of inmemory transports that we can use to test client and server connected - let (t1, t2) = MpscTransport::pair(100); - - // Create the client - let (writer, reader) = t1.into_split(); - let client = AuthClient::from(Client::new(writer, reader).unwrap()); - - // Prepare a channel where we can pass back out whatever request we get - let (tx, rx) = mpsc::channel(100); - - let tx_2 = tx.clone(); - let tx_3 = tx.clone(); - let tx_4 = tx.clone(); - - // Make a server that echos questions back as answers and only verifies the text "yes" - let server = AuthServer { - on_challenge: move |questions, options| { - let questions_2 = questions.clone(); - tx.try_send(AuthRequest::Challenge { questions, options }) - .unwrap(); - questions_2.into_iter().map(|x| x.text).collect() - }, - on_verify: move |kind, text| { - let valid = text == "yes"; - tx_2.try_send(AuthRequest::Verify { kind, text }).unwrap(); - valid - }, - on_info: move |text| { - tx_3.try_send(AuthRequest::Info { text }).unwrap(); - }, - on_error: move |kind, text| { - tx_4.try_send(AuthRequest::Error { kind, text }).unwrap(); - }, - }; - - // Spawn the server to listen for our client to connect - tokio::spawn(async move { - let (writer, reader) = t2.into_split(); - let (tx, listener) = MpscListener::channel(1); - tx.send((writer, reader)).await.unwrap(); - let _server = server.start(listener).unwrap(); - }); - - (client, rx) -} - -#[tokio::test] -async fn client_should_be_able_to_challenge_against_server() { - let (mut client, mut rx) = setup(); - - // Gotta start with the handshake first - client.handshake().await.unwrap(); - - // Now do the challenge - assert_eq!( - client - .challenge( - vec![AuthQuestion::new("hello".to_string())], - Default::default() - ) - .await - .unwrap(), - vec!["hello".to_string()] - ); - - // Verify that the server received the request - let request = rx.recv().await.unwrap(); - match request { - AuthRequest::Challenge { questions, options } => { - assert_eq!(questions.len(), 1); - assert_eq!(questions[0].text, "hello"); - assert_eq!(questions[0].options, HashMap::new()); - - assert_eq!(options, HashMap::new()); - } - x => panic!("Unexpected request received by server: {:?}", x), - } -} - -#[tokio::test] -async fn client_should_be_able_to_verify_against_server() { - let (mut client, mut rx) = setup(); - - // Gotta start with the handshake first - client.handshake().await.unwrap(); - - // "no" will yield false - assert!(!client - .verify(AuthVerifyKind::Host, "no".to_string()) - .await - .unwrap()); - - // Verify that the server received the request - let request = rx.recv().await.unwrap(); - match request { - AuthRequest::Verify { kind, text } => { - assert_eq!(kind, AuthVerifyKind::Host); - assert_eq!(text, "no"); - } - x => panic!("Unexpected request received by server: {:?}", x), - } - - // "yes" will yield true - assert!(client - .verify(AuthVerifyKind::Host, "yes".to_string()) - .await - .unwrap()); - - // Verify that the server received the request - let request = rx.recv().await.unwrap(); - match request { - AuthRequest::Verify { kind, text } => { - assert_eq!(kind, AuthVerifyKind::Host); - assert_eq!(text, "yes"); - } - x => panic!("Unexpected request received by server: {:?}", x), - } -} - -#[tokio::test] -async fn client_should_be_able_to_send_info_to_server() { - let (mut client, mut rx) = setup(); - - // Gotta start with the handshake first - client.handshake().await.unwrap(); - - // Send some information - client.info(String::from("hello, world")).await.unwrap(); - - // Verify that the server received the request - let request = rx.recv().await.unwrap(); - match request { - AuthRequest::Info { text } => assert_eq!(text, "hello, world"), - x => panic!("Unexpected request received by server: {:?}", x), - } -} - -#[tokio::test] -async fn client_should_be_able_to_send_error_to_server() { - let (mut client, mut rx) = setup(); - - // Gotta start with the handshake first - client.handshake().await.unwrap(); - - // Send some error - client - .error(AuthErrorKind::Unknown, String::from("hello, world")) - .await - .unwrap(); - - // Verify that the server received the request - let request = rx.recv().await.unwrap(); - match request { - AuthRequest::Error { kind, text } => { - assert_eq!(kind, AuthErrorKind::Unknown); - assert_eq!(text, "hello, world"); - } - x => panic!("Unexpected request received by server: {:?}", x), - } -} diff --git a/distant-net/tests/lib.rs b/distant-net/tests/lib.rs deleted file mode 100644 index 12bc9de..0000000 --- a/distant-net/tests/lib.rs +++ /dev/null @@ -1 +0,0 @@ -mod auth; diff --git a/distant-net/tests/manager_tests.rs b/distant-net/tests/manager_tests.rs new file mode 100644 index 0000000..e45f74d --- /dev/null +++ b/distant-net/tests/manager_tests.rs @@ -0,0 +1,125 @@ +use async_trait::async_trait; +use distant_net::boxed_connect_handler; +use distant_net::client::{Client, ReconnectStrategy}; +use distant_net::common::authentication::{DummyAuthHandler, Verifier}; +use distant_net::common::{Destination, InmemoryTransport, Map, OneshotListener}; +use distant_net::manager::{Config, ManagerClient, ManagerServer}; +use distant_net::server::{Server, ServerCtx, ServerHandler}; +use log::*; +use std::io; +use test_log::test; + +struct TestServerHandler; + +#[async_trait] +impl ServerHandler for TestServerHandler { + type Request = String; + type Response = String; + type LocalData = (); + + async fn on_request(&self, ctx: ServerCtx) { + ctx.reply + .send(format!("echo {}", ctx.request.payload)) + .await + .expect("Failed to send response") + } +} + +#[test(tokio::test)] +async fn should_be_able_to_establish_a_single_connection_and_communicate_with_a_manager() { + let (t1, t2) = InmemoryTransport::pair(100); + + let mut config = Config::default(); + config.connect_handlers.insert( + "scheme".to_string(), + boxed_connect_handler!(|_a, _b, _c| { + let (t1, t2) = InmemoryTransport::pair(100); + + // Spawn a server on one end and connect to it on the other + let _ = Server::new() + .handler(TestServerHandler) + .verifier(Verifier::none()) + .start(OneshotListener::from_value(t2))?; + + let client = Client::build() + .auth_handler(DummyAuthHandler) + .reconnect_strategy(ReconnectStrategy::Fail) + .connector(t1) + .connect_untyped() + .await?; + + Ok(client) + }), + ); + + info!("Starting manager"); + let _manager_ref = ManagerServer::new(config) + .verifier(Verifier::none()) + .start(OneshotListener::from_value(t2)) + .expect("Failed to start manager server"); + + info!("Connecting to manager"); + let mut client: ManagerClient = Client::build() + .auth_handler(DummyAuthHandler) + .reconnect_strategy(ReconnectStrategy::Fail) + .connector(t1) + .connect() + .await + .expect("Failed to connect to manager"); + + // Test establishing a connection to some remote server + info!("Submitting server connection request to manager"); + let id = client + .connect( + "scheme://host".parse::().unwrap(), + "key=value".parse::().unwrap(), + DummyAuthHandler, + ) + .await + .expect("Failed to connect to a remote server"); + + // Test retrieving list of connections + info!("Submitting connection list request to manager"); + let list = client + .list() + .await + .expect("Failed to get list of connections"); + assert_eq!(list.len(), 1); + assert_eq!(list.get(&id).unwrap().to_string(), "scheme://host"); + + // Test retrieving information + info!("Submitting connection info request to manager"); + let info = client + .info(id) + .await + .expect("Failed to get info about connection"); + assert_eq!(info.id, id); + assert_eq!(info.destination.to_string(), "scheme://host"); + assert_eq!(info.options, "key=value".parse::().unwrap()); + + // Create a new channel and request some data + info!("Submitting server channel open request to manager"); + let mut channel_client: Client = client + .open_raw_channel(id) + .await + .expect("Failed to open channel") + .into_client(); + + info!("Verifying server channel can send and receive data"); + let res = channel_client + .send("hello".to_string()) + .await + .expect("Failed to send request to server"); + assert_eq!(res.payload, "echo hello", "Invalid response payload"); + + // Test killing a connection + info!("Submitting connection kill request to manager"); + client.kill(id).await.expect("Failed to kill connection"); + + // Test getting an error to ensure that serialization of that data works, + // which we do by trying to access a connection that no longer exists + info!("Verifying server connection held by manager has terminated"); + let err = client.info(id).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!(err.to_string(), "No connection found"); +} diff --git a/distant-net/tests/typed_tests.rs b/distant-net/tests/typed_tests.rs new file mode 100644 index 0000000..e2503ff --- /dev/null +++ b/distant-net/tests/typed_tests.rs @@ -0,0 +1,70 @@ +use async_trait::async_trait; +use distant_net::client::{Client, ReconnectStrategy}; +use distant_net::common::authentication::{DummyAuthHandler, Verifier}; +use distant_net::common::{InmemoryTransport, OneshotListener}; +use distant_net::server::{Server, ServerCtx, ServerHandler}; +use log::*; +use test_log::test; + +struct TestServerHandler; + +#[async_trait] +impl ServerHandler for TestServerHandler { + type Request = (u8, String); + type Response = String; + type LocalData = (); + + async fn on_request(&self, ctx: ServerCtx) { + let (cnt, msg) = ctx.request.payload; + + for i in 0..cnt { + ctx.reply + .send(format!("echo {i} {msg}")) + .await + .expect("Failed to send response"); + } + } +} + +#[test(tokio::test)] +async fn should_be_able_to_send_and_receive_typed_payloads_between_client_and_server() { + let (t1, t2) = InmemoryTransport::pair(100); + + let _ = Server::new() + .handler(TestServerHandler) + .verifier(Verifier::none()) + .start(OneshotListener::from_value(t2)) + .expect("Failed to start server"); + + let mut client: Client<(u8, String), String> = Client::build() + .auth_handler(DummyAuthHandler) + .reconnect_strategy(ReconnectStrategy::Fail) + .connector(t1) + .connect() + .await + .expect("Failed to connect to server"); + + info!("Mailing a message from the client, and waiting for 3 responses"); + let mut mailbox = client + .mail((3, "hello".to_string())) + .await + .expect("Failed to mail message"); + + assert_eq!(mailbox.next().await.unwrap().payload, "echo 0 hello"); + assert_eq!(mailbox.next().await.unwrap().payload, "echo 1 hello"); + assert_eq!(mailbox.next().await.unwrap().payload, "echo 2 hello"); + + info!("Sending a message from the client, and waiting for a response"); + let response = client + .send((1, "hello".to_string())) + .await + .expect("Failed to send message"); + + assert_eq!(response.payload, "echo 0 hello"); + + info!("Firing off a message from the client"); + client + .fire((1, "hello".to_string())) + .await + .expect("Failed to fire message"); +} diff --git a/distant-net/tests/untyped_tests.rs b/distant-net/tests/untyped_tests.rs new file mode 100644 index 0000000..e8526f7 --- /dev/null +++ b/distant-net/tests/untyped_tests.rs @@ -0,0 +1,112 @@ +use async_trait::async_trait; +use distant_net::client::{Client, ReconnectStrategy}; +use distant_net::common::authentication::{DummyAuthHandler, Verifier}; +use distant_net::common::{InmemoryTransport, OneshotListener, Request}; +use distant_net::server::{Server, ServerCtx, ServerHandler}; +use log::*; +use test_log::test; + +struct TestServerHandler; + +#[async_trait] +impl ServerHandler for TestServerHandler { + type Request = (u8, String); + type Response = String; + type LocalData = (); + + async fn on_request(&self, ctx: ServerCtx) { + let (cnt, msg) = ctx.request.payload; + + for i in 0..cnt { + ctx.reply + .send(format!("echo {i} {msg}")) + .await + .expect("Failed to send response"); + } + } +} + +#[test(tokio::test)] +async fn should_be_able_to_send_and_receive_untyped_payloads_between_client_and_server() { + let (t1, t2) = InmemoryTransport::pair(100); + + let _ = Server::new() + .handler(TestServerHandler) + .verifier(Verifier::none()) + .start(OneshotListener::from_value(t2)) + .expect("Failed to start server"); + + let mut client = Client::build() + .auth_handler(DummyAuthHandler) + .reconnect_strategy(ReconnectStrategy::Fail) + .connector(t1) + .connect_untyped() + .await + .expect("Failed to connect to server"); + + info!("Mailing a message from the client, and waiting for 3 responses"); + let mut mailbox = client + .mail( + Request::new((3, "hello".to_string())) + .to_untyped_request() + .unwrap(), + ) + .await + .expect("Failed to mail message"); + + assert_eq!( + mailbox + .next() + .await + .unwrap() + .to_typed_response::() + .unwrap() + .payload, + "echo 0 hello" + ); + assert_eq!( + mailbox + .next() + .await + .unwrap() + .to_typed_response::() + .unwrap() + .payload, + "echo 1 hello" + ); + assert_eq!( + mailbox + .next() + .await + .unwrap() + .to_typed_response::() + .unwrap() + .payload, + "echo 2 hello" + ); + + info!("Sending a message from the client, and waiting for a response"); + let response = client + .send( + Request::new((1, "hello".to_string())) + .to_untyped_request() + .unwrap(), + ) + .await + .expect("Failed to send message"); + + assert_eq!( + response.to_typed_response::().unwrap().payload, + "echo 0 hello" + ); + + info!("Firing off a message from the client"); + client + .fire( + Request::new((1, "hello".to_string())) + .to_untyped_request() + .unwrap(), + ) + .await + .expect("Failed to fire message"); +} diff --git a/distant-ssh2/Cargo.toml b/distant-ssh2/Cargo.toml index d9959ed..e1c89c7 100644 --- a/distant-ssh2/Cargo.toml +++ b/distant-ssh2/Cargo.toml @@ -2,7 +2,7 @@ name = "distant-ssh2" description = "Library to enable native ssh-2 protocol for use with distant sessions" categories = ["network-programming"] -version = "0.19.0" +version = "0.20.0" authors = ["Chip Senkbeil "] edition = "2021" homepage = "https://github.com/chipsenkbeil/distant" @@ -20,7 +20,7 @@ async-compat = "0.2.1" async-once-cell = "0.4.2" async-trait = "0.1.57" derive_more = { version = "0.99.17", default-features = false, features = ["display", "error"] } -distant-core = { version = "=0.19.0", path = "../distant-core" } +distant-core = { version = "=0.20.0", path = "../distant-core" } futures = "0.3.21" hex = "0.4.3" log = "0.4.17" @@ -40,10 +40,11 @@ serde = { version = "1.0.142", features = ["derive"], optional = true } anyhow = "1.0.60" assert_fs = "1.0.7" dunce = "1.0.2" -flexi_logger = "0.23.0" +env_logger = "0.9.1" indoc = "1.0.7" once_cell = "1.13.0" predicates = "2.1.1" rstest = "0.15.0" +test-log = "0.2.11" which = "4.2.5" whoami = "1.2.1" diff --git a/distant-ssh2/src/api.rs b/distant-ssh2/src/api.rs index 3a37e6e..409b3ed 100644 --- a/distant-ssh2/src/api.rs +++ b/distant-ssh2/src/api.rs @@ -10,6 +10,7 @@ use distant_core::{ Capabilities, CapabilityKind, DirEntry, Environment, FileType, Metadata, ProcessId, PtySize, SystemInfo, UnixMetadata, }, + net::server::ConnectionCtx, DistantApi, DistantCtx, }; use log::*; @@ -75,8 +76,9 @@ impl SshDistantApi { impl DistantApi for SshDistantApi { type LocalData = ConnectionState; - async fn on_accept(&self, local_data: &mut Self::LocalData) { - local_data.global_processes = Arc::downgrade(&self.processes); + async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { + ctx.local_data.global_processes = Arc::downgrade(&self.processes); + Ok(()) } async fn capabilities(&self, ctx: DistantCtx) -> io::Result { diff --git a/distant-ssh2/src/lib.rs b/distant-ssh2/src/lib.rs index 47b420f..9124dcc 100644 --- a/distant-ssh2/src/lib.rs +++ b/distant-ssh2/src/lib.rs @@ -7,11 +7,12 @@ use async_trait::async_trait; use distant_core::{ data::Environment, net::{ - FramedTransport, IntoSplit, OneshotListener, ServerExt, ServerRef, TcpClientExt, - XChaCha20Poly1305Codec, + client::{Client, ReconnectStrategy}, + common::authentication::{AuthHandlerMap, DummyAuthHandler, Verifier}, + common::{InmemoryTransport, OneshotListener}, + server::{Server, ServerRef}, }, - BoxedDistantReader, BoxedDistantWriter, BoxedDistantWriterReader, DistantApiServer, - DistantChannelExt, DistantClient, DistantSingleKeyCredentials, + DistantApiServerHandler, DistantChannelExt, DistantClient, DistantSingleKeyCredentials, }; use log::*; use smol::channel::Receiver as SmolReceiver; @@ -565,14 +566,18 @@ impl Ssh { let credentials = self.launch(opts).await?; let key = credentials.key; - let codec = XChaCha20Poly1305Codec::from(key); // Try each IP address with the same port to see if one works let mut err = None; for ip in candidate_ips { let addr = SocketAddr::new(ip, credentials.port); debug!("Attempting to connect to distant server @ {}", addr); - match DistantClient::connect_timeout(addr, codec.clone(), timeout).await { + match Client::tcp(addr) + .auth_handler(AuthHandlerMap::new().with_static_key(key.clone())) + .timeout(timeout) + .connect() + .await + { Ok(client) => return Ok(client), Err(x) => err = Some(x), } @@ -684,28 +689,13 @@ impl Ssh { } /// Consume [`Ssh`] and produce a [`DistantClient`] that is powered by an ssh client - /// underneath + /// underneath. pub async fn into_distant_client(self) -> io::Result { Ok(self.into_distant_pair().await?.0) } - /// Consume [`Ssh`] and produce a [`BoxedDistantWriterReader`] that is powered by an ssh client - /// underneath - pub async fn into_distant_writer_reader(self) -> io::Result { - Ok(self.into_writer_reader_and_server().await?.0) - } - - /// Consumes [`Ssh`] and produces a [`DistantClient`] and [`DistantApiServer`] pair + /// Consumes [`Ssh`] and produces a [`DistantClient`] and [`ServerRef`] pair. pub async fn into_distant_pair(self) -> io::Result<(DistantClient, Box)> { - let ((writer, reader), server) = self.into_writer_reader_and_server().await?; - let client = DistantClient::new(writer, reader)?; - Ok((client, server)) - } - - /// Consumes [`Ssh`] and produces a [`DistantClient`] and [`DistantApiServer`] pair - async fn into_writer_reader_and_server( - self, - ) -> io::Result<(BoxedDistantWriterReader, Box)> { // Exit early if not authenticated as this is a requirement if !self.authenticated { return Err(io::Error::new( @@ -714,24 +704,24 @@ impl Ssh { )); } - let (t1, t2) = FramedTransport::pair(1); - - // Spawn a bridge client that is directly connected to our server - let (writer, reader) = t1.into_split(); - let writer: BoxedDistantWriter = Box::new(writer); - let reader: BoxedDistantReader = Box::new(reader); - - // Spawn a bridge server that is directly connected to our client - let server = { - let Self { - session: wez_session, - .. - } = self; - let (writer, reader) = t2.into_split(); - DistantApiServer::new(SshDistantApi::new(wez_session)) - .start(OneshotListener::from_value((writer, reader)))? - }; - - Ok(((writer, reader), server)) + let Self { + session: wez_session, + .. + } = self; + + let (t1, t2) = InmemoryTransport::pair(1); + let server = Server::new() + .handler(DistantApiServerHandler::new(SshDistantApi::new( + wez_session, + ))) + .verifier(Verifier::none()) + .start(OneshotListener::from_value(t2))?; + let client = Client::build() + .auth_handler(DummyAuthHandler) + .connector(t1) + .reconnect_strategy(ReconnectStrategy::Fail) + .connect() + .await?; + Ok((client, server)) } } diff --git a/distant-ssh2/src/process.rs b/distant-ssh2/src/process.rs index 29882f1..6a0a465 100644 --- a/distant-ssh2/src/process.rs +++ b/distant-ssh2/src/process.rs @@ -1,7 +1,7 @@ use async_compat::CompatExt; use distant_core::{ data::{DistantResponseData, Environment, ProcessId, PtySize}, - net::Reply, + net::server::Reply, }; use log::*; use std::{ @@ -16,7 +16,7 @@ use wezterm_ssh::{ }; const MAX_PIPE_CHUNK_SIZE: usize = 8192; -const THREAD_PAUSE_MILLIS: u64 = 50; +const THREAD_PAUSE_MILLIS: u64 = 1; /// Result of spawning a process, containing means to send stdin, means to kill the process, /// and the initialization function to use to start processing stdin, stdout, and stderr diff --git a/distant-ssh2/tests/ssh2/client.rs b/distant-ssh2/tests/ssh2/client.rs index 042b7c9..6d91a3c 100644 --- a/distant-ssh2/tests/ssh2/client.rs +++ b/distant-ssh2/tests/ssh2/client.rs @@ -8,6 +8,7 @@ use once_cell::sync::Lazy; use predicates::prelude::*; use rstest::*; use std::{io, path::Path, time::Duration}; +use test_log::test; const SETUP_DIR_TIMEOUT: Duration = Duration::from_secs(1); const SETUP_DIR_POLL: Duration = Duration::from_millis(50); @@ -71,7 +72,7 @@ static DOES_NOT_EXIST_BIN: Lazy = Lazy::new(|| TEMP_SCRIPT_DIR.child("does_not_exist_bin")); #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn read_file_should_fail_if_file_missing(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -81,7 +82,7 @@ async fn read_file_should_fail_if_file_missing(#[future] client: Ctx) { let mut client = client.await; @@ -94,7 +95,7 @@ async fn read_file_should_send_blob_with_file_contents(#[future] client: Ctx, ) { @@ -107,7 +108,7 @@ async fn read_file_text_should_send_error_if_fails_to_read_file( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn read_file_text_should_send_text_with_file_contents(#[future] client: Ctx) { let mut client = client.await; @@ -123,7 +124,7 @@ async fn read_file_text_should_send_text_with_file_contents(#[future] client: Ct } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn write_file_should_send_error_if_fails_to_write_file(#[future] client: Ctx) { let mut client = client.await; @@ -142,7 +143,7 @@ async fn write_file_should_send_error_if_fails_to_write_file(#[future] client: C } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn write_file_should_send_ok_when_successful(#[future] client: Ctx) { let mut client = client.await; @@ -162,7 +163,7 @@ async fn write_file_should_send_ok_when_successful(#[future] client: Ctx, ) { @@ -183,7 +184,7 @@ async fn write_file_text_should_send_error_if_fails_to_write_file( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn write_file_text_should_send_ok_when_successful(#[future] client: Ctx) { let mut client = client.await; @@ -203,7 +204,7 @@ async fn write_file_text_should_send_ok_when_successful(#[future] client: Ctx, ) { @@ -224,7 +225,7 @@ async fn append_file_should_send_error_if_fails_to_create_file( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn append_file_should_create_file_if_missing(#[future] client: Ctx) { let mut client = client.await; @@ -246,7 +247,7 @@ async fn append_file_should_create_file_if_missing(#[future] client: Ctx) { let mut client = client.await; @@ -268,7 +269,7 @@ async fn append_file_should_send_ok_when_successful(#[future] client: Ctx, ) { @@ -289,7 +290,7 @@ async fn append_file_text_should_send_error_if_fails_to_create_file( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn append_file_text_should_create_file_if_missing(#[future] client: Ctx) { let mut client = client.await; @@ -311,7 +312,7 @@ async fn append_file_text_should_create_file_if_missing(#[future] client: Ctx) { let mut client = client.await; @@ -333,7 +334,7 @@ async fn append_file_text_should_send_ok_when_successful(#[future] client: Ctx, ) { @@ -401,7 +402,7 @@ async fn setup_dir() -> assert_fs::TempDir { // NOTE: CI fails this on Windows, but it's running Windows with bash and strange paths, so ignore // it only for the CI #[rstest] -#[tokio::test] +#[test(tokio::test)] #[cfg_attr(all(windows, ci), ignore)] async fn dir_read_should_support_depth_limits(#[future] client: Ctx) { let mut client = client.await; @@ -438,7 +439,7 @@ async fn dir_read_should_support_depth_limits(#[future] client: Ctx) { let mut client = client.await; @@ -478,7 +479,7 @@ async fn dir_read_should_support_unlimited_depth_using_zero(#[future] client: Ct // NOTE: This is failing on windows as canonicalization of root path is not correct! #[rstest] -#[tokio::test] +#[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn dir_read_should_support_including_directory_in_returned_entries( #[future] client: Ctx, @@ -524,7 +525,7 @@ async fn dir_read_should_support_including_directory_in_returned_entries( // NOTE: This is failing on windows as canonicalization of root path is not correct! #[rstest] -#[tokio::test] +#[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn dir_read_should_support_returning_absolute_paths(#[future] client: Ctx) { let mut client = client.await; @@ -561,7 +562,7 @@ async fn dir_read_should_support_returning_absolute_paths(#[future] client: Ctx< // NOTE: This is failing on windows as the symlink does not get resolved! #[rstest] -#[tokio::test] +#[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn dir_read_should_support_returning_canonicalized_paths( #[future] client: Ctx, @@ -600,7 +601,7 @@ async fn dir_read_should_support_returning_canonicalized_paths( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn create_dir_should_send_error_if_fails(#[future] client: Ctx) { let mut client = client.await; @@ -619,7 +620,7 @@ async fn create_dir_should_send_error_if_fails(#[future] client: Ctx) { let mut client = client.await; let root_dir = setup_dir().await; @@ -635,7 +636,7 @@ async fn create_dir_should_send_ok_when_successful(#[future] client: Ctx, ) { @@ -653,7 +654,7 @@ async fn create_dir_should_support_creating_multiple_dir_components( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn remove_should_send_error_on_failure(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -669,7 +670,7 @@ async fn remove_should_send_error_on_failure(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -686,7 +687,7 @@ async fn remove_should_support_deleting_a_directory(#[future] client: Ctx, ) { @@ -706,7 +707,7 @@ async fn remove_should_delete_nonempty_directory_if_force_is_true( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn remove_should_support_deleting_a_single_file(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -723,7 +724,7 @@ async fn remove_should_support_deleting_a_single_file(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -740,7 +741,7 @@ async fn copy_should_send_error_on_failure(#[future] client: Ctx) } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn copy_should_support_copying_an_entire_directory(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -766,7 +767,7 @@ async fn copy_should_support_copying_an_entire_directory(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -785,7 +786,7 @@ async fn copy_should_support_copying_an_empty_directory(#[future] client: Ctx, ) { @@ -813,7 +814,7 @@ async fn copy_should_support_copying_a_directory_that_only_contains_directories( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn copy_should_support_copying_a_single_file(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -832,7 +833,7 @@ async fn copy_should_support_copying_a_single_file(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -849,7 +850,7 @@ async fn rename_should_fail_if_path_missing(#[future] client: Ctx } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn rename_should_support_renaming_an_entire_directory(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -875,7 +876,7 @@ async fn rename_should_support_renaming_an_entire_directory(#[future] client: Ct } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn rename_should_support_renaming_a_single_file(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -894,7 +895,7 @@ async fn rename_should_support_renaming_a_single_file(#[future] client: Ctx) { // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. let mut client = client.await; @@ -917,7 +918,7 @@ async fn watch_should_fail_as_unsupported(#[future] client: Ctx) } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn exists_should_send_true_if_path_exists(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -929,7 +930,7 @@ async fn exists_should_send_true_if_path_exists(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -940,7 +941,7 @@ async fn exists_should_send_false_if_path_does_not_exist(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -957,7 +958,7 @@ async fn metadata_should_send_error_on_failure(#[future] client: Ctx, ) { @@ -993,7 +994,7 @@ async fn metadata_should_send_back_metadata_on_file_if_exists( #[cfg(unix)] #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_include_unix_specific_metadata_on_unix_platform( #[future] client: Ctx, ) { @@ -1025,7 +1026,7 @@ async fn metadata_should_include_unix_specific_metadata_on_unix_platform( #[cfg(windows)] #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_not_include_windows_as_ssh_cannot_retrieve_that_information( #[future] client: Ctx, ) { @@ -1061,7 +1062,7 @@ async fn metadata_should_not_include_windows_as_ssh_cannot_retrieve_that_informa } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_send_back_metadata_on_dir_if_exists(#[future] client: Ctx) { let mut client = client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -1093,7 +1094,7 @@ async fn metadata_should_send_back_metadata_on_dir_if_exists(#[future] client: C } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_send_back_metadata_on_symlink_if_exists( #[future] client: Ctx, ) { @@ -1130,7 +1131,7 @@ async fn metadata_should_send_back_metadata_on_symlink_if_exists( } #[rstest] -#[tokio::test] +#[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn metadata_should_include_canonicalized_path_if_flag_specified( #[future] client: Ctx, @@ -1169,7 +1170,7 @@ async fn metadata_should_include_canonicalized_path_if_flag_specified( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified( #[future] client: Ctx, ) { @@ -1204,7 +1205,7 @@ async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn proc_spawn_should_not_fail_even_if_process_not_found( #[future] client: Ctx, ) { @@ -1224,7 +1225,7 @@ async fn proc_spawn_should_not_fail_even_if_process_not_found( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn proc_spawn_should_return_id_of_spawned_process(#[future] client: Ctx) { let mut client = client.await; @@ -1249,7 +1250,7 @@ async fn proc_spawn_should_return_id_of_spawned_process(#[future] client: Ctx, @@ -1285,7 +1286,7 @@ async fn proc_spawn_should_send_back_stdout_periodically_when_available( // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ #[rstest] -#[tokio::test] +#[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_spawn_should_send_back_stderr_periodically_when_available( #[future] client: Ctx, @@ -1321,7 +1322,7 @@ async fn proc_spawn_should_send_back_stderr_periodically_when_available( // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ #[rstest] -#[tokio::test] +#[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_spawn_should_send_done_signal_when_completed(#[future] client: Ctx) { let mut client = client.await; @@ -1342,7 +1343,7 @@ async fn proc_spawn_should_send_done_signal_when_completed(#[future] client: Ctx } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn proc_spawn_should_clear_process_from_state_when_killed( #[future] client: Ctx, ) { @@ -1369,7 +1370,7 @@ async fn proc_spawn_should_clear_process_from_state_when_killed( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn proc_kill_should_fail_if_process_not_running(#[future] client: Ctx) { let mut client = client.await; @@ -1397,7 +1398,7 @@ async fn proc_kill_should_fail_if_process_not_running(#[future] client: Ctx) { let mut client = client.await; @@ -1427,7 +1428,7 @@ async fn proc_stdin_should_fail_if_process_not_running(#[future] client: Ctx) { let mut client = client.await; @@ -1465,7 +1466,7 @@ async fn proc_stdin_should_send_stdin_to_process(#[future] client: Ctx, ) { diff --git a/distant-ssh2/tests/ssh2/launched.rs b/distant-ssh2/tests/ssh2/launched.rs index 2679598..75d52a6 100644 --- a/distant-ssh2/tests/ssh2/launched.rs +++ b/distant-ssh2/tests/ssh2/launched.rs @@ -8,6 +8,7 @@ use once_cell::sync::Lazy; use predicates::prelude::*; use rstest::*; use std::{path::Path, time::Duration}; +use test_log::test; static TEMP_SCRIPT_DIR: Lazy = Lazy::new(|| TempDir::new().unwrap()); static SCRIPT_RUNNER: Lazy = Lazy::new(|| String::from("bash")); @@ -68,7 +69,7 @@ static DOES_NOT_EXIST_BIN: Lazy = Lazy::new(|| TEMP_SCRIPT_DIR.child("does_not_exist_bin")); #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn read_file_should_fail_if_file_missing(#[future] launched_client: Ctx) { let mut client = launched_client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -78,7 +79,7 @@ async fn read_file_should_fail_if_file_missing(#[future] launched_client: Ctx, ) { @@ -93,7 +94,7 @@ async fn read_file_should_send_blob_with_file_contents( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn read_file_text_should_send_error_if_fails_to_read_file( #[future] launched_client: Ctx, ) { @@ -106,7 +107,7 @@ async fn read_file_text_should_send_error_if_fails_to_read_file( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn read_file_text_should_send_text_with_file_contents( #[future] launched_client: Ctx, ) { @@ -124,7 +125,7 @@ async fn read_file_text_should_send_text_with_file_contents( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn write_file_should_send_error_if_fails_to_write_file( #[future] launched_client: Ctx, ) { @@ -145,7 +146,7 @@ async fn write_file_should_send_error_if_fails_to_write_file( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn write_file_should_send_ok_when_successful(#[future] launched_client: Ctx) { let mut client = launched_client.await; @@ -165,7 +166,7 @@ async fn write_file_should_send_ok_when_successful(#[future] launched_client: Ct } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn write_file_text_should_send_error_if_fails_to_write_file( #[future] launched_client: Ctx, ) { @@ -186,7 +187,7 @@ async fn write_file_text_should_send_error_if_fails_to_write_file( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn write_file_text_should_send_ok_when_successful( #[future] launched_client: Ctx, ) { @@ -208,7 +209,7 @@ async fn write_file_text_should_send_ok_when_successful( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn append_file_should_send_error_if_fails_to_create_file( #[future] launched_client: Ctx, ) { @@ -229,7 +230,7 @@ async fn append_file_should_send_error_if_fails_to_create_file( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn append_file_should_create_file_if_missing(#[future] launched_client: Ctx) { let mut client = launched_client.await; @@ -251,7 +252,7 @@ async fn append_file_should_create_file_if_missing(#[future] launched_client: Ct } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn append_file_should_send_ok_when_successful(#[future] launched_client: Ctx) { let mut client = launched_client.await; @@ -273,7 +274,7 @@ async fn append_file_should_send_ok_when_successful(#[future] launched_client: C } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn append_file_text_should_send_error_if_fails_to_create_file( #[future] launched_client: Ctx, ) { @@ -294,7 +295,7 @@ async fn append_file_text_should_send_error_if_fails_to_create_file( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn append_file_text_should_create_file_if_missing( #[future] launched_client: Ctx, ) { @@ -318,7 +319,7 @@ async fn append_file_text_should_create_file_if_missing( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn append_file_text_should_send_ok_when_successful( #[future] launched_client: Ctx, ) { @@ -342,7 +343,7 @@ async fn append_file_text_should_send_ok_when_successful( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn dir_read_should_send_error_if_directory_does_not_exist( #[future] launched_client: Ctx, ) { @@ -385,7 +386,7 @@ async fn setup_dir() -> assert_fs::TempDir { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn dir_read_should_support_depth_limits(#[future] launched_client: Ctx) { let mut client = launched_client.await; @@ -419,7 +420,7 @@ async fn dir_read_should_support_depth_limits(#[future] launched_client: Ctx, ) { @@ -459,7 +460,7 @@ async fn dir_read_should_support_unlimited_depth_using_zero( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn dir_read_should_support_including_directory_in_returned_entries( #[future] launched_client: Ctx, ) { @@ -503,7 +504,7 @@ async fn dir_read_should_support_including_directory_in_returned_entries( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn dir_read_should_support_returning_absolute_paths( #[future] launched_client: Ctx, ) { @@ -540,7 +541,7 @@ async fn dir_read_should_support_returning_absolute_paths( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn dir_read_should_support_returning_canonicalized_paths( #[future] launched_client: Ctx, ) { @@ -578,7 +579,7 @@ async fn dir_read_should_support_returning_canonicalized_paths( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn create_dir_should_send_error_if_fails(#[future] launched_client: Ctx) { let mut client = launched_client.await; @@ -597,7 +598,7 @@ async fn create_dir_should_send_error_if_fails(#[future] launched_client: Ctx) { let mut client = launched_client.await; let root_dir = setup_dir().await; @@ -613,7 +614,7 @@ async fn create_dir_should_send_ok_when_successful(#[future] launched_client: Ct } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn create_dir_should_support_creating_multiple_dir_components( #[future] launched_client: Ctx, ) { @@ -631,7 +632,7 @@ async fn create_dir_should_support_creating_multiple_dir_components( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn remove_should_send_error_on_failure(#[future] launched_client: Ctx) { let mut client = launched_client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -647,7 +648,7 @@ async fn remove_should_send_error_on_failure(#[future] launched_client: Ctx) { let mut client = launched_client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -664,7 +665,7 @@ async fn remove_should_support_deleting_a_directory(#[future] launched_client: C } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn remove_should_delete_nonempty_directory_if_force_is_true( #[future] launched_client: Ctx, ) { @@ -684,7 +685,7 @@ async fn remove_should_delete_nonempty_directory_if_force_is_true( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn remove_should_support_deleting_a_single_file( #[future] launched_client: Ctx, ) { @@ -703,7 +704,7 @@ async fn remove_should_support_deleting_a_single_file( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn copy_should_send_error_on_failure(#[future] launched_client: Ctx) { let mut client = launched_client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -720,7 +721,7 @@ async fn copy_should_send_error_on_failure(#[future] launched_client: Ctx, ) { @@ -748,7 +749,7 @@ async fn copy_should_support_copying_an_entire_directory( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn copy_should_support_copying_an_empty_directory( #[future] launched_client: Ctx, ) { @@ -769,7 +770,7 @@ async fn copy_should_support_copying_an_empty_directory( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn copy_should_support_copying_a_directory_that_only_contains_directories( #[future] launched_client: Ctx, ) { @@ -797,7 +798,7 @@ async fn copy_should_support_copying_a_directory_that_only_contains_directories( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn copy_should_support_copying_a_single_file(#[future] launched_client: Ctx) { let mut client = launched_client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -816,7 +817,7 @@ async fn copy_should_support_copying_a_single_file(#[future] launched_client: Ct } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn rename_should_fail_if_path_missing(#[future] launched_client: Ctx) { let mut client = launched_client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -833,7 +834,7 @@ async fn rename_should_fail_if_path_missing(#[future] launched_client: Ctx, ) { @@ -861,7 +862,7 @@ async fn rename_should_support_renaming_an_entire_directory( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn rename_should_support_renaming_a_single_file( #[future] launched_client: Ctx, ) { @@ -882,7 +883,7 @@ async fn rename_should_support_renaming_a_single_file( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn watch_should_succeed(#[future] launched_client: Ctx) { // NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc. let mut client = launched_client.await; @@ -903,7 +904,7 @@ async fn watch_should_succeed(#[future] launched_client: Ctx) { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn exists_should_send_true_if_path_exists(#[future] launched_client: Ctx) { let mut client = launched_client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -915,7 +916,7 @@ async fn exists_should_send_true_if_path_exists(#[future] launched_client: Ctx, ) { @@ -928,7 +929,7 @@ async fn exists_should_send_false_if_path_does_not_exist( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_send_error_on_failure(#[future] launched_client: Ctx) { let mut client = launched_client.await; let temp = assert_fs::TempDir::new().unwrap(); @@ -945,7 +946,7 @@ async fn metadata_should_send_error_on_failure(#[future] launched_client: Ctx, ) { @@ -981,7 +982,7 @@ async fn metadata_should_send_back_metadata_on_file_if_exists( #[cfg(unix)] #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_include_unix_specific_metadata_on_unix_platform( #[future] launched_client: Ctx, ) { @@ -1013,7 +1014,7 @@ async fn metadata_should_include_unix_specific_metadata_on_unix_platform( #[cfg(windows)] #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_include_windows_specific_metadata_on_windows_platform( #[future] launched_client: Ctx, ) { @@ -1044,7 +1045,7 @@ async fn metadata_should_include_windows_specific_metadata_on_windows_platform( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_send_back_metadata_on_dir_if_exists( #[future] launched_client: Ctx, ) { @@ -1078,7 +1079,7 @@ async fn metadata_should_send_back_metadata_on_dir_if_exists( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_send_back_metadata_on_symlink_if_exists( #[future] launched_client: Ctx, ) { @@ -1115,7 +1116,7 @@ async fn metadata_should_send_back_metadata_on_symlink_if_exists( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_include_canonicalized_path_if_flag_specified( #[future] launched_client: Ctx, ) { @@ -1152,7 +1153,7 @@ async fn metadata_should_include_canonicalized_path_if_flag_specified( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified( #[future] launched_client: Ctx, ) { @@ -1187,7 +1188,7 @@ async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn proc_spawn_should_fail_if_process_not_found( #[future] launched_client: Ctx, ) { @@ -1206,7 +1207,7 @@ async fn proc_spawn_should_fail_if_process_not_found( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn proc_spawn_should_return_id_of_spawned_process( #[future] launched_client: Ctx, ) { @@ -1233,7 +1234,7 @@ async fn proc_spawn_should_return_id_of_spawned_process( // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ #[rstest] -#[tokio::test] +#[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_spawn_should_send_back_stdout_periodically_when_available( #[future] launched_client: Ctx, @@ -1269,7 +1270,7 @@ async fn proc_spawn_should_send_back_stdout_periodically_when_available( // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ #[rstest] -#[tokio::test] +#[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_spawn_should_send_back_stderr_periodically_when_available( #[future] launched_client: Ctx, @@ -1305,7 +1306,7 @@ async fn proc_spawn_should_send_back_stderr_periodically_when_available( // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ #[rstest] -#[tokio::test] +#[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_spawn_should_send_done_signal_when_completed( #[future] launched_client: Ctx, @@ -1328,7 +1329,7 @@ async fn proc_spawn_should_send_done_signal_when_completed( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn proc_spawn_should_clear_process_from_state_when_killed( #[future] launched_client: Ctx, ) { @@ -1355,7 +1356,7 @@ async fn proc_spawn_should_clear_process_from_state_when_killed( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn proc_kill_should_fail_if_process_not_running( #[future] launched_client: Ctx, ) { @@ -1385,7 +1386,7 @@ async fn proc_kill_should_fail_if_process_not_running( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn proc_stdin_should_fail_if_process_not_running( #[future] launched_client: Ctx, ) { @@ -1417,7 +1418,7 @@ async fn proc_stdin_should_fail_if_process_not_running( // NOTE: Ignoring on windows because it's using WSL which wants a Linux path // with / but thinks it's on windows and is providing \ #[rstest] -#[tokio::test] +#[test(tokio::test)] #[cfg_attr(windows, ignore)] async fn proc_stdin_should_send_stdin_to_process(#[future] launched_client: Ctx) { let mut client = launched_client.await; @@ -1455,7 +1456,7 @@ async fn proc_stdin_should_send_stdin_to_process(#[future] launched_client: Ctx< } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn system_info_should_return_system_info_based_on_binary( #[future] launched_client: Ctx, ) { diff --git a/distant-ssh2/tests/ssh2/ssh.rs b/distant-ssh2/tests/ssh2/ssh.rs index 463b678..2990f50 100644 --- a/distant-ssh2/tests/ssh2/ssh.rs +++ b/distant-ssh2/tests/ssh2/ssh.rs @@ -1,9 +1,10 @@ use crate::sshd::*; use distant_ssh2::{Ssh, SshFamily}; use rstest::*; +use test_log::test; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn detect_family_should_return_windows_if_sshd_on_windows(#[future] ssh: Ctx) { let ssh = ssh.await; let family = ssh.detect_family().await.expect("Failed to detect family"); diff --git a/distant-ssh2/tests/sshd/mod.rs b/distant-ssh2/tests/sshd/mod.rs index 61073bf..a9738a2 100644 --- a/distant-ssh2/tests/sshd/mod.rs +++ b/distant-ssh2/tests/sshd/mod.rs @@ -6,7 +6,7 @@ use derive_more::Display; use derive_more::{Deref, DerefMut}; use distant_core::DistantClient; use distant_ssh2::{DistantLaunchOpts, Ssh, SshAuthEvent, SshAuthHandler, SshOpts}; -use once_cell::sync::{Lazy, OnceCell}; +use once_cell::sync::Lazy; use rstest::*; use std::{ collections::HashMap, @@ -61,8 +61,8 @@ impl SshKeygen { passphrase: impl AsRef, ) -> anyhow::Result { let res = Command::new("ssh-keygen") - .args(&["-m", "PEM"]) - .args(&["-t", "ed25519"]) + .args(["-m", "PEM"]) + .args(["-t", "ed25519"]) .arg("-f") .arg(path.as_ref()) .arg("-N") @@ -537,19 +537,6 @@ impl Drop for Sshd { } } -#[fixture] -pub fn logger() -> &'static flexi_logger::LoggerHandle { - static LOGGER: OnceCell = OnceCell::new(); - - LOGGER.get_or_init(|| { - // flexi_logger::Logger::try_with_str("off, distant_core=trace, distant_ssh2=trace") - flexi_logger::Logger::try_with_str("off, distant_core=warn, distant_ssh2=warn") - .expect("Failed to load env") - .start() - .expect("Failed to start logger") - }) -} - /// Mocked version of [`SshAuthHandler`] pub struct MockSshAuthHandler; @@ -614,7 +601,7 @@ pub fn sshd() -> Sshd { /// Fixture to establish a client to an SSH server #[fixture] -pub async fn client(sshd: Sshd, _logger: &'_ flexi_logger::LoggerHandle) -> Ctx { +pub async fn client(sshd: Sshd) -> Ctx { let ssh_client = load_ssh_client(&sshd).await; let client = ssh_client .into_distant_client() @@ -629,10 +616,7 @@ pub async fn client(sshd: Sshd, _logger: &'_ flexi_logger::LoggerHandle) -> Ctx< /// Fixture to establish a client to a launched server #[fixture] -pub async fn launched_client( - sshd: Sshd, - _logger: &'_ flexi_logger::LoggerHandle, -) -> Ctx { +pub async fn launched_client(sshd: Sshd) -> Ctx { let binary = std::env::var("DISTANT_PATH").unwrap_or_else(|_| String::from("distant")); eprintln!("Setting path to distant binary as {binary}"); diff --git a/src/cli/cache.rs b/src/cli/cache.rs index b642fe9..1c7d4db 100644 --- a/src/cli/cache.rs +++ b/src/cli/cache.rs @@ -1,6 +1,6 @@ use crate::paths::user::CACHE_FILE_PATH; use anyhow::Context; -use distant_core::ConnectionId; +use distant_core::net::common::ConnectionId; use serde::{Deserialize, Serialize}; use std::{ io, diff --git a/src/cli/client.rs b/src/cli/client.rs index 6d12b19..25175c8 100644 --- a/src/cli/client.rs +++ b/src/cli/client.rs @@ -1,128 +1,73 @@ use crate::config::NetworkConfig; -use anyhow::Context; -use distant_core::{ - net::{AuthRequest, AuthResponse, FramedTransport, PlainCodec}, - DistantManagerClient, DistantManagerClientConfig, +use async_trait::async_trait; +use distant_core::net::client::{Client as NetClient, ReconnectStrategy}; +use distant_core::net::common::authentication::msg::*; +use distant_core::net::common::authentication::{ + AuthHandler, AuthMethodHandler, PromptAuthMethodHandler, SingleAuthHandler, }; +use distant_core::net::manager::ManagerClient; use log::*; +use std::io; +use std::time::Duration; mod msg; pub use msg::*; -pub struct Client { - config: DistantManagerClientConfig, +pub struct Client { network: NetworkConfig, + auth_handler: T, } -impl Client { +impl Client<()> { pub fn new(network: NetworkConfig) -> Self { - let config = DistantManagerClientConfig::with_prompts( - |prompt| rpassword::prompt_password(prompt), - |prompt| { - use std::io::Write; - eprint!("{}", prompt); - std::io::stderr().lock().flush()?; - - let mut answer = String::new(); - std::io::stdin().read_line(&mut answer)?; - Ok(answer) - }, - ); - Self { config, network } - } - - /// Configure client to talk over stdin and stdout using messages - pub fn using_msg_stdin_stdout(self) -> Self { - self.using_msg(MsgSender::from_stdout(), MsgReceiver::from_stdin()) - } - - /// Configure client to use a pair of msg sender and receiver - pub fn using_msg(mut self, tx: MsgSender, rx: MsgReceiver) -> Self { - self.config = DistantManagerClientConfig { - on_challenge: { - let tx = tx.clone(); - let rx = rx.clone(); - Box::new(move |questions, options| { - let question_cnt = questions.len(); - - if let Err(x) = tx.send_blocking(&AuthRequest::Challenge { questions, options }) - { - error!("{}", x); - return (0..question_cnt) - .into_iter() - .map(|_| "".to_string()) - .collect(); - } + Self { + network, + auth_handler: (), + } + } +} - match rx.recv_blocking() { - Ok(AuthResponse::Challenge { answers }) => answers, - Ok(x) => { - error!("Invalid response received: {:?}", x); - (0..question_cnt) - .into_iter() - .map(|_| "".to_string()) - .collect() - } - Err(x) => { - error!("{}", x); - (0..question_cnt) - .into_iter() - .map(|_| "".to_string()) - .collect() - } - } - }) - }, - on_info: { - let tx = tx.clone(); - Box::new(move |text| { - let _ = tx.send_blocking(&AuthRequest::Info { text }); - }) - }, - on_verify: { - let tx = tx.clone(); - Box::new(move |kind, text| { - if let Err(x) = tx.send_blocking(&AuthRequest::Verify { kind, text }) { - error!("{}", x); - return false; - } +impl Client { + pub fn using_json_auth_handler(self) -> Client { + Client { + network: self.network, + auth_handler: JsonAuthHandler::default(), + } + } - match rx.recv_blocking() { - Ok(AuthResponse::Verify { valid }) => valid, - Ok(x) => { - error!("Invalid response received: {:?}", x); - false - } - Err(x) => { - error!("{}", x); - false - } - } - }) - }, - on_error: { - Box::new(move |kind, text| { - let _ = tx.send_blocking(&AuthRequest::Error { kind, text }); - }) - }, - }; - self + pub fn using_prompt_auth_handler(self) -> Client { + Client { + network: self.network, + auth_handler: PromptAuthHandler::new(), + } } +} +impl Client { /// Connect to the manager listening on the socket or windows pipe based on /// the [`NetworkConfig`] provided to the client earlier. Will return a new instance - /// of the [`DistantManagerClient`] upon successful connection - pub async fn connect(self) -> anyhow::Result { + /// of the [`ManagerClient`] upon successful connection + pub async fn connect(self) -> anyhow::Result { #[cfg(unix)] - let transport = { - use distant_core::net::UnixSocketTransport; - let mut maybe_transport = None; + { + let mut maybe_client = None; let mut error: Option = None; for path in self.network.to_unix_socket_path_candidates() { - match UnixSocketTransport::connect(path).await { - Ok(transport) => { + match NetClient::unix_socket(path) + .auth_handler(self.auth_handler.clone()) + .reconnect_strategy(ReconnectStrategy::ExponentialBackoff { + base: Duration::from_secs(1), + factor: 2.0, + max_duration: None, + max_retries: None, + timeout: None, + }) + .connect() + .await + { + Ok(client) => { info!("Connected to unix socket @ {:?}", path); - maybe_transport = Some(FramedTransport::new(transport, PlainCodec)); + maybe_client = Some(client); break; } Err(x) => { @@ -137,21 +82,31 @@ impl Client { } } - maybe_transport.ok_or_else(|| { + Ok(maybe_client.ok_or_else(|| { error.unwrap_or_else(|| anyhow::anyhow!("No unix socket candidate available")) - })? - }; + })?) + } #[cfg(windows)] - let transport = { - use distant_core::net::WindowsPipeTransport; - let mut maybe_transport = None; + { + let mut maybe_client = None; let mut error: Option = None; for name in self.network.to_windows_pipe_name_candidates() { - match WindowsPipeTransport::connect_local(name).await { - Ok(transport) => { - info!("Connected to named windows socket @ {:?}", name); - maybe_transport = Some(FramedTransport::new(transport, PlainCodec)); + match NetClient::local_windows_pipe(name) + .auth_handler(self.auth_handler.clone()) + .reconnect_strategy(ReconnectStrategy::ExponentialBackoff { + base: Duration::from_secs(1), + factor: 2.0, + max_duration: None, + max_retries: None, + timeout: None, + }) + .connect() + .await + { + Ok(client) => { + info!("Connected to named windows pipe @ {:?}", name); + maybe_client = Some(client); break; } Err(x) => { @@ -166,12 +121,176 @@ impl Client { } } - maybe_transport.ok_or_else(|| { + Ok(maybe_client.ok_or_else(|| { error.unwrap_or_else(|| anyhow::anyhow!("No windows pipe candidate available")) - })? - }; + })?) + } + } +} + +/// Implementation of [`AuthHandler`] that communicates over JSON. +#[derive(Clone)] +pub struct JsonAuthHandler { + tx: MsgSender, + rx: MsgReceiver, +} + +impl JsonAuthHandler { + pub fn new(tx: MsgSender, rx: MsgReceiver) -> Self { + Self { tx, rx } + } +} + +impl Default for JsonAuthHandler { + fn default() -> Self { + Self::new(MsgSender::from_stdout(), MsgReceiver::from_stdin()) + } +} + +#[async_trait] +impl AuthHandler for JsonAuthHandler { + async fn on_initialization( + &mut self, + initialization: Initialization, + ) -> io::Result { + self.tx + .send_blocking(&Authentication::Initialization(initialization))?; + let response = self.rx.recv_blocking::()?; + + match response { + AuthenticationResponse::Initialization(x) => Ok(x), + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Unexpected response: {x:?}"), + )), + } + } + + async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> { + self.tx + .send_blocking(&Authentication::StartMethod(start_method))?; + Ok(()) + } + + async fn on_finished(&mut self) -> io::Result<()> { + self.tx.send_blocking(&Authentication::Finished)?; + Ok(()) + } +} + +#[async_trait] +impl AuthMethodHandler for JsonAuthHandler { + async fn on_challenge(&mut self, challenge: Challenge) -> io::Result { + self.tx + .send_blocking(&Authentication::Challenge(challenge))?; + let response = self.rx.recv_blocking::()?; + + match response { + AuthenticationResponse::Challenge(x) => Ok(x), + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Unexpected response: {x:?}"), + )), + } + } + + async fn on_verification( + &mut self, + verification: Verification, + ) -> io::Result { + self.tx + .send_blocking(&Authentication::Verification(verification))?; + let response = self.rx.recv_blocking::()?; + + match response { + AuthenticationResponse::Verification(x) => Ok(x), + x => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Unexpected response: {x:?}"), + )), + } + } + + async fn on_info(&mut self, info: Info) -> io::Result<()> { + self.tx.send_blocking(&Authentication::Info(info))?; + Ok(()) + } + + async fn on_error(&mut self, error: Error) -> io::Result<()> { + self.tx.send_blocking(&Authentication::Error(error))?; + Ok(()) + } +} + +/// Implementation of [`AuthHandler`] that uses prompts to perform authentication requests and +/// notification of different information. +pub struct PromptAuthHandler(Box); + +impl PromptAuthHandler { + pub fn new() -> Self { + Self(Box::new(SingleAuthHandler::new( + PromptAuthMethodHandler::new( + |prompt: &str| { + eprintln!("{prompt}"); + let mut line = String::new(); + std::io::stdin().read_line(&mut line)?; + Ok(line) + }, + |prompt: &str| rpassword::prompt_password(prompt), + ), + ))) + } +} + +impl Clone for PromptAuthHandler { + /// Clones a new copy of the handler. + /// + /// ### Note + /// + /// This is a hack so we can use this handler elsewhere. Because this handler only has a new + /// method that creates a new instance, we treat it like a clone and just create an entirely + /// new prompt auth handler since there is no actual state to clone. + fn clone(&self) -> Self { + Self::new() + } +} + +#[async_trait] +impl AuthHandler for PromptAuthHandler { + async fn on_initialization( + &mut self, + initialization: Initialization, + ) -> io::Result { + self.0.on_initialization(initialization).await + } + + async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> { + self.0.on_start_method(start_method).await + } + + async fn on_finished(&mut self) -> io::Result<()> { + self.0.on_finished().await + } +} + +#[async_trait] +impl AuthMethodHandler for PromptAuthHandler { + async fn on_challenge(&mut self, challenge: Challenge) -> io::Result { + self.0.on_challenge(challenge).await + } + + async fn on_verification( + &mut self, + verification: Verification, + ) -> io::Result { + self.0.on_verification(verification).await + } + + async fn on_info(&mut self, info: Info) -> io::Result<()> { + self.0.on_info(info).await + } - DistantManagerClient::new(self.config, transport) - .context("Failed to create client for manager") + async fn on_error(&mut self, error: Error) -> io::Result<()> { + self.0.on_error(error).await } } diff --git a/src/cli/commands/client.rs b/src/cli/commands/client.rs index 459aa9d..324c0d1 100644 --- a/src/cli/commands/client.rs +++ b/src/cli/commands/client.rs @@ -1,6 +1,6 @@ use crate::{ cli::{ - client::{MsgReceiver, MsgSender}, + client::{JsonAuthHandler, MsgReceiver, MsgSender, PromptAuthHandler}, Cache, Client, }, config::{ @@ -15,9 +15,9 @@ use clap::{Subcommand, ValueHint}; use dialoguer::{console::Term, theme::ColorfulTheme, Select}; use distant_core::{ data::{ChangeKindSet, Environment}, - net::{IntoSplit, Request, Response, TypedAsyncRead, TypedAsyncWrite}, - ConnectionId, Destination, DistantManagerClient, DistantMsg, DistantRequestData, - DistantResponseData, Host, Map, RemoteCommand, Searcher, Watcher, + net::common::{ConnectionId, Destination, Host, Map, Request, Response}, + net::manager::ManagerClient, + DistantMsg, DistantRequestData, DistantResponseData, RemoteCommand, Searcher, Watcher, }; use log::*; use serde_json::{json, Value}; @@ -26,6 +26,7 @@ use std::{ path::{Path, PathBuf}, time::Duration, }; +use tokio::sync::mpsc; mod buf; mod format; @@ -40,6 +41,8 @@ use link::RemoteProcessLink; use lsp::Lsp; use shell::Shell; +const SLEEP_DURATION: Duration = Duration::from_millis(1); + #[derive(Debug, Subcommand)] pub enum ClientSubcommand { /// Performs some action on a remote machine @@ -257,6 +260,7 @@ impl ClientSubcommand { let network = network.merge(config.network); debug!("Connecting to manager"); let mut client = Client::new(network) + .using_prompt_auth_handler() .connect() .await .context("Failed to connect to manager")?; @@ -265,9 +269,12 @@ impl ClientSubcommand { use_or_lookup_connection_id(&mut cache, connection, &mut client).await?; debug!("Opening channel to connection {}", connection_id); - let mut channel = client.open_channel(connection_id).await.with_context(|| { - format!("Failed to open channel to connection {connection_id}") - })?; + let channel = client + .open_raw_channel(connection_id) + .await + .with_context(|| { + format!("Failed to open channel to connection {connection_id}") + })?; let timeout = action_config.timeout.or(config.action.timeout); @@ -296,7 +303,7 @@ impl ClientSubcommand { .current_dir(current_dir) .persist(persist) .pty(pty) - .spawn(channel, cmd.as_str()) + .spawn(channel.into_client().into_channel(), cmd.as_str()) .await .with_context(|| format!("Failed to spawn {cmd}"))?; @@ -322,9 +329,10 @@ impl ClientSubcommand { } DistantRequestData::Search { query } => { debug!("Special request creating searcher for {:?}", query); - let mut searcher = Searcher::search(channel, query) - .await - .context("Failed to start search")?; + let mut searcher = + Searcher::search(channel.into_client().into_channel(), query) + .await + .context("Failed to start search")?; // Continue to receive and process matches while let Some(m) = searcher.next().await { @@ -348,7 +356,7 @@ impl ClientSubcommand { } => { debug!("Special request creating watcher for {:?}", path); let mut watcher = Watcher::watch( - channel, + channel.into_client().into_channel(), path.as_path(), recursive, only.into_iter().collect::(), @@ -370,6 +378,8 @@ impl ClientSubcommand { } request => { let response = channel + .into_client() + .into_channel() .send_timeout( DistantMsg::Single(request), timeout @@ -409,15 +419,17 @@ impl ClientSubcommand { } => { let network = network.merge(config.network); debug!("Connecting to manager"); - let mut client = { - let client = match format { - Format::Shell => Client::new(network), - Format::Json => Client::new(network).using_msg_stdin_stdout(), - }; - client + let mut client = match format { + Format::Shell => Client::new(network) + .using_prompt_auth_handler() + .connect() + .await + .context("Failed to connect to manager")?, + Format::Json => Client::new(network) + .using_json_auth_handler() .connect() .await - .context("Failed to connect to manager")? + .context("Failed to connect to manager")?, }; // Merge our connect configs, overwriting anything in the config file with our cli @@ -427,10 +439,16 @@ impl ClientSubcommand { // Trigger our manager to connect to the launched server debug!("Connecting to server at {} with {}", destination, options); - let id = client - .connect(*destination, options) - .await - .context("Failed to connect to server")?; + let id = match format { + Format::Shell => client + .connect(*destination, options, PromptAuthHandler::new()) + .await + .context("Failed to connect to server")?, + Format::Json => client + .connect(*destination, options, JsonAuthHandler::default()) + .await + .context("Failed to connect to server")?, + }; // Mark the server's id as the new default debug!("Updating selected connection id in cache to {}", id); @@ -458,15 +476,17 @@ impl ClientSubcommand { } => { let network = network.merge(config.network); debug!("Connecting to manager"); - let mut client = { - let client = match format { - Format::Shell => Client::new(network), - Format::Json => Client::new(network).using_msg_stdin_stdout(), - }; - client + let mut client = match format { + Format::Shell => Client::new(network) + .using_prompt_auth_handler() .connect() .await - .context("Failed to connect to manager")? + .context("Failed to connect to manager")?, + Format::Json => Client::new(network) + .using_json_auth_handler() + .connect() + .await + .context("Failed to connect to manager")?, }; // Merge our launch configs, overwriting anything in the config file @@ -488,10 +508,16 @@ impl ClientSubcommand { // Start the server using our manager debug!("Launching server at {} with {}", destination, options); - let mut new_destination = client - .launch(*destination, options) - .await - .context("Failed to launch server")?; + let mut new_destination = match format { + Format::Shell => client + .launch(*destination, options, PromptAuthHandler::new()) + .await + .context("Failed to launch server")?, + Format::Json => client + .launch(*destination, options, JsonAuthHandler::default()) + .await + .context("Failed to launch server")?, + }; // Update the new destination with our previously-used host if the // new host is not globally-accessible @@ -511,10 +537,16 @@ impl ClientSubcommand { // Trigger our manager to connect to the launched server debug!("Connecting to server at {}", new_destination); - let id = client - .connect(new_destination, Map::new()) - .await - .context("Failed to connect to server")?; + let id = match format { + Format::Shell => client + .connect(new_destination, Map::new(), PromptAuthHandler::new()) + .await + .context("Failed to connect to server")?, + Format::Json => client + .connect(new_destination, Map::new(), JsonAuthHandler::default()) + .await + .context("Failed to connect to server")?, + }; // Mark the server's id as the new default debug!("Updating selected connection id in cache to {}", id); @@ -544,6 +576,7 @@ impl ClientSubcommand { let network = network.merge(config.network); debug!("Connecting to manager"); let mut client = Client::new(network) + .using_prompt_auth_handler() .connect() .await .context("Failed to connect to manager")?; @@ -552,15 +585,20 @@ impl ClientSubcommand { use_or_lookup_connection_id(&mut cache, connection, &mut client).await?; debug!("Opening channel to connection {}", connection_id); - let channel = client.open_channel(connection_id).await.with_context(|| { - format!("Failed to open channel to connection {connection_id}") - })?; + let channel = client + .open_raw_channel(connection_id) + .await + .with_context(|| { + format!("Failed to open channel to connection {connection_id}") + })?; debug!( "Spawning LSP server (persist = {}, pty = {}): {}", persist, pty, cmd ); - Lsp::new(channel).spawn(cmd, persist, pty).await?; + Lsp::new(channel.into_client().into_channel()) + .spawn(cmd, persist, pty) + .await?; } Self::Repl { config: repl_config, @@ -569,10 +607,17 @@ impl ClientSubcommand { format, .. } => { + // TODO: Support shell format? + if !format.is_json() { + return Err(CliError::Error(anyhow::anyhow!( + "Only JSON format is supported" + ))); + } + let network = network.merge(config.network); debug!("Connecting to manager"); let mut client = Client::new(network) - .using_msg_stdin_stdout() + .using_json_auth_handler() .connect() .await .context("Failed to connect to manager")?; @@ -583,12 +628,13 @@ impl ClientSubcommand { let timeout = repl_config.timeout.or(config.repl.timeout); debug!("Opening raw channel to connection {}", connection_id); - let channel = client - .open_raw_channel(connection_id) - .await - .with_context(|| { - format!("Failed to open raw channel to connection {connection_id}") - })?; + let mut channel = + client + .open_raw_channel(connection_id) + .await + .with_context(|| { + format!("Failed to open raw channel to connection {connection_id}") + })?; debug!( "Timeout configured to be {}", @@ -598,32 +644,18 @@ impl ClientSubcommand { } ); - // TODO: Support shell format? - if !format.is_json() { - return Err(CliError::Error(anyhow::anyhow!( - "Only JSON format is supported" - ))); - } - debug!("Starting repl using format {:?}", format); - let (mut writer, mut reader) = channel.transport.into_split(); - let response_task = tokio::task::spawn(async move { - let tx = MsgSender::from_stdout(); - while let Some(response) = reader.read().await? { - debug!("Received response {:?}", response); - tx.send_blocking(&response)?; - } - io::Result::Ok(()) - }); - + let (msg_tx, mut msg_rx) = mpsc::channel(1); let request_task = tokio::spawn(async move { let mut rx = MsgReceiver::from_stdin() .into_rx::>>(); loop { match rx.recv().await { Some(Ok(request)) => { - debug!("Sending request {:?}", request); - writer.write(request).await?; + if let Err(x) = msg_tx.send(request).await { + error!("Failed to forward request: {x}"); + break; + } } Some(Err(x)) => error!("{}", x), None => { @@ -634,8 +666,62 @@ impl ClientSubcommand { } io::Result::Ok(()) }); + let channel_task = tokio::task::spawn(async move { + let tx = MsgSender::from_stdout(); - let (r1, r2) = tokio::join!(request_task, response_task); + loop { + let ready = channel.readable_or_writeable().await?; + + // Keep track of whether we read or wrote anything + let mut read_blocked = !ready.is_readable(); + let mut write_blocked = !ready.is_writable(); + + if ready.is_readable() { + match channel + .try_read_frame_as::>>() + { + Ok(Some(msg)) => tx.send_blocking(&msg)?, + Ok(None) => break, + Err(x) if x.kind() == io::ErrorKind::WouldBlock => { + read_blocked = true; + } + Err(x) => return Err(x), + } + } + + if ready.is_writable() { + if let Ok(msg) = msg_rx.try_recv() { + match channel.try_write_frame_for(&msg) { + Ok(_) => (), + Err(x) if x.kind() == io::ErrorKind::WouldBlock => { + write_blocked = true + } + Err(x) => return Err(x), + } + } else { + match channel.try_flush() { + Ok(0) => write_blocked = true, + Ok(_) => (), + Err(x) if x.kind() == io::ErrorKind::WouldBlock => { + write_blocked = true + } + Err(x) => { + error!("Failed to flush outgoing data: {x}"); + } + } + } + } + + // If we did not read or write anything, sleep a bit to offload CPU usage + if read_blocked && write_blocked { + tokio::time::sleep(SLEEP_DURATION).await; + } + } + + io::Result::Ok(()) + }); + + let (r1, r2) = tokio::join!(request_task, channel_task); match r1 { Err(x) => error!("{}", x), Ok(Err(x)) => error!("{}", x), @@ -662,10 +748,18 @@ impl ClientSubcommand { None => { let network = network.merge(config.network); debug!("Connecting to manager"); - let mut client = Client::new(network) - .connect() - .await - .context("Failed to connect to manager")?; + let mut client = match format { + Format::Json => Client::new(network) + .using_json_auth_handler() + .connect() + .await + .context("Failed to connect to manager")?, + Format::Shell => Client::new(network) + .using_prompt_auth_handler() + .connect() + .await + .context("Failed to connect to manager")?, + }; let list = client .list() .await @@ -782,6 +876,7 @@ impl ClientSubcommand { let network = network.merge(config.network); debug!("Connecting to manager"); let mut client = Client::new(network) + .using_prompt_auth_handler() .connect() .await .context("Failed to connect to manager")?; @@ -790,9 +885,12 @@ impl ClientSubcommand { use_or_lookup_connection_id(&mut cache, connection, &mut client).await?; debug!("Opening channel to connection {}", connection_id); - let channel = client.open_channel(connection_id).await.with_context(|| { - format!("Failed to open channel to connection {connection_id}") - })?; + let channel = client + .open_raw_channel(connection_id) + .await + .with_context(|| { + format!("Failed to open channel to connection {connection_id}") + })?; debug!( "Spawning shell (environment = {:?}, persist = {}): {}", @@ -800,7 +898,9 @@ impl ClientSubcommand { persist, cmd.as_deref().unwrap_or(r"$SHELL") ); - Shell::new(channel).spawn(cmd, environment, persist).await?; + Shell::new(channel.into_client().into_channel()) + .spawn(cmd, environment, persist) + .await?; } } @@ -811,7 +911,7 @@ impl ClientSubcommand { async fn use_or_lookup_connection_id( cache: &mut Cache, connection: Option, - client: &mut DistantManagerClient, + client: &mut ManagerClient, ) -> anyhow::Result { match connection { Some(id) => { diff --git a/src/cli/commands/client/format.rs b/src/cli/commands/client/format.rs index 920ec7b..f8de8dc 100644 --- a/src/cli/commands/client/format.rs +++ b/src/cli/commands/client/format.rs @@ -4,7 +4,7 @@ use distant_core::{ ChangeKind, DistantMsg, DistantResponseData, Error, FileType, Metadata, SearchQueryContentsMatch, SearchQueryMatch, SearchQueryPathMatch, SystemInfo, }, - net::Response, + net::common::Response, }; use log::*; use std::{ diff --git a/src/cli/commands/generate.rs b/src/cli/commands/generate.rs index 5dbda26..23b5e90 100644 --- a/src/cli/commands/generate.rs +++ b/src/cli/commands/generate.rs @@ -3,7 +3,7 @@ use anyhow::Context; use clap::{CommandFactory, Subcommand}; use clap_complete::{generate as clap_generate, Shell}; use distant_core::{ - net::{Request, Response}, + net::common::{Request, Response}, DistantMsg, DistantRequestData, DistantResponseData, }; use std::{fs, io, path::PathBuf}; diff --git a/src/cli/commands/manager.rs b/src/cli/commands/manager.rs index a55dfae..160e1fb 100644 --- a/src/cli/commands/manager.rs +++ b/src/cli/commands/manager.rs @@ -6,14 +6,15 @@ use crate::{ }; use anyhow::Context; use clap::{Subcommand, ValueHint}; -use distant_core::{net::ServerRef, ConnectionId, DistantManagerConfig}; +use distant_core::net::common::ConnectionId; +use distant_core::net::manager::{Config as NetManagerConfig, ConnectHandler, LaunchHandler}; use log::*; use once_cell::sync::Lazy; use service_manager::{ ServiceInstallCtx, ServiceLabel, ServiceLevel, ServiceManager, ServiceManagerKind, ServiceStartCtx, ServiceStopCtx, ServiceUninstallCtx, }; -use std::{ffi::OsString, path::PathBuf}; +use std::{collections::HashMap, ffi::OsString, path::PathBuf}; use tabled::{Table, Tabled}; /// [`ServiceLabel`] for our manager in the form `rocks.distant.manager` @@ -83,12 +84,6 @@ pub enum ManagerSubcommand { network: NetworkConfig, id: ConnectionId, }, - - /// Send a shutdown request to the manager - Shutdown { - #[clap(flatten)] - network: NetworkConfig, - }, } #[derive(Debug, Subcommand)] @@ -289,8 +284,37 @@ impl ManagerSubcommand { ); let manager_ref = Manager { access, - config: DistantManagerConfig { + config: NetManagerConfig { user, + launch_handlers: { + let mut handlers: HashMap> = + HashMap::new(); + handlers.insert( + "manager".to_string(), + Box::new(handlers::ManagerLaunchHandler::new()), + ); + + #[cfg(any(feature = "libssh", feature = "ssh2"))] + handlers + .insert("ssh".to_string(), Box::new(handlers::SshLaunchHandler)); + + handlers + }, + connect_handlers: { + let mut handlers: HashMap> = + HashMap::new(); + + handlers.insert( + "distant".to_string(), + Box::new(handlers::DistantConnectHandler), + ); + + #[cfg(any(feature = "libssh", feature = "ssh2"))] + handlers + .insert("ssh".to_string(), Box::new(handlers::SshConnectHandler)); + + handlers + }, ..Default::default() }, network, @@ -299,33 +323,10 @@ impl ManagerSubcommand { .await .context("Failed to start manager")?; - // Register our handlers for different schemes - debug!("Registering handlers with manager"); - manager_ref - .register_launch_handler("manager", handlers::ManagerLaunchHandler::new()) - .await - .context("Failed to register launch handler for \"manager://\"")?; - manager_ref - .register_connect_handler("distant", handlers::DistantConnectHandler) - .await - .context("Failed to register connect handler for \"distant://\"")?; - - #[cfg(any(feature = "libssh", feature = "ssh2"))] - // Register ssh-specific handlers if either feature flag is enabled - { - manager_ref - .register_launch_handler("ssh", handlers::SshLaunchHandler) - .await - .context("Failed to register launch handler for \"ssh://\"")?; - manager_ref - .register_connect_handler("ssh", handlers::SshConnectHandler) - .await - .context("Failed to register connect handler for \"ssh://\"")?; - } - // Let our server run to completion manager_ref - .wait() + .as_ref() + .polling_wait() .await .context("Failed to wait on manager")?; info!("Manager is shutting down"); @@ -336,12 +337,14 @@ impl ManagerSubcommand { let network = network.merge(config.network); debug!("Getting list of capabilities"); let caps = Client::new(network) + .using_prompt_auth_handler() .connect() .await .context("Failed to connect to manager")? .capabilities() .await .context("Failed to get list of capabilities")?; + debug!("Got capabilities: {caps:?}"); #[derive(Tabled)] struct CapabilityRow { @@ -365,12 +368,14 @@ impl ManagerSubcommand { let network = network.merge(config.network); debug!("Getting info about connection {}", id); let info = Client::new(network) + .using_prompt_auth_handler() .connect() .await .context("Failed to connect to manager")? .info(id) .await .context("Failed to get info about connection")?; + debug!("Got info: {info:?}"); #[derive(Tabled)] struct InfoRow { @@ -402,12 +407,14 @@ impl ManagerSubcommand { let network = network.merge(config.network); debug!("Getting list of connections"); let list = Client::new(network) + .using_prompt_auth_handler() .connect() .await .context("Failed to connect to manager")? .list() .await .context("Failed to get list of connections")?; + debug!("Got list: {list:?}"); debug!("Looking up selected connection"); let selected = Cache::read_from_disk_or_default(cache) @@ -415,6 +422,7 @@ impl ManagerSubcommand { .context("Failed to look up selected connection")? .data .selected; + debug!("Using selected: {selected}"); #[derive(Tabled)] struct ListRow { @@ -444,24 +452,14 @@ impl ManagerSubcommand { let network = network.merge(config.network); debug!("Killing connection {}", id); Client::new(network) + .using_prompt_auth_handler() .connect() .await .context("Failed to connect to manager")? .kill(id) .await .with_context(|| format!("Failed to kill connection to server {id}"))?; - Ok(()) - } - Self::Shutdown { network } => { - let network = network.merge(config.network); - debug!("Shutting down manager"); - Client::new(network) - .connect() - .await - .context("Failed to connect to manager")? - .shutdown() - .await - .context("Failed to shutdown manager")?; + debug!("Connection killed"); Ok(()) } } diff --git a/src/cli/commands/manager/handlers.rs b/src/cli/commands/manager/handlers.rs index ee89f0f..2092c55 100644 --- a/src/cli/commands/manager/handlers.rs +++ b/src/cli/commands/manager/handlers.rs @@ -1,19 +1,20 @@ use crate::config::ClientLaunchConfig; use async_trait::async_trait; -use distant_core::{ - net::{ - AuthClient, AuthQuestion, FramedTransport, IntoSplit, SecretKey32, TcpTransport, - XChaCha20Poly1305Codec, - }, - BoxedDistantReader, BoxedDistantWriter, BoxedDistantWriterReader, ConnectHandler, Destination, - LaunchHandler, Map, +use distant_core::net::client::{Client, ReconnectStrategy, UntypedClient}; +use distant_core::net::common::authentication::msg::*; +use distant_core::net::common::authentication::{ + AuthHandler, Authenticator, DynAuthHandler, ProxyAuthHandler, SingleAuthHandler, + StaticKeyAuthMethodHandler, }; +use distant_core::net::common::{Destination, Map, SecretKey32}; +use distant_core::net::manager::{ConnectHandler, LaunchHandler}; use log::*; use std::{ io, net::{IpAddr, SocketAddr}, path::PathBuf, process::Stdio, + time::Duration, }; use tokio::{ io::{AsyncBufReadExt, BufReader}, @@ -50,7 +51,7 @@ impl LaunchHandler for ManagerLaunchHandler { &self, destination: &Destination, options: &Map, - _auth_client: &mut AuthClient, + _authenticator: &mut dyn Authenticator, ) -> io::Result { debug!("Handling launch of {destination} with options '{options}'"); let config = ClientLaunchConfig::from(options.clone()); @@ -163,14 +164,14 @@ impl LaunchHandler for SshLaunchHandler { &self, destination: &Destination, options: &Map, - auth_client: &mut AuthClient, + authenticator: &mut dyn Authenticator, ) -> io::Result { debug!("Handling launch of {destination} with options '{options}'"); let config = ClientLaunchConfig::from(options.clone()); use distant_ssh2::DistantLaunchOpts; let mut ssh = load_ssh(destination, options)?; - let handler = AuthClientSshAuthHandler::new(auth_client); + let handler = AuthClientSshAuthHandler::new(authenticator); let _ = ssh.authenticate(handler).await?; let opts = { let opts = DistantLaunchOpts::default(); @@ -196,14 +197,31 @@ impl LaunchHandler for SshLaunchHandler { pub struct DistantConnectHandler; impl DistantConnectHandler { - pub async fn try_connect(ips: Vec, port: u16) -> io::Result { + async fn try_connect( + ips: Vec, + port: u16, + mut auth_handler: impl AuthHandler, + ) -> io::Result { // Try each IP address with the same port to see if one works let mut err = None; for ip in ips { let addr = SocketAddr::new(ip, port); debug!("Attempting to connect to distant server @ {}", addr); - match TcpTransport::connect(addr).await { - Ok(transport) => return Ok(transport), + + match Client::tcp(addr) + .auth_handler(DynAuthHandler::from(&mut auth_handler)) + .reconnect_strategy(ReconnectStrategy::ExponentialBackoff { + base: Duration::from_secs(1), + factor: 2.0, + max_duration: None, + max_retries: None, + timeout: None, + }) + .timeout(Duration::from_secs(180)) + .connect_untyped() + .await + { + Ok(client) => return Ok(client), Err(x) => err = Some(x), } } @@ -219,8 +237,8 @@ impl ConnectHandler for DistantConnectHandler { &self, destination: &Destination, options: &Map, - auth_client: &mut AuthClient, - ) -> io::Result { + authenticator: &mut dyn Authenticator, + ) -> io::Result { debug!("Handling connect of {destination} with options '{options}'"); let host = destination.host.to_string(); let port = destination.port.ok_or_else(|| missing("port"))?; @@ -246,37 +264,24 @@ impl ConnectHandler for DistantConnectHandler { )); } - // Use provided password or options key if available, otherwise ask for it, and produce a - // codec using the key - let codec = { - let key = destination - .password - .as_deref() - .or_else(|| options.get("key").map(|s| s.as_str())); - - let key = match key { - Some(key) => key.parse::().map_err(|_| invalid("key"))?, - None => { - let answers = auth_client - .challenge(vec![AuthQuestion::new("key")], Default::default()) - .await?; - answers - .first() - .ok_or_else(|| missing("key"))? - .parse::() - .map_err(|_| invalid("key"))? - } - }; - XChaCha20Poly1305Codec::from(key) - }; - - // Establish a TCP connection, wrap it, and split it out into a writer and reader - let transport = Self::try_connect(candidate_ips, port).await?; - let transport = FramedTransport::new(transport, codec); - let (writer, reader) = transport.into_split(); - let writer: BoxedDistantWriter = Box::new(writer); - let reader: BoxedDistantReader = Box::new(reader); - Ok((writer, reader)) + // For legacy reasons, we need to support a static key being provided + // via part of the destination OR an option, and attempt to use it + // during authentication if it is provided + if let Some(key) = destination + .password + .as_deref() + .or_else(|| options.get("key").map(|s| s.as_str())) + { + let key = key.parse::().map_err(|_| invalid("key"))?; + Self::try_connect( + candidate_ips, + port, + SingleAuthHandler::new(StaticKeyAuthMethodHandler::simple(key)), + ) + .await + } else { + Self::try_connect(candidate_ips, port, ProxyAuthHandler::new(authenticator)).await + } } } @@ -291,23 +296,23 @@ impl ConnectHandler for SshConnectHandler { &self, destination: &Destination, options: &Map, - auth_client: &mut AuthClient, - ) -> io::Result { + authenticator: &mut dyn Authenticator, + ) -> io::Result { debug!("Handling connect of {destination} with options '{options}'"); let mut ssh = load_ssh(destination, options)?; - let handler = AuthClientSshAuthHandler::new(auth_client); + let handler = AuthClientSshAuthHandler::new(authenticator); let _ = ssh.authenticate(handler).await?; - ssh.into_distant_writer_reader().await + Ok(ssh.into_distant_client().await?.into_untyped_client()) } } #[cfg(any(feature = "libssh", feature = "ssh2"))] -struct AuthClientSshAuthHandler<'a>(Mutex<&'a mut AuthClient>); +struct AuthClientSshAuthHandler<'a>(Mutex<&'a mut dyn Authenticator>); #[cfg(any(feature = "libssh", feature = "ssh2"))] impl<'a> AuthClientSshAuthHandler<'a> { - pub fn new(auth_client: &'a mut AuthClient) -> Self { - Self(Mutex::new(auth_client)) + pub fn new(authenticator: &'a mut dyn Authenticator) -> Self { + Self(Mutex::new(authenticator)) } } @@ -322,7 +327,8 @@ impl<'a> distant_ssh2::SshAuthHandler for AuthClientSshAuthHandler<'a> { for prompt in event.prompts { let mut options = HashMap::new(); options.insert("echo".to_string(), prompt.echo.to_string()); - questions.push(AuthQuestion { + questions.push(Question { + label: "ssh-prompt".to_string(), text: prompt.prompt, options, }); @@ -331,31 +337,51 @@ impl<'a> distant_ssh2::SshAuthHandler for AuthClientSshAuthHandler<'a> { options.insert("instructions".to_string(), event.instructions); options.insert("username".to_string(), event.username); - self.0.lock().await.challenge(questions, options).await + Ok(self + .0 + .lock() + .await + .challenge(Challenge { questions, options }) + .await? + .answers) } async fn on_verify_host(&self, host: &str) -> io::Result { - use distant_core::net::AuthVerifyKind; - self.0 + Ok(self + .0 .lock() .await - .verify(AuthVerifyKind::Host, host.to_string()) - .await + .verify(Verification { + kind: VerificationKind::Host, + text: host.to_string(), + }) + .await? + .valid) } async fn on_banner(&self, text: &str) { - if let Err(x) = self.0.lock().await.info(text.to_string()).await { + if let Err(x) = self + .0 + .lock() + .await + .info(Info { + text: text.to_string(), + }) + .await + { error!("ssh on_banner failed: {}", x); } } async fn on_error(&self, text: &str) { - use distant_core::net::AuthErrorKind; if let Err(x) = self .0 .lock() .await - .error(AuthErrorKind::Unknown, text.to_string()) + .error(Error { + kind: ErrorKind::Fatal, + text: text.to_string(), + }) .await { error!("ssh on_error failed: {}", x); diff --git a/src/cli/commands/server.rs b/src/cli/commands/server.rs index 420d5f9..bca69c6 100644 --- a/src/cli/commands/server.rs +++ b/src/cli/commands/server.rs @@ -4,13 +4,10 @@ use crate::{ }; use anyhow::Context; use clap::Subcommand; -use distant_core::{ - net::{ - SecretKey32, ServerConfig as NetServerConfig, ServerRef, TcpServerExt, - XChaCha20Poly1305Codec, - }, - DistantApiServer, DistantSingleKeyCredentials, Host, -}; +use distant_core::net::common::authentication::Verifier; +use distant_core::net::common::{Host, SecretKey32}; +use distant_core::net::server::{Server, ServerConfig as NetServerConfig, ServerRef}; +use distant_core::{DistantApiServerHandler, DistantSingleKeyCredentials}; use log::*; use std::io::{self, Read, Write}; @@ -52,9 +49,8 @@ impl ServerSubcommand { #[cfg(windows)] fn run_daemon(self, _config: ServerConfig) -> CliResult { use crate::cli::Spawner; - use distant_core::net::{Listener, WindowsPipeListener}; + use distant_core::net::common::{Listener, TransportExt, WindowsPipeListener}; use std::ffi::OsString; - use tokio::io::AsyncReadExt; let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?; rt.block_on(async { let name = format!("distant_{}_{}", std::process::id(), rand::random::()); @@ -69,7 +65,7 @@ impl ServerSubcommand { println!("[distant server detached, pid = {}]", pid); // Wait to receive a connection from the above process - let mut transport = listener.accept().await.context( + let transport = listener.accept().await.context( "Failed to receive connection from background process to send credentials", )?; @@ -163,8 +159,6 @@ impl ServerSubcommand { SecretKey32::default() }; - let codec = XChaCha20Poly1305Codec::new(key.unprotected_as_bytes()); - debug!( "Starting local API server, binding to {} {}", addr, @@ -173,21 +167,26 @@ impl ServerSubcommand { None => "using an ephemeral port".to_string(), } ); - let server = DistantApiServer::local(NetServerConfig { - shutdown: get!(shutdown).unwrap_or_default(), - }) - .context("Failed to create local distant api")? - .start(addr, get!(port).unwrap_or_else(|| 0.into()), codec) - .await - .with_context(|| { - format!( - "Failed to start server @ {} with {}", - addr, - get!(port) - .map(|p| format!("port in range {p}")) - .unwrap_or_else(|| String::from("ephemeral port")) - ) - })?; + let handler = DistantApiServerHandler::local() + .context("Failed to create local distant api")?; + let server = Server::tcp() + .config(NetServerConfig { + shutdown: get!(shutdown).unwrap_or_default(), + ..Default::default() + }) + .handler(handler) + .verifier(Verifier::static_key(key.clone())) + .start(addr, get!(port).unwrap_or_else(|| 0.into())) + .await + .with_context(|| { + format!( + "Failed to start server @ {} with {}", + addr, + get!(port) + .map(|p| format!("port in range {p}")) + .unwrap_or_else(|| String::from("ephemeral port")) + ) + })?; let credentials = DistantSingleKeyCredentials { host: Host::from(addr), @@ -214,9 +213,8 @@ impl ServerSubcommand { #[cfg(windows)] if let Some(name) = output_to_local_pipe { - use distant_core::net::WindowsPipeTransport; - use tokio::io::AsyncWriteExt; - let mut transport = WindowsPipeTransport::connect_local(&name) + use distant_core::net::common::{TransportExt, WindowsPipeTransport}; + let transport = WindowsPipeTransport::connect_local(&name) .await .with_context(|| { format!("Failed to connect to local pipe named {name:?}") diff --git a/src/cli/manager.rs b/src/cli/manager.rs index e894cb2..16597f5 100644 --- a/src/cli/manager.rs +++ b/src/cli/manager.rs @@ -3,22 +3,25 @@ use crate::{ paths::{global as global_paths, user as user_paths}, }; use anyhow::Context; -use distant_core::{net::PlainCodec, DistantManager, DistantManagerConfig, DistantManagerRef}; +use distant_core::net::common::authentication::Verifier; +use distant_core::net::manager::{Config as ManagerConfig, ManagerServer}; +use distant_core::net::server::ServerRef; use log::*; pub struct Manager { pub access: AccessControl, - pub config: DistantManagerConfig, + pub config: ManagerConfig, pub network: NetworkConfig, } impl Manager { /// Begin listening on the network interface specified within [`NetworkConfig`] - pub async fn listen(self) -> anyhow::Result { + pub async fn listen(self) -> anyhow::Result> { let user = self.config.user; #[cfg(unix)] { + use distant_core::net::common::UnixSocketListener; let socket_path = self.network.unix_socket.as_deref().unwrap_or({ if user { user_paths::UNIX_SOCKET_PATH.as_path() @@ -34,41 +37,34 @@ impl Manager { .with_context(|| format!("Failed to create socket directory {parent:?}"))?; } - let boxed_ref = DistantManager::start_unix_socket_with_permissions( - self.config, - socket_path, - PlainCodec, - self.access.into_mode(), - ) - .await - .with_context(|| format!("Failed to start manager at socket {socket_path:?}"))? - .into_inner() - .into_boxed_server_ref() - .map_err(|_| anyhow::anyhow!("Got wrong server ref"))?; + let boxed_ref = ManagerServer::new(self.config) + .verifier(Verifier::none()) + .start( + UnixSocketListener::bind_with_permissions(socket_path, self.access.into_mode()) + .await?, + ) + .with_context(|| format!("Failed to start manager at socket {socket_path:?}"))?; info!("Manager listening using unix socket @ {:?}", socket_path); - Ok(*boxed_ref) + Ok(boxed_ref) } #[cfg(windows)] { + use distant_core::net::common::WindowsPipeListener; let pipe_name = self.network.windows_pipe.as_deref().unwrap_or(if user { user_paths::WINDOWS_PIPE_NAME.as_str() } else { global_paths::WINDOWS_PIPE_NAME.as_str() }); - let boxed_ref = - DistantManager::start_local_named_pipe(self.config, pipe_name, PlainCodec) - .await - .with_context(|| { - format!("Failed to start manager with pipe named '{pipe_name}'") - })? - .into_inner() - .into_boxed_server_ref() - .map_err(|_| anyhow::anyhow!("Got wrong server ref"))?; - info!("Manager listening using local named pipe @ {:?}", pipe_name); - Ok(*boxed_ref) + let boxed_ref = ManagerServer::new(self.config) + .verifier(Verifier::none()) + .start(WindowsPipeListener::bind_local(pipe_name)?) + .with_context(|| format!("Failed to start manager at pipe {pipe_name:?}"))?; + + info!("Manager listening using windows pipe @ {:?}", pipe_name); + Ok(boxed_ref) } } } diff --git a/src/config/client/connect.rs b/src/config/client/connect.rs index defd4a1..a9d3323 100644 --- a/src/config/client/connect.rs +++ b/src/config/client/connect.rs @@ -1,5 +1,5 @@ use clap::Args; -use distant_core::Map; +use distant_core::net::common::Map; use serde::{Deserialize, Serialize}; #[derive(Args, Debug, Default, Serialize, Deserialize)] diff --git a/src/config/client/launch.rs b/src/config/client/launch.rs index 42f048a..0e2287d 100644 --- a/src/config/client/launch.rs +++ b/src/config/client/launch.rs @@ -1,6 +1,6 @@ use crate::config::BindAddress; use clap::Args; -use distant_core::Map; +use distant_core::net::common::Map; use serde::{Deserialize, Serialize}; #[derive(Args, Debug, Default, Serialize, Deserialize)] diff --git a/src/config/manager.rs b/src/config/manager.rs index a9ad09b..95f64df 100644 --- a/src/config/manager.rs +++ b/src/config/manager.rs @@ -1,6 +1,6 @@ use super::{AccessControl, CommonConfig, NetworkConfig}; use clap::Args; -use distant_core::Destination; +use distant_core::net::common::Destination; use serde::{Deserialize, Serialize}; use service_manager::ServiceManagerKind; diff --git a/src/config/server/listen.rs b/src/config/server/listen.rs index 10bed5f..4d77f24 100644 --- a/src/config/server/listen.rs +++ b/src/config/server/listen.rs @@ -1,9 +1,7 @@ use anyhow::Context; use clap::Args; -use distant_core::{ - net::{PortRange, Shutdown}, - Host, HostParseError, Map, -}; +use distant_core::net::common::{Host, HostParseError, Map, PortRange}; +use distant_core::net::server::Shutdown; use serde::{Deserialize, Serialize}; use std::{ env, fmt, @@ -197,6 +195,7 @@ impl BindAddress { #[cfg(test)] mod tests { use super::*; + use test_log::test; #[test] fn to_string_should_properly_print_bind_address() { diff --git a/tests/cli/action/capabilities.rs b/tests/cli/action/capabilities.rs index b4f130d..db13d45 100644 --- a/tests/cli/action/capabilities.rs +++ b/tests/cli/action/capabilities.rs @@ -56,6 +56,7 @@ const EXPECTED_TABLE: &str = indoc! {" "}; #[rstest] +#[test_log::test] fn should_output_capabilities(mut action_cmd: CtxCommand) { // distant action capabilities action_cmd diff --git a/tests/cli/action/copy.rs b/tests/cli/action/copy.rs index 0aa02fc..01970d5 100644 --- a/tests/cli/action/copy.rs +++ b/tests/cli/action/copy.rs @@ -11,6 +11,7 @@ that is a file's contents "#; #[rstest] +#[test_log::test] fn should_support_copying_file(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -21,7 +22,7 @@ fn should_support_copying_file(mut action_cmd: CtxCommand) { // distant action copy {src} {dst} action_cmd - .args(&["copy", src.to_str().unwrap(), dst.to_str().unwrap()]) + .args(["copy", src.to_str().unwrap(), dst.to_str().unwrap()]) .assert() .success() .stdout("") @@ -32,6 +33,7 @@ fn should_support_copying_file(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn should_support_copying_nonempty_directory(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -46,7 +48,7 @@ fn should_support_copying_nonempty_directory(mut action_cmd: CtxCommand // distant action copy {src} {dst} action_cmd - .args(&["copy", src.to_str().unwrap(), dst.to_str().unwrap()]) + .args(["copy", src.to_str().unwrap(), dst.to_str().unwrap()]) .assert() .success() .stdout("") @@ -57,6 +59,7 @@ fn should_support_copying_nonempty_directory(mut action_cmd: CtxCommand } #[rstest] +#[test_log::test] fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -65,7 +68,7 @@ fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { // distant action copy {src} {dst} action_cmd - .args(&["copy", src.to_str().unwrap(), dst.to_str().unwrap()]) + .args(["copy", src.to_str().unwrap(), dst.to_str().unwrap()]) .assert() .code(1) .stdout("") diff --git a/tests/cli/action/dir_create.rs b/tests/cli/action/dir_create.rs index c6cf228..4a0d38d 100644 --- a/tests/cli/action/dir_create.rs +++ b/tests/cli/action/dir_create.rs @@ -5,13 +5,14 @@ use predicates::prelude::*; use rstest::*; #[rstest] +#[test_log::test] fn should_report_ok_when_done(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("dir"); // distant action dir-create {path} action_cmd - .args(&["dir-create", dir.to_str().unwrap()]) + .args(["dir-create", dir.to_str().unwrap()]) .assert() .success() .stdout("") @@ -22,6 +23,7 @@ fn should_report_ok_when_done(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn should_support_creating_missing_parent_directories_if_specified( mut action_cmd: CtxCommand, ) { @@ -30,7 +32,7 @@ fn should_support_creating_missing_parent_directories_if_specified( // distant action dir-create {path} action_cmd - .args(&["dir-create", "--all", dir.to_str().unwrap()]) + .args(["dir-create", "--all", dir.to_str().unwrap()]) .assert() .success() .stdout("") @@ -41,13 +43,14 @@ fn should_support_creating_missing_parent_directories_if_specified( } #[rstest] +#[test_log::test] fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("missing-dir").child("dir"); // distant action dir-create {path} action_cmd - .args(&["dir-create", dir.to_str().unwrap()]) + .args(["dir-create", dir.to_str().unwrap()]) .assert() .code(1) .stdout("") diff --git a/tests/cli/action/dir_read.rs b/tests/cli/action/dir_read.rs index 0d19a64..3906225 100644 --- a/tests/cli/action/dir_read.rs +++ b/tests/cli/action/dir_read.rs @@ -88,6 +88,7 @@ fn regex_line(ty: &str, path: &str) -> String { } #[rstest] +#[test_log::test] fn should_print_immediate_files_and_directories_by_default(mut action_cmd: CtxCommand) { let temp = make_directory(); @@ -100,7 +101,7 @@ fn should_print_immediate_files_and_directories_by_default(mut action_cmd: CtxCo // distant action dir-read {path} action_cmd - .args(&["dir-read", temp.to_str().unwrap()]) + .args(["dir-read", temp.to_str().unwrap()]) .assert() .success() .stdout(expected) @@ -109,6 +110,7 @@ fn should_print_immediate_files_and_directories_by_default(mut action_cmd: CtxCo // NOTE: Ignoring on windows because ssh2 doesn't properly canonicalize paths to resolve symlinks! #[rstest] +#[test_log::test] #[cfg_attr(windows, ignore)] fn should_use_absolute_paths_if_specified(mut action_cmd: CtxCommand) { let temp = make_directory(); @@ -126,7 +128,7 @@ fn should_use_absolute_paths_if_specified(mut action_cmd: CtxCommand) { // distant action dir-read --absolute {path} action_cmd - .args(&["dir-read", "--absolute", temp.to_str().unwrap()]) + .args(["dir-read", "--absolute", temp.to_str().unwrap()]) .assert() .success() .stdout(expected) @@ -135,6 +137,7 @@ fn should_use_absolute_paths_if_specified(mut action_cmd: CtxCommand) { // NOTE: Ignoring on windows because ssh2 doesn't properly canonicalize paths to resolve symlinks! #[rstest] +#[test_log::test] #[cfg_attr(windows, ignore)] fn should_print_all_files_and_directories_if_depth_is_0(mut action_cmd: CtxCommand) { let temp = make_directory(); @@ -172,7 +175,7 @@ fn should_print_all_files_and_directories_if_depth_is_0(mut action_cmd: CtxComma // distant action dir-read --depth 0 {path} action_cmd - .args(&["dir-read", "--depth", "0", temp.to_str().unwrap()]) + .args(["dir-read", "--depth", "0", temp.to_str().unwrap()]) .assert() .success() .stdout(expected) @@ -181,6 +184,7 @@ fn should_print_all_files_and_directories_if_depth_is_0(mut action_cmd: CtxComma // NOTE: Ignoring on windows because ssh2 doesn't properly canonicalize paths to resolve symlinks! #[rstest] +#[test_log::test] #[cfg_attr(windows, ignore)] fn should_include_root_directory_if_specified(mut action_cmd: CtxCommand) { let temp = make_directory(); @@ -199,7 +203,7 @@ fn should_include_root_directory_if_specified(mut action_cmd: CtxCommand) { let temp = make_directory(); let dir = temp.child("missing-dir"); // distant action dir-read {path} action_cmd - .args(&["dir-read", dir.to_str().unwrap()]) + .args(["dir-read", dir.to_str().unwrap()]) .assert() .code(1) .stdout("") diff --git a/tests/cli/action/exists.rs b/tests/cli/action/exists.rs index e5acaa3..d43c63d 100644 --- a/tests/cli/action/exists.rs +++ b/tests/cli/action/exists.rs @@ -4,6 +4,7 @@ use assert_fs::prelude::*; use rstest::*; #[rstest] +#[test_log::test] fn should_output_true_if_exists(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -13,7 +14,7 @@ fn should_output_true_if_exists(mut action_cmd: CtxCommand) { // distant action exists {path} action_cmd - .args(&["exists", file.to_str().unwrap()]) + .args(["exists", file.to_str().unwrap()]) .assert() .success() .stdout("true\n") @@ -21,6 +22,7 @@ fn should_output_true_if_exists(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn should_output_false_if_not_exists(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -29,7 +31,7 @@ fn should_output_false_if_not_exists(mut action_cmd: CtxCommand) { // distant action exists {path} action_cmd - .args(&["exists", file.to_str().unwrap()]) + .args(["exists", file.to_str().unwrap()]) .assert() .success() .stdout("false\n") diff --git a/tests/cli/action/file_append.rs b/tests/cli/action/file_append.rs index 3bb0a6c..8750825 100644 --- a/tests/cli/action/file_append.rs +++ b/tests/cli/action/file_append.rs @@ -15,6 +15,7 @@ file contents "#; #[rstest] +#[test_log::test] fn should_report_ok_when_done(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); @@ -22,7 +23,7 @@ fn should_report_ok_when_done(mut action_cmd: CtxCommand) { // distant action file-append {path} -- {contents} action_cmd - .args(&[ + .args([ "file-append", file.to_str().unwrap(), "--", @@ -41,13 +42,14 @@ fn should_report_ok_when_done(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); // distant action file-append {path} -- {contents} action_cmd - .args(&[ + .args([ "file-append", file.to_str().unwrap(), "--", diff --git a/tests/cli/action/file_append_text.rs b/tests/cli/action/file_append_text.rs index a603a97..b993be4 100644 --- a/tests/cli/action/file_append_text.rs +++ b/tests/cli/action/file_append_text.rs @@ -15,6 +15,7 @@ file contents "#; #[rstest] +#[test_log::test] fn should_report_ok_when_done(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); @@ -22,7 +23,7 @@ fn should_report_ok_when_done(mut action_cmd: CtxCommand) { // distant action file-append-text {path} -- {contents} action_cmd - .args(&[ + .args([ "file-append-text", file.to_str().unwrap(), "--", @@ -41,13 +42,14 @@ fn should_report_ok_when_done(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); // distant action file-append-text {path} -- {contents} action_cmd - .args(&[ + .args([ "file-append-text", file.to_str().unwrap(), "--", diff --git a/tests/cli/action/file_read.rs b/tests/cli/action/file_read.rs index 0d48da9..260f21e 100644 --- a/tests/cli/action/file_read.rs +++ b/tests/cli/action/file_read.rs @@ -10,6 +10,7 @@ that is a file's contents "#; #[rstest] +#[test_log::test] fn should_print_out_file_contents(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); @@ -17,7 +18,7 @@ fn should_print_out_file_contents(mut action_cmd: CtxCommand) { // distant action file-read {path} action_cmd - .args(&["file-read", file.to_str().unwrap()]) + .args(["file-read", file.to_str().unwrap()]) .assert() .success() .stdout(format!("{}\n", FILE_CONTENTS)) @@ -25,13 +26,14 @@ fn should_print_out_file_contents(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-file"); // distant action file-read {path} action_cmd - .args(&["file-read", file.to_str().unwrap()]) + .args(["file-read", file.to_str().unwrap()]) .assert() .code(1) .stdout("") diff --git a/tests/cli/action/file_read_text.rs b/tests/cli/action/file_read_text.rs index 141f69b..eda0f34 100644 --- a/tests/cli/action/file_read_text.rs +++ b/tests/cli/action/file_read_text.rs @@ -10,6 +10,7 @@ that is a file's contents "#; #[rstest] +#[test_log::test] fn should_print_out_file_contents(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); @@ -17,7 +18,7 @@ fn should_print_out_file_contents(mut action_cmd: CtxCommand) { // distant action file-read-text {path} action_cmd - .args(&["file-read-text", file.to_str().unwrap()]) + .args(["file-read-text", file.to_str().unwrap()]) .assert() .success() .stdout(format!("{}\n", FILE_CONTENTS)) @@ -25,13 +26,14 @@ fn should_print_out_file_contents(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-file"); // distant action file-read-text {path} action_cmd - .args(&["file-read-text", file.to_str().unwrap()]) + .args(["file-read-text", file.to_str().unwrap()]) .assert() .code(1) .stdout("") diff --git a/tests/cli/action/file_write.rs b/tests/cli/action/file_write.rs index 6b005c3..34385e5 100644 --- a/tests/cli/action/file_write.rs +++ b/tests/cli/action/file_write.rs @@ -10,13 +10,14 @@ that is a file's contents "#; #[rstest] +#[test_log::test] fn should_report_ok_when_done(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); // distant action file-write {path} -- {contents} action_cmd - .args(&["file-write", file.to_str().unwrap(), "--", FILE_CONTENTS]) + .args(["file-write", file.to_str().unwrap(), "--", FILE_CONTENTS]) .assert() .success() .stdout("") @@ -30,13 +31,14 @@ fn should_report_ok_when_done(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); // distant action file-write {path} -- {contents} action_cmd - .args(&["file-write", file.to_str().unwrap(), "--", FILE_CONTENTS]) + .args(["file-write", file.to_str().unwrap(), "--", FILE_CONTENTS]) .assert() .code(1) .stdout("") diff --git a/tests/cli/action/file_write_text.rs b/tests/cli/action/file_write_text.rs index 0bc546f..53476b4 100644 --- a/tests/cli/action/file_write_text.rs +++ b/tests/cli/action/file_write_text.rs @@ -10,13 +10,14 @@ that is a file's contents "#; #[rstest] +#[test_log::test] fn should_report_ok_when_done(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); // distant action file-write-text {path} -- {contents} action_cmd - .args(&[ + .args([ "file-write-text", file.to_str().unwrap(), "--", @@ -35,13 +36,14 @@ fn should_report_ok_when_done(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); // distant action file-write {path} -- {contents} action_cmd - .args(&[ + .args([ "file-write-text", file.to_str().unwrap(), "--", diff --git a/tests/cli/action/metadata.rs b/tests/cli/action/metadata.rs index 77f7581..fb4d3a3 100644 --- a/tests/cli/action/metadata.rs +++ b/tests/cli/action/metadata.rs @@ -13,6 +13,7 @@ that is a file's contents "#; #[rstest] +#[test_log::test] fn should_output_metadata_for_file(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -21,7 +22,7 @@ fn should_output_metadata_for_file(mut action_cmd: CtxCommand) { // distant action metadata {path} action_cmd - .args(&["metadata", file.to_str().unwrap()]) + .args(["metadata", file.to_str().unwrap()]) .assert() .success() .stdout(regex_pred(concat!( @@ -36,6 +37,7 @@ fn should_output_metadata_for_file(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn should_output_metadata_for_directory(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -44,7 +46,7 @@ fn should_output_metadata_for_directory(mut action_cmd: CtxCommand) { // distant action metadata {path} action_cmd - .args(&["metadata", dir.to_str().unwrap()]) + .args(["metadata", dir.to_str().unwrap()]) .assert() .success() .stdout(regex_pred(concat!( @@ -60,6 +62,7 @@ fn should_output_metadata_for_directory(mut action_cmd: CtxCommand) { // NOTE: Ignoring on windows because ssh2 doesn't properly canonicalize paths to resolve symlinks! #[rstest] +#[test_log::test] #[cfg_attr(windows, ignore)] fn should_support_including_a_canonicalized_path(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -72,7 +75,7 @@ fn should_support_including_a_canonicalized_path(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -102,7 +106,7 @@ fn should_support_resolving_file_type_of_symlink(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -125,7 +130,7 @@ fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { // distant action metadata {path} action_cmd - .args(&["metadata", file.to_str().unwrap()]) + .args(["metadata", file.to_str().unwrap()]) .assert() .code(1) .stdout("") diff --git a/tests/cli/action/proc_spawn.rs b/tests/cli/action/proc_spawn.rs index b87e9d3..e3a27d8 100644 --- a/tests/cli/action/proc_spawn.rs +++ b/tests/cli/action/proc_spawn.rs @@ -4,6 +4,7 @@ use rstest::*; use std::process::Command as StdCommand; #[rstest] +#[test_log::test] fn should_execute_program_and_return_exit_status(mut action_cmd: CtxCommand) { // Windows prints out a message whereas unix prints nothing #[cfg(windows)] @@ -13,7 +14,7 @@ fn should_execute_program_and_return_exit_status(mut action_cmd: CtxCommand) { // distant action proc-spawn {cmd} [args] action_cmd - .args(&["proc-spawn", "--"]) + .args(["proc-spawn", "--"]) .arg(SCRIPT_RUNNER.as_str()) .arg(SCRIPT_RUNNER_ARG.as_str()) .arg(ECHO_ARGS_TO_STDOUT.to_str().unwrap()) @@ -44,10 +46,11 @@ fn should_capture_and_print_stdout(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn should_capture_and_print_stderr(mut action_cmd: CtxCommand) { // distant action proc-spawn {cmd} [args] action_cmd - .args(&["proc-spawn", "--"]) + .args(["proc-spawn", "--"]) .arg(SCRIPT_RUNNER.as_str()) .arg(SCRIPT_RUNNER_ARG.as_str()) .arg(ECHO_ARGS_TO_STDERR.to_str().unwrap()) @@ -67,12 +70,13 @@ fn should_capture_and_print_stderr(mut action_cmd: CtxCommand) { // and then the process exiting. This may be a bug we've introduced with the // refactor and should be revisited some day. #[rstest] +#[test_log::test] fn should_forward_stdin_to_remote_process(mut action_std_cmd: CtxCommand) { use std::io::{BufRead, BufReader, Write}; // distant action proc-spawn {cmd} [args] let mut child = action_std_cmd - .args(&["proc-spawn", "--"]) + .args(["proc-spawn", "--"]) .arg(SCRIPT_RUNNER.as_str()) .arg(SCRIPT_RUNNER_ARG.as_str()) .arg(ECHO_STDIN_TO_STDOUT.to_str().unwrap()) @@ -106,6 +110,7 @@ fn should_forward_stdin_to_remote_process(mut action_std_cmd: CtxCommand) { // Windows prints out a message whereas unix prints nothing #[cfg(windows)] @@ -115,7 +120,7 @@ fn reflect_the_exit_code_of_the_process(mut action_cmd: CtxCommand) { // distant action proc-spawn {cmd} [args] action_cmd - .args(&["proc-spawn", "--"]) + .args(["proc-spawn", "--"]) .arg(SCRIPT_RUNNER.as_str()) .arg(SCRIPT_RUNNER_ARG.as_str()) .arg(EXIT_CODE.to_str().unwrap()) @@ -127,10 +132,11 @@ fn reflect_the_exit_code_of_the_process(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { // distant action proc-spawn {cmd} [args] action_cmd - .args(&["proc-spawn", "--"]) + .args(["proc-spawn", "--"]) .arg(DOES_NOT_EXIST_BIN.to_str().unwrap()) .assert() .code(1) diff --git a/tests/cli/action/remove.rs b/tests/cli/action/remove.rs index 681498c..6b3323e 100644 --- a/tests/cli/action/remove.rs +++ b/tests/cli/action/remove.rs @@ -5,6 +5,7 @@ use predicates::prelude::*; use rstest::*; #[rstest] +#[test_log::test] fn should_support_removing_file(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); @@ -12,7 +13,7 @@ fn should_support_removing_file(mut action_cmd: CtxCommand) { // distant action remove {path} action_cmd - .args(&["remove", file.to_str().unwrap()]) + .args(["remove", file.to_str().unwrap()]) .assert() .success() .stdout("") @@ -22,6 +23,7 @@ fn should_support_removing_file(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn should_support_removing_empty_directory(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -31,7 +33,7 @@ fn should_support_removing_empty_directory(mut action_cmd: CtxCommand) // distant action remove {path} action_cmd - .args(&["remove", dir.to_str().unwrap()]) + .args(["remove", dir.to_str().unwrap()]) .assert() .success() .stdout("") @@ -41,6 +43,7 @@ fn should_support_removing_empty_directory(mut action_cmd: CtxCommand) } #[rstest] +#[test_log::test] fn should_support_removing_nonempty_directory_if_force_specified( mut action_cmd: CtxCommand, ) { @@ -53,7 +56,7 @@ fn should_support_removing_nonempty_directory_if_force_specified( // distant action remove --force {path} action_cmd - .args(&["remove", "--force", dir.to_str().unwrap()]) + .args(["remove", "--force", dir.to_str().unwrap()]) .assert() .success() .stdout("") @@ -63,6 +66,7 @@ fn should_support_removing_nonempty_directory_if_force_specified( } #[rstest] +#[test_log::test] fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -73,7 +77,7 @@ fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { // distant action remove {path} action_cmd - .args(&["remove", dir.to_str().unwrap()]) + .args(["remove", dir.to_str().unwrap()]) .assert() .code(1) .stdout("") diff --git a/tests/cli/action/rename.rs b/tests/cli/action/rename.rs index 2f57abb..1db1cfe 100644 --- a/tests/cli/action/rename.rs +++ b/tests/cli/action/rename.rs @@ -11,6 +11,7 @@ that is a file's contents "#; #[rstest] +#[test_log::test] fn should_support_renaming_file(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -21,7 +22,7 @@ fn should_support_renaming_file(mut action_cmd: CtxCommand) { // distant action rename {src} {dst} action_cmd - .args(&["rename", src.to_str().unwrap(), dst.to_str().unwrap()]) + .args(["rename", src.to_str().unwrap(), dst.to_str().unwrap()]) .assert() .success() .stdout("") @@ -32,6 +33,7 @@ fn should_support_renaming_file(mut action_cmd: CtxCommand) { } #[rstest] +#[test_log::test] fn should_support_renaming_nonempty_directory(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -46,7 +48,7 @@ fn should_support_renaming_nonempty_directory(mut action_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -68,7 +71,7 @@ fn yield_an_error_when_fails(mut action_cmd: CtxCommand) { // distant action rename {src} {dst} action_cmd - .args(&["rename", src.to_str().unwrap(), dst.to_str().unwrap()]) + .args(["rename", src.to_str().unwrap(), dst.to_str().unwrap()]) .assert() .code(1) .stdout("") diff --git a/tests/cli/action/search.rs b/tests/cli/action/search.rs index 7af972d..0b9c9a4 100644 --- a/tests/cli/action/search.rs +++ b/tests/cli/action/search.rs @@ -17,6 +17,7 @@ const SEARCH_RESULTS_REGEX: &str = indoc! {r" "}; #[rstest] +#[test_log::test] fn should_search_filesystem_using_query(mut action_cmd: CtxCommand) { let root = assert_fs::TempDir::new().unwrap(); root.child("file1.txt").write_str("some file text").unwrap(); diff --git a/tests/cli/action/system_info.rs b/tests/cli/action/system_info.rs index bf7c6ed..f82e8b1 100644 --- a/tests/cli/action/system_info.rs +++ b/tests/cli/action/system_info.rs @@ -4,6 +4,7 @@ use rstest::*; use std::env; #[rstest] +#[test_log::test] fn should_output_system_info(mut action_cmd: CtxCommand) { // distant action system-info action_cmd diff --git a/tests/cli/action/watch.rs b/tests/cli/action/watch.rs index d4df696..3fd8252 100644 --- a/tests/cli/action/watch.rs +++ b/tests/cli/action/watch.rs @@ -16,6 +16,7 @@ fn wait_millis(millis: u64) { } #[rstest] +#[test_log::test] fn should_support_watching_a_single_file(mut action_std_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); @@ -23,7 +24,7 @@ fn should_support_watching_a_single_file(mut action_std_cmd: CtxCommand // distant action watch {path} let mut child = action_std_cmd - .args(&["watch", file.to_str().unwrap()]) + .args(["watch", file.to_str().unwrap()]) .spawn() .expect("Failed to execute"); @@ -66,6 +67,7 @@ fn should_support_watching_a_single_file(mut action_std_cmd: CtxCommand } #[rstest] +#[test_log::test] fn should_support_watching_a_directory_recursively(mut action_std_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); @@ -77,7 +79,7 @@ fn should_support_watching_a_directory_recursively(mut action_std_cmd: CtxComman // distant action watch {path} let mut child = action_std_cmd - .args(&["watch", "--recursive", temp.to_str().unwrap()]) + .args(["watch", "--recursive", temp.to_str().unwrap()]) .spawn() .expect("Failed to execute"); @@ -120,13 +122,14 @@ fn should_support_watching_a_directory_recursively(mut action_std_cmd: CtxComman } #[rstest] +#[test_log::test] fn yield_an_error_when_fails(mut action_std_cmd: CtxCommand) { let temp = assert_fs::TempDir::new().unwrap(); let invalid_path = temp.to_path_buf().join("missing"); // distant action watch {path} let child = action_std_cmd - .args(&["watch", invalid_path.to_str().unwrap()]) + .args(["watch", invalid_path.to_str().unwrap()]) .spawn() .expect("Failed to execute"); diff --git a/tests/cli/fixtures.rs b/tests/cli/fixtures.rs index 1d4b278..779a25c 100644 --- a/tests/cli/fixtures.rs +++ b/tests/cli/fixtures.rs @@ -1,13 +1,16 @@ use assert_cmd::Command; use derive_more::{Deref, DerefMut}; +use distant_core::{net::common::Host, DistantSingleKeyCredentials}; use once_cell::sync::Lazy; use rstest::*; +use serde_json::json; use std::{ - io, + io::{BufReader, Read}, + net::{Ipv4Addr, Ipv6Addr}, path::PathBuf, process::{Child, Command as StdCommand, Stdio}, thread, - time::Duration, + time::{Duration, Instant}, }; mod repl; @@ -17,9 +20,8 @@ static ROOT_LOG_DIR: Lazy = Lazy::new(|| std::env::temp_dir().join("dis static SESSION_RANDOM: Lazy = Lazy::new(rand::random); const TIMEOUT: Duration = Duration::from_secs(3); -// Number of times to retry launching a server before giving up -const LAUNCH_RETRY_CNT: usize = 2; -const LAUNCH_RETRY_TIMEOUT: Duration = Duration::from_millis(250); +const MAX_RETRY_ATTEMPTS: usize = 3; +const RETRY_PAUSE_DURATION: Duration = Duration::from_millis(250); #[derive(Deref, DerefMut)] pub struct CtxCommand { @@ -33,6 +35,7 @@ pub struct CtxCommand { /// Context for some listening distant server pub struct DistantManagerCtx { manager: Child, + server: Child, socket_or_pipe: String, } @@ -50,7 +53,10 @@ impl DistantManagerCtx { .arg("--log-file") .arg(random_log_file("manager")) .arg("--log-level") - .arg("trace"); + .arg("trace") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); let socket_or_pipe = if cfg!(windows) { format!("distant_test_{}", rand::random::()) @@ -78,94 +84,155 @@ impl DistantManagerCtx { panic!("Manager exited ({}): {:?}", status.success(), status.code()); } - // Spawn a server locally by launching it through the manager - let mut launch_cmd = StdCommand::new(bin_path()); - launch_cmd - .arg("client") - .arg("launch") - .arg("--log-file") - .arg(random_log_file("launch")) - .arg("--log-level") - .arg("trace") - .arg("--distant") - .arg(bin_path()) - .arg("--distant-args") - .arg(format!( - "--log-file {} --log-level trace", - random_log_file("server").to_string_lossy() - )); - - if cfg!(windows) { - launch_cmd - .arg("--windows-pipe") - .arg(socket_or_pipe.as_str()); - } else { - launch_cmd.arg("--unix-socket").arg(socket_or_pipe.as_str()); - } - - launch_cmd.arg("manager://localhost"); - - for i in 0..=LAUNCH_RETRY_CNT { - eprintln!("[{i}/{LAUNCH_RETRY_CNT}] Spawning launch cmd: {launch_cmd:?}"); - let output = launch_cmd.output().expect("Failed to launch server"); - let success = output.status.success(); - if success { - break; + let mut server = None; + 'outer: for i in 1..=MAX_RETRY_ATTEMPTS { + let mut err = String::new(); + + // Spawn a server and capture the credentials so we can connect to it + let mut server_cmd = StdCommand::new(bin_path()); + let server_log_file = random_log_file("server"); + server_cmd + .arg("server") + .arg("listen") + .arg("--log-file") + .arg(&server_log_file) + .arg("--log-level") + .arg("trace") + .arg("--shutdown") + .arg("lonely=60") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + eprintln!("Spawning server cmd: {server_cmd:?}"); + server = match server_cmd.spawn() { + Ok(server) => Some(server), + Err(x) => { + eprintln!("--- SERVER LOG ---"); + eprintln!( + "{}", + std::fs::read_to_string(server_log_file.as_path()) + .unwrap_or_else(|_| format!("Unable to read: {server_log_file:?}")) + ); + eprintln!("------------------"); + if i == MAX_RETRY_ATTEMPTS { + panic!("Failed to spawn server: {x}"); + } else { + continue; + } + } + }; + + // Spawn a thread to read stdout to look for credentials + let stdout = server.as_mut().unwrap().stdout.take().unwrap(); + let stdout_thread = thread::spawn(move || { + let mut reader = BufReader::new(stdout); + let mut lines = String::new(); + let mut buf = [0u8; 1024]; + while let Ok(n) = reader.read(&mut buf) { + lines.push_str(&String::from_utf8_lossy(&buf[..n])); + if let Some(credentials) = DistantSingleKeyCredentials::find(&lines) { + return credentials; + } + } + panic!("Failed to read line"); + }); + + // Wait for thread to finish (up to 500ms) + let start = Instant::now(); + while !stdout_thread.is_finished() { + if start.elapsed() > Duration::from_millis(500) { + break; + } + thread::sleep(Duration::from_millis(50)); } - if !success && i == LAUNCH_RETRY_CNT { - let _ = manager.kill(); - panic!( - "Failed to launch: {}", + let mut credentials = match stdout_thread.join() { + Ok(credentials) => credentials, + Err(x) => { + if let Err(x) = server.as_mut().unwrap().kill() { + eprintln!("Encountered error, but failed to kill server: {x}"); + } + + if i == MAX_RETRY_ATTEMPTS { + panic!("Failed to retrieve credentials: {x:?}"); + } else { + eprintln!("Failed to retrieve credentials: {x:?}"); + continue; + } + } + }; + + for host in vec![ + Host::Ipv4(Ipv4Addr::LOCALHOST), + Host::Ipv6(Ipv6Addr::LOCALHOST), + Host::Name("localhost".to_string()), + ] { + credentials.host = host.clone(); + // Connect manager to server + let mut connect_cmd = StdCommand::new(bin_path()); + connect_cmd + .arg("client") + .arg("connect") + .arg("--log-file") + .arg(random_log_file("connect")) + .arg("--log-level") + .arg("trace") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + if cfg!(windows) { + connect_cmd + .arg("--windows-pipe") + .arg(socket_or_pipe.as_str()); + } else { + connect_cmd + .arg("--unix-socket") + .arg(socket_or_pipe.as_str()); + } + + connect_cmd.arg(credentials.to_string()); + + eprintln!("[{i}/{MAX_RETRY_ATTEMPTS}] Host: {host} | Spawning connect cmd: {connect_cmd:?}"); + let output = connect_cmd.output().expect("Failed to connect to server"); + + if output.status.success() { + break 'outer; + } + + err = format!( + "{err}\nConnecting to host {host} failed: {}", String::from_utf8_lossy(&output.stderr) ); } - thread::sleep(LAUNCH_RETRY_TIMEOUT); + if let Err(x) = server.as_mut().unwrap().kill() { + eprintln!("Failed to connect, and failed to kill server: {x}"); + } + + if i == MAX_RETRY_ATTEMPTS { + eprintln!("--- SERVER LOG ---"); + eprintln!( + "{}", + std::fs::read_to_string(server_log_file.as_path()) + .unwrap_or_else(|_| format!("Unable to read: {server_log_file:?}")) + ); + eprintln!("------------------"); + + panic!("Connecting to server failed: {err}"); + } else { + thread::sleep(RETRY_PAUSE_DURATION); + } } + eprintln!("Connected! Proceeding with test..."); Self { manager, + server: server.unwrap(), socket_or_pipe, } } - pub fn shutdown(&self) -> io::Result<()> { - // Send a shutdown request to the manager - let mut shutdown_cmd = StdCommand::new(bin_path()); - shutdown_cmd - .arg("manager") - .arg("shutdown") - .arg("--log-file") - .arg(random_log_file("shutdown")) - .arg("--log-level") - .arg("trace"); - - if cfg!(windows) { - shutdown_cmd - .arg("--windows-pipe") - .arg(self.socket_or_pipe.as_str()); - } else { - shutdown_cmd - .arg("--unix-socket") - .arg(self.socket_or_pipe.as_str()); - } - - eprintln!("Spawning shutdown cmd: {shutdown_cmd:?}"); - let output = shutdown_cmd.output().expect("Failed to shutdown server"); - if !output.status.success() { - Err(io::Error::new( - io::ErrorKind::Other, - format!( - "Failed to shutdown: {}", - String::from_utf8_lossy(&output.stderr) - ), - )) - } else { - Ok(()) - } - } - /// Produces a new test command that configures some distant command /// configured with an environment that can talk to a remote distant server pub fn new_assert_cmd(&self, subcommands: impl IntoIterator) -> Command { @@ -233,11 +300,10 @@ fn random_log_file(prefix: &str) -> PathBuf { impl Drop for DistantManagerCtx { /// Kills manager upon drop fn drop(&mut self) { - // Attempt to shutdown gracefully, forcing a kill otherwise - if self.shutdown().is_err() { - let _ = self.manager.kill(); - let _ = self.manager.wait(); - } + let _ = self.manager.kill(); + let _ = self.server.kill(); + let _ = self.manager.wait(); + let _ = self.server.wait(); } } @@ -273,5 +339,38 @@ pub fn json_repl(ctx: DistantManagerCtx) -> CtxCommand { .spawn() .expect("Failed to start distant repl with json format"); let cmd = Repl::new(child, TIMEOUT); + CtxCommand { ctx, cmd } } + +pub async fn validate_authentication(repl: &mut Repl) { + // NOTE: We have to handle receiving authentication messages, as we will get + // an authentication initialization of with method "none", and then + // a finish authentication status before we can do anything else. + let json = repl + .read_json_from_stdout() + .await + .unwrap() + .expect("Missing authentication initialization"); + assert_eq!( + json, + json!({"type": "auth_initialization", "methods": ["none"]}) + ); + + let json = repl + .write_and_read_json(json!({ + "type": "auth_initialization_response", + "methods": ["none"] + })) + .await + .unwrap() + .expect("Missing authentication method"); + assert_eq!(json, json!({"type": "auth_start_method", "method": "none"})); + + let json = repl + .read_json_from_stdout() + .await + .unwrap() + .expect("Missing authentication finalization"); + assert_eq!(json, json!({"type": "auth_finished"})); +} diff --git a/tests/cli/manager/capabilities.rs b/tests/cli/manager/capabilities.rs index 0c31f98..c42ad2d 100644 --- a/tests/cli/manager/capabilities.rs +++ b/tests/cli/manager/capabilities.rs @@ -6,6 +6,8 @@ const EXPECTED_TABLE: &str = indoc! {" +---------------+--------------------------------------------------------------+ | kind | description | +---------------+--------------------------------------------------------------+ +| authenticate | Supports authenticating with a remote server | ++---------------+--------------------------------------------------------------+ | capabilities | Supports retrieving capabilities | +---------------+--------------------------------------------------------------+ | channel | Supports sending data through a channel with a remote server | @@ -18,17 +20,16 @@ const EXPECTED_TABLE: &str = indoc! {" +---------------+--------------------------------------------------------------+ | kill | Supports killing a remote connection | +---------------+--------------------------------------------------------------+ -| launch | Supports launching distant on remote servers | +| launch | Supports launching a server on remote machines | +---------------+--------------------------------------------------------------+ | list | Supports retrieving a list of managed connections | +---------------+--------------------------------------------------------------+ | open_channel | Supports opening a channel with a remote server | +---------------+--------------------------------------------------------------+ -| shutdown | Supports being shut down on demand | -+---------------+--------------------------------------------------------------+ "}; #[rstest] +#[test_log::test] fn should_output_capabilities(ctx: DistantManagerCtx) { // distant action capabilities ctx.new_assert_cmd(vec!["manager", "capabilities"]) diff --git a/tests/cli/repl/capabilities.rs b/tests/cli/repl/capabilities.rs index c1db2d7..e379aec 100644 --- a/tests/cli/repl/capabilities.rs +++ b/tests/cli/repl/capabilities.rs @@ -2,10 +2,13 @@ use crate::cli::fixtures::*; use distant_core::data::{Capabilities, Capability}; use rstest::*; use serde_json::json; +use test_log::test; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_capabilities(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let id = rand::random::().to_string(); let req = json!({ "id": id, @@ -14,8 +17,8 @@ async fn should_support_json_capabilities(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "capabilities"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "capabilities", "JSON: {res}"); let supported: Capabilities = res["payload"]["supported"] .as_array() diff --git a/tests/cli/repl/copy.rs b/tests/cli/repl/copy.rs index a07de97..bcd205f 100644 --- a/tests/cli/repl/copy.rs +++ b/tests/cli/repl/copy.rs @@ -3,6 +3,7 @@ use assert_fs::prelude::*; use predicates::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; const FILE_CONTENTS: &str = r#" some text @@ -11,8 +12,10 @@ that is a file's contents "#; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_copying_file(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let src = temp.child("file"); @@ -32,12 +35,13 @@ async fn should_support_json_copying_file(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); src.assert(predicate::path::exists()); @@ -45,8 +49,10 @@ async fn should_support_json_copying_file(mut json_repl: CtxCommand) { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_copying_nonempty_directory(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); // Make a non-empty directory @@ -70,12 +76,13 @@ async fn should_support_json_copying_nonempty_directory(mut json_repl: CtxComman let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); src_file.assert(predicate::path::exists()); @@ -83,8 +90,10 @@ async fn should_support_json_copying_nonempty_directory(mut json_repl: CtxComman } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let src = temp.child("dir"); @@ -102,9 +111,9 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); src.assert(predicate::path::missing()); dst.assert(predicate::path::missing()); diff --git a/tests/cli/repl/dir_create.rs b/tests/cli/repl/dir_create.rs index 4432281..6a8fc52 100644 --- a/tests/cli/repl/dir_create.rs +++ b/tests/cli/repl/dir_create.rs @@ -3,10 +3,13 @@ use assert_fs::prelude::*; use predicates::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("dir"); @@ -22,12 +25,13 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); dir.assert(predicate::path::exists()); @@ -35,10 +39,12 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_creating_missing_parent_directories_if_specified( mut json_repl: CtxCommand, ) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("dir1").child("dir2"); @@ -54,12 +60,13 @@ async fn should_support_json_creating_missing_parent_directories_if_specified( let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); dir.assert(predicate::path::exists()); @@ -67,8 +74,10 @@ async fn should_support_json_creating_missing_parent_directories_if_specified( } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("missing-dir").child("dir"); @@ -84,9 +93,9 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); dir.assert(predicate::path::missing()); } diff --git a/tests/cli/repl/dir_read.rs b/tests/cli/repl/dir_read.rs index bfb80ec..a718768 100644 --- a/tests/cli/repl/dir_read.rs +++ b/tests/cli/repl/dir_read.rs @@ -3,6 +3,7 @@ use assert_fs::prelude::*; use rstest::*; use serde_json::json; use std::path::PathBuf; +use test_log::test; /// Creates a directory in the form /// @@ -69,8 +70,10 @@ fn make_directory() -> assert_fs::TempDir { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = make_directory(); let id = rand::random::().to_string(); @@ -88,7 +91,7 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ @@ -100,15 +103,18 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { {"path": PathBuf::from("file2"), "file_type": "file", "depth": 1}, ], "errors": [], - }) + }), + "JSON: {res}" ); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_returning_absolute_paths_if_specified( mut json_repl: CtxCommand, ) { + validate_authentication(&mut json_repl).await; + let temp = make_directory(); // NOTE: Our root path is always canonicalized, so the absolute path @@ -130,7 +136,7 @@ async fn should_support_json_returning_absolute_paths_if_specified( let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ @@ -142,15 +148,18 @@ async fn should_support_json_returning_absolute_paths_if_specified( {"path": root_path.join("file2"), "file_type": "file", "depth": 1}, ], "errors": [], - }) + }), + "JSON: {res}" ); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_returning_all_files_and_directories_if_depth_is_0( mut json_repl: CtxCommand, ) { + validate_authentication(&mut json_repl).await; + let temp = make_directory(); let id = rand::random::().to_string(); @@ -168,7 +177,7 @@ async fn should_support_json_returning_all_files_and_directories_if_depth_is_0( let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ @@ -190,15 +199,18 @@ async fn should_support_json_returning_all_files_and_directories_if_depth_is_0( {"path": PathBuf::from("file2"), "file_type": "file", "depth": 1}, ], "errors": [], - }) + }), + "JSON: {res}" ); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_including_root_directory_if_specified( mut json_repl: CtxCommand, ) { + validate_authentication(&mut json_repl).await; + let temp = make_directory(); // NOTE: Our root path is always canonicalized, so yielded entry @@ -220,7 +232,7 @@ async fn should_support_json_including_root_directory_if_specified( let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ @@ -233,13 +245,16 @@ async fn should_support_json_including_root_directory_if_specified( {"path": PathBuf::from("file2"), "file_type": "file", "depth": 1}, ], "errors": [], - }) + }), + "JSON: {res}" ); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = make_directory(); let dir = temp.child("missing-dir"); @@ -258,7 +273,7 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); } diff --git a/tests/cli/repl/exists.rs b/tests/cli/repl/exists.rs index 7c4a434..80da487 100644 --- a/tests/cli/repl/exists.rs +++ b/tests/cli/repl/exists.rs @@ -2,10 +2,13 @@ use crate::cli::fixtures::*; use assert_fs::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_true_if_exists(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); // Create file @@ -23,19 +26,22 @@ async fn should_support_json_true_if_exists(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "exists", "value": true, - }) + }), + "JSON: {res}" ); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_false_if_not_exists(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); // Don't create file @@ -52,12 +58,13 @@ async fn should_support_json_false_if_not_exists(mut json_repl: CtxCommand let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "exists", "value": false, - }) + }), + "JSON: {res}" ); } diff --git a/tests/cli/repl/file_append.rs b/tests/cli/repl/file_append.rs index e73fbd5..6c7360b 100644 --- a/tests/cli/repl/file_append.rs +++ b/tests/cli/repl/file_append.rs @@ -2,6 +2,7 @@ use crate::cli::fixtures::*; use assert_fs::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; const FILE_CONTENTS: &str = r#" some text @@ -15,8 +16,10 @@ file contents "#; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); file.write_str(FILE_CONTENTS).unwrap(); @@ -33,12 +36,13 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); // NOTE: We wait a little bit to give the OS time to fully write to file @@ -49,8 +53,10 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); @@ -66,9 +72,9 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); // Because we're talking to a local server, we can verify locally file.assert(predicates::path::missing()); diff --git a/tests/cli/repl/file_append_text.rs b/tests/cli/repl/file_append_text.rs index 061c989..f2d2ac3 100644 --- a/tests/cli/repl/file_append_text.rs +++ b/tests/cli/repl/file_append_text.rs @@ -2,6 +2,7 @@ use crate::cli::fixtures::*; use assert_fs::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; const FILE_CONTENTS: &str = r#" some text @@ -15,8 +16,10 @@ file contents "#; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); file.write_str(FILE_CONTENTS).unwrap(); @@ -33,12 +36,13 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); // NOTE: We wait a little bit to give the OS time to fully write to file @@ -49,8 +53,10 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); @@ -66,9 +72,9 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); // Because we're talking to a local server, we can verify locally file.assert(predicates::path::missing()); diff --git a/tests/cli/repl/file_read.rs b/tests/cli/repl/file_read.rs index 5938998..6504e93 100644 --- a/tests/cli/repl/file_read.rs +++ b/tests/cli/repl/file_read.rs @@ -2,6 +2,7 @@ use crate::cli::fixtures::*; use assert_fs::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; const FILE_CONTENTS: &str = r#" some text @@ -10,8 +11,10 @@ that is a file's contents "#; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); file.write_str(FILE_CONTENTS).unwrap(); @@ -27,19 +30,22 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "blob", "data": FILE_CONTENTS.as_bytes().to_vec() - }) + }), + "JSON: {res}" ); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-file"); @@ -54,7 +60,7 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); } diff --git a/tests/cli/repl/file_read_text.rs b/tests/cli/repl/file_read_text.rs index d1b8379..78bc13f 100644 --- a/tests/cli/repl/file_read_text.rs +++ b/tests/cli/repl/file_read_text.rs @@ -2,6 +2,7 @@ use crate::cli::fixtures::*; use assert_fs::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; const FILE_CONTENTS: &str = r#" some text @@ -10,8 +11,10 @@ that is a file's contents "#; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); file.write_str(FILE_CONTENTS).unwrap(); @@ -27,19 +30,22 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "text", "data": FILE_CONTENTS.to_string() - }) + }), + "JSON: {res}" ); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-file"); @@ -54,7 +60,7 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); } diff --git a/tests/cli/repl/file_write.rs b/tests/cli/repl/file_write.rs index 21ef333..07eafd0 100644 --- a/tests/cli/repl/file_write.rs +++ b/tests/cli/repl/file_write.rs @@ -2,6 +2,7 @@ use crate::cli::fixtures::*; use assert_fs::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; const FILE_CONTENTS: &str = r#" some text @@ -10,8 +11,10 @@ that is a file's contents "#; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); @@ -27,12 +30,13 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); // NOTE: We wait a little bit to give the OS time to fully write to file @@ -43,8 +47,10 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); @@ -60,9 +66,9 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); // Because we're talking to a local server, we can verify locally file.assert(predicates::path::missing()); diff --git a/tests/cli/repl/file_write_text.rs b/tests/cli/repl/file_write_text.rs index 00c50cc..8c75bc3 100644 --- a/tests/cli/repl/file_write_text.rs +++ b/tests/cli/repl/file_write_text.rs @@ -2,6 +2,7 @@ use crate::cli::fixtures::*; use assert_fs::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; const FILE_CONTENTS: &str = r#" some text @@ -10,8 +11,10 @@ that is a file's contents "#; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("test-file"); @@ -27,12 +30,13 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); // NOTE: We wait a little bit to give the OS time to fully write to file @@ -43,8 +47,10 @@ async fn should_support_json_output(mut json_repl: CtxCommand) { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("missing-dir").child("missing-file"); @@ -60,9 +66,9 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); // Because we're talking to a local server, we can verify locally file.assert(predicates::path::missing()); diff --git a/tests/cli/repl/metadata.rs b/tests/cli/repl/metadata.rs index 81680bc..6fcb61c 100644 --- a/tests/cli/repl/metadata.rs +++ b/tests/cli/repl/metadata.rs @@ -2,6 +2,7 @@ use crate::cli::fixtures::*; use assert_fs::prelude::*; use rstest::*; use serde_json::{json, Value}; +use test_log::test; const FILE_CONTENTS: &str = r#" some text @@ -10,8 +11,10 @@ that is a file's contents "#; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_metadata_for_file(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); @@ -30,16 +33,22 @@ async fn should_support_json_metadata_for_file(mut json_repl: CtxCommand) let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "metadata"); - assert_eq!(res["payload"]["canonicalized_path"], Value::Null); - assert_eq!(res["payload"]["file_type"], "file"); - assert_eq!(res["payload"]["readonly"], false); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "metadata", "JSON: {res}"); + assert_eq!( + res["payload"]["canonicalized_path"], + Value::Null, + "JSON: {res}" + ); + assert_eq!(res["payload"]["file_type"], "file", "JSON: {res}"); + assert_eq!(res["payload"]["readonly"], false, "JSON: {res}"); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_metadata_for_directory(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("dir"); @@ -58,18 +67,24 @@ async fn should_support_json_metadata_for_directory(mut json_repl: CtxCommand, ) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); @@ -91,21 +106,24 @@ async fn should_support_json_metadata_for_including_a_canonicalized_path( let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "metadata"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "metadata", "JSON: {res}"); assert_eq!( res["payload"]["canonicalized_path"], - json!(file.path().canonicalize().unwrap()) + json!(file.path().canonicalize().unwrap()), + "JSON: {res}" ); - assert_eq!(res["payload"]["file_type"], "symlink"); - assert_eq!(res["payload"]["readonly"], false); + assert_eq!(res["payload"]["file_type"], "symlink", "JSON: {res}"); + assert_eq!(res["payload"]["readonly"], false, "JSON: {res}"); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_metadata_for_resolving_file_type_of_symlink( mut json_repl: CtxCommand, ) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); @@ -127,14 +145,16 @@ async fn should_support_json_metadata_for_resolving_file_type_of_symlink( let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "metadata"); - assert_eq!(res["payload"]["file_type"], "file"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "metadata", "JSON: {res}"); + assert_eq!(res["payload"]["file_type"], "file", "JSON: {res}"); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); // Don't create file @@ -153,7 +173,7 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); } diff --git a/tests/cli/repl/proc_spawn.rs b/tests/cli/repl/proc_spawn.rs index 7d7a240..4934629 100644 --- a/tests/cli/repl/proc_spawn.rs +++ b/tests/cli/repl/proc_spawn.rs @@ -1,6 +1,7 @@ use crate::cli::{fixtures::*, scripts::*}; use rstest::*; use serde_json::json; +use test_log::test; fn make_cmd(args: Vec<&str>) -> String { format!( @@ -69,10 +70,12 @@ fn check_value_as_str(value: &serde_json::Value, other: &str) { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_to_execute_program_and_return_exit_status( mut json_repl: CtxCommand, ) { + validate_authentication(&mut json_repl).await; + let cmd = make_cmd(vec![ECHO_ARGS_TO_STDOUT.to_str().unwrap()]); let id = rand::random::().to_string(); @@ -88,13 +91,15 @@ async fn should_support_json_to_execute_program_and_return_exit_status( let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "proc_spawned"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "proc_spawned", "JSON: {res}"); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_to_capture_and_print_stdout(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let cmd = make_cmd(vec![ECHO_ARGS_TO_STDOUT.to_str().unwrap(), "some output"]); // Spawn the process @@ -111,27 +116,29 @@ async fn should_support_json_to_capture_and_print_stdout(mut json_repl: CtxComma let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], origin_id); - assert_eq!(res["payload"]["type"], "proc_spawned"); + assert_eq!(res["origin_id"], origin_id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "proc_spawned", "JSON: {res}"); // Wait for output to show up (for stderr) let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); - assert_eq!(res["origin_id"], origin_id); - assert_eq!(res["payload"]["type"], "proc_stdout"); + assert_eq!(res["origin_id"], origin_id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "proc_stdout", "JSON: {res}"); check_value_as_str(&res["payload"]["data"], "some output"); // Now we wait for the process to complete let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); - assert_eq!(res["origin_id"], origin_id); - assert_eq!(res["payload"]["type"], "proc_done"); - assert_eq!(res["payload"]["success"], true); + assert_eq!(res["origin_id"], origin_id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "proc_done", "JSON: {res}"); + assert_eq!(res["payload"]["success"], true, "JSON: {res}"); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_to_capture_and_print_stderr(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let cmd = make_cmd(vec![ECHO_ARGS_TO_STDERR.to_str().unwrap(), "some output"]); // Spawn the process @@ -148,27 +155,29 @@ async fn should_support_json_to_capture_and_print_stderr(mut json_repl: CtxComma let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], origin_id); - assert_eq!(res["payload"]["type"], "proc_spawned"); + assert_eq!(res["origin_id"], origin_id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "proc_spawned", "JSON: {res}"); // Wait for output to show up (for stderr) let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); - assert_eq!(res["origin_id"], origin_id); - assert_eq!(res["payload"]["type"], "proc_stderr"); + assert_eq!(res["origin_id"], origin_id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "proc_stderr", "JSON: {res}"); check_value_as_str(&res["payload"]["data"], "some output"); // Now we wait for the process to complete let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); - assert_eq!(res["origin_id"], origin_id); - assert_eq!(res["payload"]["type"], "proc_done"); - assert_eq!(res["payload"]["success"], true); + assert_eq!(res["origin_id"], origin_id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "proc_done", "JSON: {res}"); + assert_eq!(res["payload"]["success"], true, "JSON: {res}"); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_to_forward_stdin_to_remote_process(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let cmd = make_cmd(vec![ECHO_STDIN_TO_STDOUT.to_str().unwrap()]); // Spawn the process @@ -185,8 +194,8 @@ async fn should_support_json_to_forward_stdin_to_remote_process(mut json_repl: C let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], origin_id); - assert_eq!(res["payload"]["type"], "proc_spawned"); + assert_eq!(res["origin_id"], origin_id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "proc_spawned", "JSON: {res}"); // Write output to stdin of process to trigger getting it back as stdout let proc_id = res["payload"]["id"] @@ -204,13 +213,13 @@ async fn should_support_json_to_forward_stdin_to_remote_process(mut json_repl: C let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "ok"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "ok", "JSON: {res}"); let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); - assert_eq!(res["origin_id"], origin_id); - assert_eq!(res["payload"]["type"], "proc_stdout"); + assert_eq!(res["origin_id"], origin_id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "proc_stdout", "JSON: {res}"); check_value_as_str(&res["payload"]["data"], "some output"); // Now kill the process and wait for it to complete @@ -236,15 +245,15 @@ async fn should_support_json_to_forward_stdin_to_remote_process(mut json_repl: C res_1["payload"]["type"] == "proc_done" || res_2["payload"]["type"] == "proc_done"; if res_1["payload"]["type"] == "ok" { - assert_eq!(res_1["origin_id"], id); + assert_eq!(res_1["origin_id"], id, "JSON: {res_1}"); } else if res_1["payload"]["type"] == "proc_done" { - assert_eq!(res_1["origin_id"], origin_id); + assert_eq!(res_1["origin_id"], origin_id, "JSON: {res_1}"); } if res_2["payload"]["type"] == "ok" { - assert_eq!(res_2["origin_id"], id); + assert_eq!(res_2["origin_id"], id, "JSON: {res_2}"); } else if res_2["payload"]["type"] == "proc_done" { - assert_eq!(res_2["origin_id"], origin_id); + assert_eq!(res_2["origin_id"], origin_id, "JSON: {res_2}"); } assert!(got_ok, "Did not receive ok from proc_kill"); @@ -252,8 +261,10 @@ async fn should_support_json_to_forward_stdin_to_remote_process(mut json_repl: C } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let id = rand::random::().to_string(); let req = json!({ "id": id, @@ -267,7 +278,7 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); } diff --git a/tests/cli/repl/remove.rs b/tests/cli/repl/remove.rs index 9989622..ddfa17f 100644 --- a/tests/cli/repl/remove.rs +++ b/tests/cli/repl/remove.rs @@ -3,10 +3,13 @@ use assert_fs::prelude::*; use predicates::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_removing_file(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); @@ -24,20 +27,23 @@ async fn should_support_json_removing_file(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); file.assert(predicate::path::missing()); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_removing_empty_directory(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); // Make an empty directory @@ -56,22 +62,25 @@ async fn should_support_json_removing_empty_directory(mut json_repl: CtxCommand< let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); dir.assert(predicate::path::missing()); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_removing_nonempty_directory_if_force_specified( mut json_repl: CtxCommand, ) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); // Make an empty directory @@ -90,20 +99,23 @@ async fn should_support_json_removing_nonempty_directory_if_force_specified( let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); dir.assert(predicate::path::missing()); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); // Make a non-empty directory so we fail to remove it @@ -123,11 +135,11 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); assert!( res["payload"]["kind"] == "other" || res["payload"]["kind"] == "unknown", - "error kind was neither other or unknown" + "error kind was neither other or unknown; JSON: {res}" ); dir.assert(predicate::path::exists()); diff --git a/tests/cli/repl/rename.rs b/tests/cli/repl/rename.rs index e1c50f9..994bf94 100644 --- a/tests/cli/repl/rename.rs +++ b/tests/cli/repl/rename.rs @@ -3,6 +3,7 @@ use assert_fs::prelude::*; use predicates::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; const FILE_CONTENTS: &str = r#" some text @@ -11,8 +12,10 @@ that is a file's contents "#; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_renaming_file(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let src = temp.child("file"); @@ -32,12 +35,13 @@ async fn should_support_json_renaming_file(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); src.assert(predicate::path::missing()); @@ -45,8 +49,10 @@ async fn should_support_json_renaming_file(mut json_repl: CtxCommand) { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_renaming_nonempty_directory(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); // Make a non-empty directory @@ -70,12 +76,13 @@ async fn should_support_json_renaming_nonempty_directory(mut json_repl: CtxComma let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); src.assert(predicate::path::missing()); @@ -86,8 +93,10 @@ async fn should_support_json_renaming_nonempty_directory(mut json_repl: CtxComma } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let src = temp.child("dir"); @@ -105,9 +114,9 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); src.assert(predicate::path::missing()); dst.assert(predicate::path::missing()); diff --git a/tests/cli/repl/search.rs b/tests/cli/repl/search.rs index 3a6f2bb..47adc3d 100644 --- a/tests/cli/repl/search.rs +++ b/tests/cli/repl/search.rs @@ -2,10 +2,13 @@ use crate::cli::fixtures::*; use assert_fs::prelude::*; use rstest::*; use serde_json::json; +use test_log::test; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_search_filesystem_using_query(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let root = assert_fs::TempDir::new().unwrap(); root.child("file1.txt").write_str("some file text").unwrap(); root.child("file2.txt") @@ -30,15 +33,15 @@ async fn should_support_json_search_filesystem_using_query(mut json_repl: CtxCom let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); // Get id from started confirmation - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "search_started"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "search_started", "JSON: {res}"); let search_id = res["payload"]["id"] .as_u64() .expect("id missing or not number"); // Get search results back let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ @@ -66,17 +69,19 @@ async fn should_support_json_search_filesystem_using_query(mut json_repl: CtxCom ], }, ] - }) + }), + "JSON: {res}" ); // Get search completion confirmation let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "search_done", "id": search_id, - }) + }), + "JSON: {res}" ); } diff --git a/tests/cli/repl/system_info.rs b/tests/cli/repl/system_info.rs index 2dbf417..4e135b8 100644 --- a/tests/cli/repl/system_info.rs +++ b/tests/cli/repl/system_info.rs @@ -2,10 +2,13 @@ use crate::cli::fixtures::*; use rstest::*; use serde_json::json; use std::env; +use test_log::test; #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_system_info(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let id = rand::random::().to_string(); let req = json!({ "id": id, @@ -14,7 +17,7 @@ async fn should_support_json_system_info(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ @@ -30,6 +33,7 @@ async fn should_support_json_system_info(mut json_repl: CtxCommand) { } else { std::env::var("SHELL").unwrap_or_else(|_| String::from("/bin/sh")) } - }) + }), + "JSON: {res}" ); } diff --git a/tests/cli/repl/watch.rs b/tests/cli/repl/watch.rs index ce4e7df..4048d72 100644 --- a/tests/cli/repl/watch.rs +++ b/tests/cli/repl/watch.rs @@ -3,6 +3,7 @@ use assert_fs::prelude::*; use rstest::*; use serde_json::json; use std::time::Duration; +use test_log::test; async fn wait_a_bit() { wait_millis(250).await; @@ -17,8 +18,10 @@ async fn wait_millis(millis: u64) { } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_watching_single_file(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file = temp.child("file"); @@ -36,12 +39,13 @@ async fn should_support_json_watching_single_file(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let dir = temp.child("dir"); @@ -86,12 +93,13 @@ async fn should_support_json_watching_directory_recursively(mut json_repl: CtxCo let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); + assert_eq!(res["origin_id"], id, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); // Make a change to some file @@ -106,11 +114,12 @@ async fn should_support_json_watching_directory_recursively(mut json_repl: CtxCo // NOTE: Don't bother checking the kind as it can vary by platform let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "changed"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "changed", "JSON: {res}"); assert_eq!( res["payload"]["paths"], - json!([dir.to_path_buf().canonicalize().unwrap()]) + json!([dir.to_path_buf().canonicalize().unwrap()]), + "JSON: {res}" ); } @@ -121,19 +130,22 @@ async fn should_support_json_watching_directory_recursively(mut json_repl: CtxCo // NOTE: Don't bother checking the kind as it can vary by platform let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "changed"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "changed", "JSON: {res}"); assert_eq!( res["payload"]["paths"], - json!([file.to_path_buf().canonicalize().unwrap()]) + json!([file.to_path_buf().canonicalize().unwrap()]), + "JSON: {res}" ); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_reporting_changes_using_correct_request_id( mut json_repl: CtxCommand, ) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let file1 = temp.child("file1"); @@ -154,12 +166,13 @@ async fn should_support_json_reporting_changes_using_correct_request_id( let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id_1); + assert_eq!(res["origin_id"], id_1, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); // Watch file2 for changes @@ -174,12 +187,13 @@ async fn should_support_json_reporting_changes_using_correct_request_id( let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id_2); + assert_eq!(res["origin_id"], id_2, "JSON: {res}"); assert_eq!( res["payload"], json!({ "type": "ok" - }) + }), + "JSON: {res}" ); // Make a change to file1 @@ -192,11 +206,12 @@ async fn should_support_json_reporting_changes_using_correct_request_id( // NOTE: Don't bother checking the kind as it can vary by platform let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id_1); - assert_eq!(res["payload"]["type"], "changed"); + assert_eq!(res["origin_id"], id_1, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "changed", "JSON: {res}"); assert_eq!( res["payload"]["paths"], - json!([file1.to_path_buf().canonicalize().unwrap()]) + json!([file1.to_path_buf().canonicalize().unwrap()]), + "JSON: {res}" ); // Process any extra messages (we might get create, content, and more) @@ -223,17 +238,20 @@ async fn should_support_json_reporting_changes_using_correct_request_id( // NOTE: Don't bother checking the kind as it can vary by platform let res = json_repl.read_json_from_stdout().await.unwrap().unwrap(); - assert_eq!(res["origin_id"], id_2); - assert_eq!(res["payload"]["type"], "changed"); + assert_eq!(res["origin_id"], id_2, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "changed", "JSON: {res}"); assert_eq!( res["payload"]["paths"], - json!([file2.to_path_buf().canonicalize().unwrap()]) + json!([file2.to_path_buf().canonicalize().unwrap()]), + "JSON: {res}" ); } #[rstest] -#[tokio::test] +#[test(tokio::test)] async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { + validate_authentication(&mut json_repl).await; + let temp = assert_fs::TempDir::new().unwrap(); let path = temp.to_path_buf().join("missing"); @@ -250,7 +268,7 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand) { let res = json_repl.write_and_read_json(req).await.unwrap().unwrap(); // Ensure we got an acknowledgement of watching that failed - assert_eq!(res["origin_id"], id); - assert_eq!(res["payload"]["type"], "error"); - assert_eq!(res["payload"]["kind"], "not_found"); + assert_eq!(res["origin_id"], id, "JSON: {res}"); + assert_eq!(res["payload"]["type"], "error", "JSON: {res}"); + assert_eq!(res["payload"]["kind"], "not_found", "JSON: {res}"); }