mirror of https://github.com/chipsenkbeil/distant
Compare commits
126 Commits
Author | SHA1 | Date |
---|---|---|
Chip Senkbeil | 3fe1fba339 | 1 year ago |
Chip Senkbeil | 48f7eb74ec | 1 year ago |
Chip Senkbeil | 96abcefdc5 | 1 year ago |
Chip Senkbeil | 22f3c2dd76 | 1 year ago |
Chip Senkbeil | 0320e7fe24 | 1 year ago |
Chip Senkbeil | 9e48300e83 | 1 year ago |
Chip Senkbeil | e304e6a689 | 1 year ago |
Chip Senkbeil | 8972013716 | 1 year ago |
Chip Senkbeil | 0efb5aee4c | 1 year ago |
Chip Senkbeil | 56b3b8f4f1 | 1 year ago |
Chip Senkbeil | eb23b4e1ad | 1 year ago |
Chip Senkbeil | dc7e9b5309 | 1 year ago |
Chip Senkbeil | e0b8769087 | 1 year ago |
Chip Senkbeil | 9bc50886bb | 1 year ago |
Chip Senkbeil | bd3b068651 | 1 year ago |
Chip Senkbeil | c61393750a | 1 year ago |
Chip Senkbeil | 2abaf0b814 | 1 year ago |
Chip Senkbeil | 0e03fc3011 | 1 year ago |
Chip Senkbeil | cb8ea0507f | 1 year ago |
Chip Senkbeil | 8a34fec1f7 | 1 year ago |
Chip Senkbeil | 6feeb2d012 | 1 year ago |
Chip Senkbeil | fefbe19a3c | 1 year ago |
Chip Senkbeil | be7a15caa0 | 1 year ago |
Chip Senkbeil | 84ea28402d | 1 year ago |
Chip Senkbeil | b74cba28df | 1 year ago |
Chip Senkbeil | f4180f6245 | 1 year ago |
Chip Senkbeil | c250acdfb4 | 1 year ago |
Chip Senkbeil | 1836f20a2a | 1 year ago |
Chip Senkbeil | 9096a7d81b | 1 year ago |
Chip Senkbeil | 7c08495904 | 1 year ago |
Chip Senkbeil | da75801639 | 1 year ago |
Nagy Botond | 8009cc9361 | 1 year ago |
Chip Senkbeil | 4fb9045152 | 1 year ago |
Chip Senkbeil | efad345a0d | 1 year ago |
Chip Senkbeil | 6ba3ded188 | 1 year ago |
Chip Senkbeil | c4c46f80a9 | 1 year ago |
Chip Senkbeil | 791a41c29e | 1 year ago |
Chip Senkbeil | a36263e7e1 | 1 year ago |
Chip Senkbeil | 6f98e44723 | 1 year ago |
Chip Senkbeil | 72cc998595 | 1 year ago |
Chip Senkbeil | 4eaae55d53 | 1 year ago |
Chip Senkbeil | 9da7679081 | 1 year ago |
Chip Senkbeil | 009996b554 | 1 year ago |
Chip Senkbeil | b163094d49 | 1 year ago |
Chip Senkbeil | 3225471e28 | 1 year ago |
Chip Senkbeil | 9f345eb31b | 1 year ago |
Chip Senkbeil | e99329d9a9 | 1 year ago |
Chip Senkbeil | 40c265e35b | 1 year ago |
Chip Senkbeil | af903013f6 | 1 year ago |
Chip Senkbeil | 76dc7cf1fa | 1 year ago |
Chip Senkbeil | 95c0d0c0d1 | 1 year ago |
Chip Senkbeil | 528dea0917 | 1 year ago |
Chip Senkbeil | 8cf7f11269 | 1 year ago |
Chip Senkbeil | 2042684c97 | 1 year ago |
Chip Senkbeil | 31aff1e282 | 1 year ago |
Chip Senkbeil | ea0424e2f4 | 1 year ago |
Chip Senkbeil | 137b4dc289 | 1 year ago |
Chip Senkbeil | 3208fdcaa2 | 1 year ago |
Chip Senkbeil | 8768106c67 | 1 year ago |
Chip Senkbeil | b3e0f651d5 | 1 year ago |
Chip Senkbeil | f2bd2f15f5 | 1 year ago |
Chip Senkbeil | 398aff2f12 | 1 year ago |
Chip Senkbeil | 7fceb63aa3 | 1 year ago |
Chip Senkbeil | 5740c2cc4d | 1 year ago |
Chip Senkbeil | b8fecaacc0 | 1 year ago |
Chip Senkbeil | 5b19870b98 | 1 year ago |
Chip Senkbeil | bbf74f1e71 | 1 year ago |
Chip Senkbeil | c989a851ce | 1 year ago |
Chip Senkbeil | 09e8442892 | 1 year ago |
Chip Senkbeil | 4b983b0229 | 1 year ago |
Chip Senkbeil | 093b4d2ec4 | 2 years ago |
Chip Senkbeil | cfee78c2da | 2 years ago |
Chip Senkbeil | d44df53e83 | 2 years ago |
Chip Senkbeil | 90305607e9 | 2 years ago |
Chip Senkbeil | 2b6bf3c0a8 | 2 years ago |
Chip Senkbeil | 656a8007d6 | 2 years ago |
Chip Senkbeil | 8853d1072a | 2 years ago |
Chip Senkbeil | 2ab41c4976 | 2 years ago |
Chip Senkbeil | 5940b21339 | 2 years ago |
Chip Senkbeil | 78b0ee628e | 2 years ago |
Chip Senkbeil | 40bd20e4ac | 2 years ago |
Chip Senkbeil | 55036478a0 | 2 years ago |
Chip Senkbeil | 27dc5775f9 | 2 years ago |
Chip Senkbeil | 9b2f0de0c5 | 2 years ago |
Chip Senkbeil | a023b8f22d | 2 years ago |
Chip Senkbeil | ee50eaf9b3 | 2 years ago |
Chip Senkbeil | ee595551ae | 2 years ago |
Chip Senkbeil | a544587bab | 2 years ago |
Chip Senkbeil | 1c393ef723 | 2 years ago |
Chip Senkbeil | a41ef5996e | 2 years ago |
Chip Senkbeil | 10141f2090 | 2 years ago |
Chip Senkbeil | e13ec37603 | 2 years ago |
Chip Senkbeil | 8f3b204474 | 2 years ago |
Chip Senkbeil | 3a4b98cdde | 2 years ago |
Chip Senkbeil | bc3f6eef04 | 2 years ago |
Chip Senkbeil | 65fdbe8650 | 2 years ago |
Chip Senkbeil | 9ef32fe811 | 2 years ago |
Chip Senkbeil | 4798b67dfe | 2 years ago |
Chip Senkbeil | 7d1b3ba6f0 | 2 years ago |
Chip Senkbeil | 4cf869ecb7 | 2 years ago |
Chip Senkbeil | a8107aed3a | 2 years ago |
Chip Senkbeil | 193bb6d237 | 2 years ago |
Chip Senkbeil | dac318eb1e | 2 years ago |
Chip Senkbeil | cae6c5e244 | 2 years ago |
Chip Senkbeil | 01610a3ac7 | 2 years ago |
Chip Senkbeil | 5130ee3b5f | 2 years ago |
Chip Senkbeil | 53fd8d0c4f | 2 years ago |
Chip Senkbeil | c19df9f538 | 2 years ago |
Chip Senkbeil | 1fa3a8acea | 2 years ago |
Chip Senkbeil | b9c00153a0 | 2 years ago |
Chip Senkbeil | 22b2a351de | 2 years ago |
Chip Senkbeil | 591cd6ff41 | 2 years ago |
Chip Senkbeil | 6d0bbd56fc | 2 years ago |
Chip Senkbeil | 56a030e6dd | 2 years ago |
Chip Senkbeil | 486e5399ff | 2 years ago |
Chip Senkbeil | 4011671a77 | 2 years ago |
Chip Senkbeil | 04b20d1348 | 2 years ago |
Chip Senkbeil | 6c4318baa0 | 2 years ago |
Chip Senkbeil | ec95f573b9 | 2 years ago |
Chip Senkbeil | 30548cdbfb | 2 years ago |
Chip Senkbeil | fd325e4523 | 2 years ago |
Chip Senkbeil | 2cdfb89751 | 2 years ago |
Chip Senkbeil | 74a37209eb | 2 years ago |
Chip Senkbeil | 8e8eb8c574 | 2 years ago |
Chip Senkbeil | 1ff3ef2db1 | 2 years ago |
Chip Senkbeil | a0c7c492bd | 2 years ago |
@ -0,0 +1,5 @@
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
linker = "aarch64-linux-gnu-gcc"
|
||||
|
||||
[target.armv7-unknown-linux-gnueabihf]
|
||||
linker = "arm-linux-gnueabihf-gcc"
|
@ -1,4 +1,6 @@
|
||||
[profile.ci]
|
||||
fail-fast = false
|
||||
retries = 2
|
||||
retries = 4
|
||||
slow-timeout = { period = "60s", terminate-after = 3 }
|
||||
status-level = "fail"
|
||||
final-status-level = "fail"
|
||||
|
@ -0,0 +1,24 @@
|
||||
name: 'Tag latest'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
action:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Tag latest and push
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
|
||||
run: |
|
||||
git config user.name "${GITHUB_ACTOR}"
|
||||
git config user.email "${GITHUB_ACTOR}@users.noreply.github.com"
|
||||
|
||||
origin_url="$(git config --get remote.origin.url)"
|
||||
origin_url="${origin_url/#https:\/\//https:\/\/$GITHUB_TOKEN@}" # add token to URL
|
||||
|
||||
git tag latest --force
|
||||
git push "$origin_url" --tags --force
|
@ -0,0 +1,28 @@
|
||||
name: 'Lock Threads'
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 3 * * *'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
concurrency:
|
||||
group: lock
|
||||
|
||||
jobs:
|
||||
action:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: dessant/lock-threads@v4
|
||||
with:
|
||||
issue-inactive-days: '30'
|
||||
issue-comment: >
|
||||
I'm going to lock this issue because it has been closed for _30 days_ ⏳.
|
||||
This helps our maintainers find and focus on the active issues.
|
||||
If you have found a problem that seems similar to this, please open a new
|
||||
issue and complete the issue template so we can capture all the details
|
||||
necessary to investigate further.
|
||||
process-only: 'issues'
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,44 @@
|
||||
[tasks.format]
|
||||
clear = true
|
||||
install_crate = "rustfmt-nightly"
|
||||
command = "cargo"
|
||||
args = ["+nightly", "fmt", "--all"]
|
||||
|
||||
[tasks.test]
|
||||
clear = true
|
||||
command = "cargo"
|
||||
args = ["test", "--release", "--all-features", "--workspace"]
|
||||
|
||||
[tasks.ci-test]
|
||||
clear = true
|
||||
command = "cargo"
|
||||
args = ["nextest", "run", "--profile", "ci", "--release", "--all-features", "--workspace"]
|
||||
|
||||
[tasks.post-ci-test]
|
||||
clear = true
|
||||
command = "cargo"
|
||||
args = ["test", "--release", "--all-features", "--workspace", "--doc"]
|
||||
|
||||
[tasks.publish]
|
||||
clear = true
|
||||
script = '''
|
||||
cargo publish --all-features -p distant-auth
|
||||
cargo publish --all-features -p distant-protocol
|
||||
cargo publish --all-features -p distant-net
|
||||
cargo publish --all-features -p distant-core
|
||||
cargo publish --all-features -p distant-local
|
||||
cargo publish --all-features -p distant-ssh2
|
||||
cargo publish --all-features
|
||||
'''
|
||||
|
||||
[tasks.dry-run-publish]
|
||||
clear = true
|
||||
script = '''
|
||||
cargo publish --all-features --dry-run -p distant-auth
|
||||
cargo publish --all-features --dry-run -p distant-protocol
|
||||
cargo publish --all-features --dry-run -p distant-net
|
||||
cargo publish --all-features --dry-run -p distant-core
|
||||
cargo publish --all-features --dry-run -p distant-local
|
||||
cargo publish --all-features --dry-run -p distant-ssh2
|
||||
cargo publish --all-features --dry-run
|
||||
'''
|
@ -0,0 +1,58 @@
|
||||
# Publish
|
||||
|
||||
Guide to publishing the binary and associated crates.
|
||||
|
||||
## 1. Update Changelog
|
||||
|
||||
Ensure that the changelog is updated for a new release. The CI build requires
|
||||
that the release version is specified in the format: `[VERSION] - DATE`.
|
||||
|
||||
1. Update the changelog by changing `[Unreleased]` to the latest version and
|
||||
date.
|
||||
2. Re-add a new `[Unreleased]` header at the top.
|
||||
3. At the bottom, add a new link for the current version.
|
||||
4. Update the `[Unreleased]` link with the latest tag.
|
||||
|
||||
## 2. Update READMEs
|
||||
|
||||
Each crate README has a reference to installing a specific version and needs to
|
||||
be updated.
|
||||
|
||||
e.g. Open `distant-core/README.md` and replace `0.17` with `0.18` if applicable
|
||||
|
||||
## 3. Update Crate Versions
|
||||
|
||||
Run a command to update the crate versions. An easy way is to use `sed`.
|
||||
|
||||
On Mac, this would be `sed -i '' "s~0.17.4~0.17.5~g" **/*.toml` where the old
|
||||
and new versions would be specified.
|
||||
|
||||
*Make sure to review the changed files! Sometimes a version overlaps with
|
||||
another crate and then we've bumped something wrong!*
|
||||
|
||||
## 4. Build to get Cargo.lock update
|
||||
|
||||
Run `cargo build` to get a new `Cargo.lock` refresh and commit it.
|
||||
|
||||
## 5. Tag Commit
|
||||
|
||||
Tag the release commit with the form `vMAJOR.MINOR.PATCH` by using
|
||||
`git tag vMAJOR.MINOR.PATCH` and publish the tag via `git push --tags`.
|
||||
|
||||
Once the tag is pushed, a new job will start to build and publish the artifacts
|
||||
on Github.
|
||||
|
||||
## 6. Publish Crates
|
||||
|
||||
Now, `cd` into each sub-crate and publish. Sometimes, it takes a little while
|
||||
for a crate to be indexed after getting published. This can lead to the publish
|
||||
of a downstream crate to fail. If so, try again in a couple of seconds.
|
||||
|
||||
1. **distant-net:** `(cd distant-net && cargo publish)`
|
||||
2. **distant-core:** `(cd distant-core && cargo publish)`
|
||||
3. **distant-ssh2:** `(cd distant-ssh2 && cargo publish)`
|
||||
4. **distant:** `cargo publish`
|
||||
|
||||
## 7. Celebrate
|
||||
|
||||
Another release done!
|
@ -0,0 +1,27 @@
|
||||
[package]
|
||||
name = "distant-auth"
|
||||
description = "Authentication library for distant, providing various implementations"
|
||||
categories = ["authentication"]
|
||||
keywords = ["auth", "authentication", "async"]
|
||||
version = "0.20.0"
|
||||
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
|
||||
edition = "2021"
|
||||
homepage = "https://github.com/chipsenkbeil/distant"
|
||||
repository = "https://github.com/chipsenkbeil/distant"
|
||||
readme = "README.md"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
tests = []
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.68"
|
||||
derive_more = { version = "0.99.17", default-features = false, features = ["display", "from", "error"] }
|
||||
log = "0.4.18"
|
||||
serde = { version = "1.0.163", features = ["derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger = "0.10.0"
|
||||
test-log = "0.2.11"
|
||||
tokio = { version = "1.28.2", features = ["full"] }
|
@ -0,0 +1,35 @@
|
||||
# distant auth
|
||||
|
||||
[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![Rustc 1.70.0][distant_rustc_img]][distant_rustc_lnk]
|
||||
|
||||
[distant_crates_img]: https://img.shields.io/crates/v/distant-auth.svg
|
||||
[distant_crates_lnk]: https://crates.io/crates/distant-auth
|
||||
[distant_doc_img]: https://docs.rs/distant-auth/badge.svg
|
||||
[distant_doc_lnk]: https://docs.rs/distant-auth
|
||||
[distant_rustc_img]: https://img.shields.io/badge/distant_auth-rustc_1.70+-lightgray.svg
|
||||
[distant_rustc_lnk]: https://blog.rust-lang.org/2023/06/01/Rust-1.70.0.html
|
||||
|
||||
## Details
|
||||
|
||||
The `distant-auth` library supplies the authentication functionality for the
|
||||
distant interfaces and distant cli.
|
||||
|
||||
## Installation
|
||||
|
||||
You can import the dependency by adding the following to your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
distant-auth = "0.20"
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under either of
|
||||
|
||||
Apache License, Version 2.0, (LICENSE-APACHE or
|
||||
[apache-license][apache-license]) MIT license (LICENSE-MIT or
|
||||
[mit-license][mit-license]) at your option.
|
||||
|
||||
[apache-license]: http://www.apache.org/licenses/LICENSE-2.0
|
||||
[mit-license]: http://opensource.org/licenses/MIT
|
@ -0,0 +1,110 @@
|
||||
use std::io;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::handler::AuthHandler;
|
||||
use crate::msg::*;
|
||||
|
||||
/// Represents an interface for authenticating with a server.
|
||||
#[async_trait]
|
||||
pub trait Authenticate {
|
||||
/// Performs authentication by leveraging the `handler` for any received challenge.
|
||||
async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()>;
|
||||
}
|
||||
|
||||
/// Represents an interface for submitting challenges for authentication.
|
||||
#[async_trait]
|
||||
pub trait Authenticator: Send {
|
||||
/// Issues an initialization notice and returns the response indicating which authentication
|
||||
/// methods to pursue
|
||||
async fn initialize(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse>;
|
||||
|
||||
/// Issues a challenge and returns the answers to the `questions` asked.
|
||||
async fn challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse>;
|
||||
|
||||
/// Requests verification of some `kind` and `text`, returning true if passed verification.
|
||||
async fn verify(&mut self, verification: Verification) -> io::Result<VerificationResponse>;
|
||||
|
||||
/// Reports information with no response expected.
|
||||
async fn info(&mut self, info: Info) -> io::Result<()>;
|
||||
|
||||
/// Reports an error occurred during authentication, consuming the authenticator since no more
|
||||
/// challenges should be issued.
|
||||
async fn error(&mut self, error: Error) -> io::Result<()>;
|
||||
|
||||
/// Reports that the authentication has started for a specific method.
|
||||
async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()>;
|
||||
|
||||
/// Reports that the authentication has finished successfully, consuming the authenticator
|
||||
/// since no more challenges should be issued.
|
||||
async fn finished(&mut self) -> io::Result<()>;
|
||||
}
|
||||
|
||||
/// Represents an implementator of [`Authenticator`] used purely for testing purposes.
|
||||
#[cfg(any(test, feature = "tests"))]
|
||||
pub struct TestAuthenticator {
|
||||
pub initialize: Box<dyn FnMut(Initialization) -> io::Result<InitializationResponse> + Send>,
|
||||
pub challenge: Box<dyn FnMut(Challenge) -> io::Result<ChallengeResponse> + Send>,
|
||||
pub verify: Box<dyn FnMut(Verification) -> io::Result<VerificationResponse> + Send>,
|
||||
pub info: Box<dyn FnMut(Info) -> io::Result<()> + Send>,
|
||||
pub error: Box<dyn FnMut(Error) -> io::Result<()> + Send>,
|
||||
pub start_method: Box<dyn FnMut(StartMethod) -> io::Result<()> + Send>,
|
||||
pub finished: Box<dyn FnMut() -> io::Result<()> + Send>,
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "tests"))]
|
||||
impl Default for TestAuthenticator {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
initialize: Box::new(|x| Ok(InitializationResponse { methods: x.methods })),
|
||||
challenge: Box::new(|x| {
|
||||
Ok(ChallengeResponse {
|
||||
answers: x.questions.into_iter().map(|x| x.text).collect(),
|
||||
})
|
||||
}),
|
||||
verify: Box::new(|_| Ok(VerificationResponse { valid: true })),
|
||||
info: Box::new(|_| Ok(())),
|
||||
error: Box::new(|_| Ok(())),
|
||||
start_method: Box::new(|_| Ok(())),
|
||||
finished: Box::new(|| Ok(())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "tests"))]
|
||||
#[async_trait]
|
||||
impl Authenticator for TestAuthenticator {
|
||||
async fn initialize(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
(self.initialize)(initialization)
|
||||
}
|
||||
|
||||
async fn challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
(self.challenge)(challenge)
|
||||
}
|
||||
|
||||
async fn verify(&mut self, verification: Verification) -> io::Result<VerificationResponse> {
|
||||
(self.verify)(verification)
|
||||
}
|
||||
|
||||
async fn info(&mut self, info: Info) -> io::Result<()> {
|
||||
(self.info)(info)
|
||||
}
|
||||
|
||||
async fn error(&mut self, error: Error) -> io::Result<()> {
|
||||
(self.error)(error)
|
||||
}
|
||||
|
||||
async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
(self.start_method)(start_method)
|
||||
}
|
||||
|
||||
async fn finished(&mut self) -> io::Result<()> {
|
||||
(self.finished)()
|
||||
}
|
||||
}
|
@ -0,0 +1,422 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::io;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::authenticator::Authenticator;
|
||||
use crate::msg::*;
|
||||
|
||||
mod methods;
|
||||
pub use methods::*;
|
||||
|
||||
/// Interface for a handler of authentication requests for all methods.
|
||||
#[async_trait]
|
||||
pub trait AuthHandler: AuthMethodHandler + Send {
|
||||
/// Callback when authentication is beginning, providing available authentication methods and
|
||||
/// returning selected authentication methods to pursue.
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
Ok(InitializationResponse {
|
||||
methods: initialization.methods,
|
||||
})
|
||||
}
|
||||
|
||||
/// Callback when authentication starts for a specific method.
|
||||
#[allow(unused_variables)]
|
||||
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Callback when authentication is finished and no more requests will be received.
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Dummy implementation of [`AuthHandler`] where any challenge or verification request will
|
||||
/// instantly fail.
|
||||
pub struct DummyAuthHandler;
|
||||
|
||||
#[async_trait]
|
||||
impl AuthHandler for DummyAuthHandler {}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthMethodHandler for DummyAuthHandler {
|
||||
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_verification(&mut self, _: Verification) -> io::Result<VerificationResponse> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, _: Info) -> io::Result<()> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, _: Error) -> io::Result<()> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of [`AuthHandler`] that uses the same [`AuthMethodHandler`] for all methods.
|
||||
pub struct SingleAuthHandler(Box<dyn AuthMethodHandler>);
|
||||
|
||||
impl SingleAuthHandler {
|
||||
pub fn new<T: AuthMethodHandler + 'static>(method_handler: T) -> Self {
|
||||
Self(Box::new(method_handler))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthHandler for SingleAuthHandler {}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthMethodHandler for SingleAuthHandler {
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
self.0.on_challenge(challenge).await
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
self.0.on_verification(verification).await
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
self.0.on_info(info).await
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
self.0.on_error(error).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of [`AuthHandler`] that maintains a map of [`AuthMethodHandler`] implementations
|
||||
/// for specific methods, invoking [`on_challenge`], [`on_verification`], [`on_info`], and
|
||||
/// [`on_error`] for a specific handler based on an associated id.
|
||||
///
|
||||
/// [`on_challenge`]: AuthMethodHandler::on_challenge
|
||||
/// [`on_verification`]: AuthMethodHandler::on_verification
|
||||
/// [`on_info`]: AuthMethodHandler::on_info
|
||||
/// [`on_error`]: AuthMethodHandler::on_error
|
||||
pub struct AuthHandlerMap {
|
||||
active: String,
|
||||
map: HashMap<&'static str, Box<dyn AuthMethodHandler>>,
|
||||
}
|
||||
|
||||
impl AuthHandlerMap {
|
||||
/// Creates a new, empty map of auth method handlers.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
active: String::new(),
|
||||
map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the `id` of the active [`AuthMethodHandler`].
|
||||
pub fn active_id(&self) -> &str {
|
||||
&self.active
|
||||
}
|
||||
|
||||
/// Sets the active [`AuthMethodHandler`] by its `id`.
|
||||
pub fn set_active_id(&mut self, id: impl Into<String>) {
|
||||
self.active = id.into();
|
||||
}
|
||||
|
||||
/// Inserts the specified `handler` into the map, associating it with `id` for determining the
|
||||
/// method that would trigger this handler.
|
||||
pub fn insert_method_handler<T: AuthMethodHandler + 'static>(
|
||||
&mut self,
|
||||
id: &'static str,
|
||||
handler: T,
|
||||
) -> Option<Box<dyn AuthMethodHandler>> {
|
||||
self.map.insert(id, Box::new(handler))
|
||||
}
|
||||
|
||||
/// Removes a handler with the associated `id`.
|
||||
pub fn remove_method_handler(
|
||||
&mut self,
|
||||
id: &'static str,
|
||||
) -> Option<Box<dyn AuthMethodHandler>> {
|
||||
self.map.remove(id)
|
||||
}
|
||||
|
||||
/// Retrieves a mutable reference to the active [`AuthMethodHandler`] with the specified `id`,
|
||||
/// returning an error if no handler for the active id is found.
|
||||
pub fn get_mut_active_method_handler_or_error(
|
||||
&mut self,
|
||||
) -> io::Result<&mut (dyn AuthMethodHandler + 'static)> {
|
||||
let id = self.active.clone();
|
||||
self.get_mut_active_method_handler().ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::Other, format!("No active handler for {id}"))
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieves a mutable reference to the active [`AuthMethodHandler`] with the specified `id`.
|
||||
pub fn get_mut_active_method_handler(
|
||||
&mut self,
|
||||
) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
|
||||
// TODO: Optimize this
|
||||
self.get_mut_method_handler(&self.active.clone())
|
||||
}
|
||||
|
||||
/// Retrieves a mutable reference to the [`AuthMethodHandler`] with the specified `id`.
|
||||
pub fn get_mut_method_handler(
|
||||
&mut self,
|
||||
id: &str,
|
||||
) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
|
||||
self.map.get_mut(id).map(|h| h.as_mut())
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthHandlerMap {
|
||||
/// Consumes the map, returning a new map that supports the `static_key` method.
|
||||
pub fn with_static_key<K>(mut self, key: K) -> Self
|
||||
where
|
||||
K: Display + Send + 'static,
|
||||
{
|
||||
self.insert_method_handler("static_key", StaticKeyAuthMethodHandler::simple(key));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AuthHandlerMap {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthHandler for AuthHandlerMap {
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
let methods = initialization
|
||||
.methods
|
||||
.into_iter()
|
||||
.filter(|method| self.map.contains_key(method.as_str()))
|
||||
.collect();
|
||||
|
||||
Ok(InitializationResponse { methods })
|
||||
}
|
||||
|
||||
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
self.set_active_id(start_method.method);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthMethodHandler for AuthHandlerMap {
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
let handler = self.get_mut_active_method_handler_or_error()?;
|
||||
handler.on_challenge(challenge).await
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
let handler = self.get_mut_active_method_handler_or_error()?;
|
||||
handler.on_verification(verification).await
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
let handler = self.get_mut_active_method_handler_or_error()?;
|
||||
handler.on_info(info).await
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
let handler = self.get_mut_active_method_handler_or_error()?;
|
||||
handler.on_error(error).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of [`AuthHandler`] that redirects all requests to an [`Authenticator`].
|
||||
pub struct ProxyAuthHandler<'a>(&'a mut dyn Authenticator);
|
||||
|
||||
impl<'a> ProxyAuthHandler<'a> {
|
||||
pub fn new(authenticator: &'a mut dyn Authenticator) -> Self {
|
||||
Self(authenticator)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> AuthHandler for ProxyAuthHandler<'a> {
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
Authenticator::initialize(self.0, initialization).await
|
||||
}
|
||||
|
||||
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
Authenticator::start_method(self.0, start_method).await
|
||||
}
|
||||
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
Authenticator::finished(self.0).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> AuthMethodHandler for ProxyAuthHandler<'a> {
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
Authenticator::challenge(self.0, challenge).await
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
Authenticator::verify(self.0, verification).await
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
Authenticator::info(self.0, info).await
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
Authenticator::error(self.0, error).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of [`AuthHandler`] that holds a mutable reference to another [`AuthHandler`]
|
||||
/// trait object to use underneath.
|
||||
pub struct DynAuthHandler<'a>(&'a mut dyn AuthHandler);
|
||||
|
||||
impl<'a> DynAuthHandler<'a> {
|
||||
pub fn new(handler: &'a mut dyn AuthHandler) -> Self {
|
||||
Self(handler)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: AuthHandler> From<&'a mut T> for DynAuthHandler<'a> {
|
||||
fn from(handler: &'a mut T) -> Self {
|
||||
Self::new(handler as &mut dyn AuthHandler)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> AuthHandler for DynAuthHandler<'a> {
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
self.0.on_initialization(initialization).await
|
||||
}
|
||||
|
||||
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
self.0.on_start_method(start_method).await
|
||||
}
|
||||
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
self.0.on_finished().await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> AuthMethodHandler for DynAuthHandler<'a> {
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
self.0.on_challenge(challenge).await
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
self.0.on_verification(verification).await
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
self.0.on_info(info).await
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
self.0.on_error(error).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents an implementator of [`AuthHandler`] used purely for testing purposes.
|
||||
#[cfg(any(test, feature = "tests"))]
|
||||
pub struct TestAuthHandler {
|
||||
pub on_initialization:
|
||||
Box<dyn FnMut(Initialization) -> io::Result<InitializationResponse> + Send>,
|
||||
pub on_challenge: Box<dyn FnMut(Challenge) -> io::Result<ChallengeResponse> + Send>,
|
||||
pub on_verification: Box<dyn FnMut(Verification) -> io::Result<VerificationResponse> + Send>,
|
||||
pub on_info: Box<dyn FnMut(Info) -> io::Result<()> + Send>,
|
||||
pub on_error: Box<dyn FnMut(Error) -> io::Result<()> + Send>,
|
||||
pub on_start_method: Box<dyn FnMut(StartMethod) -> io::Result<()> + Send>,
|
||||
pub on_finished: Box<dyn FnMut() -> io::Result<()> + Send>,
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "tests"))]
|
||||
impl Default for TestAuthHandler {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
on_initialization: Box::new(|x| Ok(InitializationResponse { methods: x.methods })),
|
||||
on_challenge: Box::new(|x| {
|
||||
Ok(ChallengeResponse {
|
||||
answers: x.questions.into_iter().map(|x| x.text).collect(),
|
||||
})
|
||||
}),
|
||||
on_verification: Box::new(|_| Ok(VerificationResponse { valid: true })),
|
||||
on_info: Box::new(|_| Ok(())),
|
||||
on_error: Box::new(|_| Ok(())),
|
||||
on_start_method: Box::new(|_| Ok(())),
|
||||
on_finished: Box::new(|| Ok(())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "tests"))]
|
||||
#[async_trait]
|
||||
impl AuthHandler for TestAuthHandler {
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
(self.on_initialization)(initialization)
|
||||
}
|
||||
|
||||
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
(self.on_start_method)(start_method)
|
||||
}
|
||||
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
(self.on_finished)()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "tests"))]
|
||||
#[async_trait]
|
||||
impl AuthMethodHandler for TestAuthHandler {
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
(self.on_challenge)(challenge)
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
(self.on_verification)(verification)
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
(self.on_info)(info)
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
(self.on_error)(error)
|
||||
}
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
use std::io;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::msg::{Challenge, ChallengeResponse, Error, Info, Verification, VerificationResponse};
|
||||
|
||||
/// Interface for a handler of authentication requests for a specific authentication method.
|
||||
#[async_trait]
|
||||
pub trait AuthMethodHandler: Send {
|
||||
/// Callback when a challenge is received, returning answers to the given questions.
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse>;
|
||||
|
||||
/// Callback when a verification request is received, returning true if approvided or false if
|
||||
/// unapproved.
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse>;
|
||||
|
||||
/// Callback when information is received. To fail, return an error from this function.
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()>;
|
||||
|
||||
/// Callback when an error is received. Regardless of the result returned, this will terminate
|
||||
/// the authenticator. In the situation where a custom error would be preferred, have this
|
||||
/// callback return an error.
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()>;
|
||||
}
|
||||
|
||||
mod prompt;
|
||||
pub use prompt::*;
|
||||
|
||||
mod static_key;
|
||||
pub use static_key::*;
|
@ -0,0 +1,90 @@
|
||||
use std::io;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use log::*;
|
||||
|
||||
use crate::handler::AuthMethodHandler;
|
||||
use crate::msg::{
|
||||
Challenge, ChallengeResponse, Error, Info, Verification, VerificationKind, VerificationResponse,
|
||||
};
|
||||
|
||||
/// Blocking implementation of [`AuthMethodHandler`] that uses prompts to communicate challenge &
|
||||
/// verification requests, receiving responses to relay back.
|
||||
pub struct PromptAuthMethodHandler<T, U> {
|
||||
text_prompt: T,
|
||||
password_prompt: U,
|
||||
}
|
||||
|
||||
impl<T, U> PromptAuthMethodHandler<T, U> {
|
||||
pub fn new(text_prompt: T, password_prompt: U) -> Self {
|
||||
Self {
|
||||
text_prompt,
|
||||
password_prompt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, U> AuthMethodHandler for PromptAuthMethodHandler<T, U>
|
||||
where
|
||||
T: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
|
||||
U: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
|
||||
{
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
trace!("on_challenge({challenge:?})");
|
||||
let mut answers = Vec::new();
|
||||
for question in challenge.questions.iter() {
|
||||
// Contains all prompt lines including same line
|
||||
let mut lines = question.text.split('\n').collect::<Vec<_>>();
|
||||
|
||||
// Line that is prompt on same line as answer
|
||||
let line = lines.pop().unwrap();
|
||||
|
||||
// Go ahead and display all other lines
|
||||
for line in lines.into_iter() {
|
||||
eprintln!("{line}");
|
||||
}
|
||||
|
||||
// Get an answer from user input, or use a blank string as an answer
|
||||
// if we fail to get input from the user
|
||||
let answer = (self.password_prompt)(line).unwrap_or_default();
|
||||
|
||||
answers.push(answer);
|
||||
}
|
||||
Ok(ChallengeResponse { answers })
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
trace!("on_verify({verification:?})");
|
||||
match verification.kind {
|
||||
VerificationKind::Host => {
|
||||
eprintln!("{}", verification.text);
|
||||
|
||||
let answer = (self.text_prompt)("Enter [y/N]> ")?;
|
||||
trace!("Verify? Answer = '{answer}'");
|
||||
Ok(VerificationResponse {
|
||||
valid: matches!(answer.trim(), "y" | "Y" | "yes" | "YES"),
|
||||
})
|
||||
}
|
||||
x => {
|
||||
error!("Unsupported verify kind: {x}");
|
||||
Ok(VerificationResponse { valid: false })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
trace!("on_info({info:?})");
|
||||
println!("{}", info.text);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
trace!("on_error({error:?})");
|
||||
eprintln!("{}: {}", error.kind, error.text);
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -0,0 +1,175 @@
|
||||
use std::fmt::Display;
|
||||
use std::io;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use log::*;
|
||||
|
||||
use crate::handler::AuthMethodHandler;
|
||||
use crate::msg::{Challenge, ChallengeResponse, Error, Info, Verification, VerificationResponse};
|
||||
|
||||
/// Implementation of [`AuthMethodHandler`] that answers challenge requests using a static
|
||||
/// [`HeapSecretKey`]. All other portions of method authentication are handled by another
|
||||
/// [`AuthMethodHandler`].
|
||||
pub struct StaticKeyAuthMethodHandler<K> {
|
||||
key: K,
|
||||
handler: Box<dyn AuthMethodHandler>,
|
||||
}
|
||||
|
||||
impl<K> StaticKeyAuthMethodHandler<K> {
|
||||
/// Creates a new [`StaticKeyAuthMethodHandler`] that responds to challenges using a static
|
||||
/// `key`. All other requests are passed to the `handler`.
|
||||
pub fn new<T: AuthMethodHandler + 'static>(key: K, handler: T) -> Self {
|
||||
Self {
|
||||
key,
|
||||
handler: Box::new(handler),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new [`StaticKeyAuthMethodHandler`] that responds to challenges using a static
|
||||
/// `key`. All other requests are passed automatically, meaning that verification is always
|
||||
/// approvide and info/errors are ignored.
|
||||
pub fn simple(key: K) -> Self {
|
||||
Self::new(key, {
|
||||
struct __AuthMethodHandler;
|
||||
|
||||
#[async_trait]
|
||||
impl AuthMethodHandler for __AuthMethodHandler {
|
||||
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
|
||||
unreachable!("on_challenge should be handled by StaticKeyAuthMethodHandler");
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
_: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
Ok(VerificationResponse { valid: true })
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, _: Info) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, _: Error) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
__AuthMethodHandler
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<K> AuthMethodHandler for StaticKeyAuthMethodHandler<K>
|
||||
where
|
||||
K: Display + Send,
|
||||
{
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
trace!("on_challenge({challenge:?})");
|
||||
let mut answers = Vec::new();
|
||||
for question in challenge.questions.iter() {
|
||||
// Only challenges with a "key" label are allowed, all else will fail
|
||||
if question.label != "key" {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Only 'key' challenges are supported",
|
||||
));
|
||||
}
|
||||
answers.push(self.key.to_string());
|
||||
}
|
||||
Ok(ChallengeResponse { answers })
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
trace!("on_verify({verification:?})");
|
||||
self.handler.on_verification(verification).await
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
trace!("on_info({info:?})");
|
||||
self.handler.on_info(info).await
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
trace!("on_error({error:?})");
|
||||
self.handler.on_error(error).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use test_log::test;
|
||||
|
||||
use super::*;
|
||||
use crate::msg::{ErrorKind, Question, VerificationKind};
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn on_challenge_should_fail_if_non_key_question_received() {
|
||||
let mut handler = StaticKeyAuthMethodHandler::simple(String::from("secret-key"));
|
||||
|
||||
handler
|
||||
.on_challenge(Challenge {
|
||||
questions: vec![Question::new("test")],
|
||||
options: Default::default(),
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn on_challenge_should_answer_with_stringified_key_for_key_questions() {
|
||||
let mut handler = StaticKeyAuthMethodHandler::simple(String::from("secret-key"));
|
||||
|
||||
let response = handler
|
||||
.on_challenge(Challenge {
|
||||
questions: vec![Question::new("key")],
|
||||
options: Default::default(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(response.answers.len(), 1, "Wrong answer set received");
|
||||
assert!(!response.answers[0].is_empty(), "Empty answer being sent");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn on_verification_should_leverage_fallback_handler() {
|
||||
let mut handler = StaticKeyAuthMethodHandler::simple(String::from("secret-key"));
|
||||
|
||||
let response = handler
|
||||
.on_verification(Verification {
|
||||
kind: VerificationKind::Host,
|
||||
text: "host".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(response.valid, "Unexpected result from fallback handler");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn on_info_should_leverage_fallback_handler() {
|
||||
let mut handler = StaticKeyAuthMethodHandler::simple(String::from("secret-key"));
|
||||
|
||||
handler
|
||||
.on_info(Info {
|
||||
text: "info".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn on_error_should_leverage_fallback_handler() {
|
||||
let mut handler = StaticKeyAuthMethodHandler::simple(String::from("secret-key"));
|
||||
|
||||
handler
|
||||
.on_error(Error {
|
||||
kind: ErrorKind::Error,
|
||||
text: "text".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
#![doc = include_str!("../README.md")]
|
||||
|
||||
#[doc = include_str!("../README.md")]
|
||||
#[cfg(doctest)]
|
||||
pub struct ReadmeDoctests;
|
||||
|
||||
mod authenticator;
|
||||
mod handler;
|
||||
mod methods;
|
||||
pub mod msg;
|
||||
|
||||
pub use authenticator::*;
|
||||
pub use handler::*;
|
||||
pub use methods::*;
|
||||
|
||||
#[cfg(any(test, feature = "tests"))]
|
||||
pub mod tests {
|
||||
pub use crate::{TestAuthHandler, TestAuthenticator};
|
||||
}
|
@ -0,0 +1,365 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::str::FromStr;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use log::*;
|
||||
|
||||
use crate::authenticator::Authenticator;
|
||||
use crate::msg::*;
|
||||
|
||||
mod none;
|
||||
mod static_key;
|
||||
|
||||
pub use none::*;
|
||||
pub use static_key::*;
|
||||
|
||||
/// Supports authenticating using a variety of methods
|
||||
pub struct Verifier {
|
||||
methods: HashMap<&'static str, Box<dyn AuthenticationMethod>>,
|
||||
}
|
||||
|
||||
impl Verifier {
|
||||
pub fn new<I>(methods: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = Box<dyn AuthenticationMethod>>,
|
||||
{
|
||||
let mut m = HashMap::new();
|
||||
|
||||
for method in methods {
|
||||
m.insert(method.id(), method);
|
||||
}
|
||||
|
||||
Self { methods: m }
|
||||
}
|
||||
|
||||
/// Creates a verifier with no methods.
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
methods: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a verifier that uses the [`NoneAuthenticationMethod`] exclusively.
|
||||
pub fn none() -> Self {
|
||||
Self::new(vec![
|
||||
Box::new(NoneAuthenticationMethod::new()) as Box<dyn AuthenticationMethod>
|
||||
])
|
||||
}
|
||||
|
||||
/// Creates a verifier that uses the [`StaticKeyAuthenticationMethod`] exclusively.
|
||||
pub fn static_key<K>(key: K) -> Self
|
||||
where
|
||||
K: FromStr + PartialEq + Send + Sync + 'static,
|
||||
{
|
||||
Self::new(vec![
|
||||
Box::new(StaticKeyAuthenticationMethod::new(key)) as Box<dyn AuthenticationMethod>
|
||||
])
|
||||
}
|
||||
|
||||
/// Returns an iterator over the ids of the methods supported by the verifier
|
||||
pub fn methods(&self) -> impl Iterator<Item = &'static str> + '_ {
|
||||
self.methods.keys().copied()
|
||||
}
|
||||
|
||||
/// Attempts to verify by submitting challenges using the `authenticator` provided. Returns the
|
||||
/// id of the authentication method that succeeded. Fails if no authentication method succeeds.
|
||||
pub async fn verify(&self, authenticator: &mut dyn Authenticator) -> io::Result<&'static str> {
|
||||
// Initiate the process to get methods to use
|
||||
let response = authenticator
|
||||
.initialize(Initialization {
|
||||
methods: self.methods.keys().map(ToString::to_string).collect(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
for method in response.methods {
|
||||
match self.methods.get(method.as_str()) {
|
||||
Some(method) => {
|
||||
// Report the authentication method
|
||||
authenticator
|
||||
.start_method(StartMethod {
|
||||
method: method.id().to_string(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Perform the actual authentication
|
||||
if method.authenticate(authenticator).await.is_ok() {
|
||||
authenticator.finished().await?;
|
||||
return Ok(method.id());
|
||||
}
|
||||
}
|
||||
None => {
|
||||
trace!("Skipping authentication {method} as it is not available or supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::PermissionDenied,
|
||||
"No authentication method succeeded",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<Box<dyn AuthenticationMethod>>> for Verifier {
|
||||
fn from(methods: Vec<Box<dyn AuthenticationMethod>>) -> Self {
|
||||
Self::new(methods)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents an interface to authenticate using some method
|
||||
#[async_trait]
|
||||
pub trait AuthenticationMethod: Send + Sync {
|
||||
/// Returns a unique id to distinguish the method from other methods
|
||||
fn id(&self) -> &'static str;
|
||||
|
||||
/// Performs authentication using the `authenticator` to submit challenges and other
|
||||
/// information based on the authentication method
|
||||
async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::mpsc;
|
||||
|
||||
use test_log::test;
|
||||
|
||||
use super::*;
|
||||
use crate::authenticator::TestAuthenticator;
|
||||
|
||||
struct SuccessAuthenticationMethod;
|
||||
|
||||
#[async_trait]
|
||||
impl AuthenticationMethod for SuccessAuthenticationMethod {
|
||||
fn id(&self) -> &'static str {
|
||||
"success"
|
||||
}
|
||||
|
||||
async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct FailAuthenticationMethod;
|
||||
|
||||
#[async_trait]
|
||||
impl AuthenticationMethod for FailAuthenticationMethod {
|
||||
fn id(&self) -> &'static str {
|
||||
"fail"
|
||||
}
|
||||
|
||||
async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {
|
||||
Err(io::Error::from(io::ErrorKind::Other))
|
||||
}
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_fail_to_verify_if_initialization_fails() {
|
||||
let mut authenticator = TestAuthenticator {
|
||||
initialize: Box::new(|_| Err(io::Error::from(io::ErrorKind::Other))),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> =
|
||||
vec![Box::new(SuccessAuthenticationMethod)];
|
||||
let verifier = Verifier::from(methods);
|
||||
verifier.verify(&mut authenticator).await.unwrap_err();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_fail_to_verify_if_fails_to_send_finished_indicator_after_success() {
|
||||
let mut authenticator = TestAuthenticator {
|
||||
initialize: Box::new(|_| {
|
||||
Ok(InitializationResponse {
|
||||
methods: vec![SuccessAuthenticationMethod.id().to_string()]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
})
|
||||
}),
|
||||
finished: Box::new(|| Err(io::Error::new(io::ErrorKind::Other, "test error"))),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> =
|
||||
vec![Box::new(SuccessAuthenticationMethod)];
|
||||
let verifier = Verifier::from(methods);
|
||||
|
||||
let err = verifier.verify(&mut authenticator).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::Other);
|
||||
assert_eq!(err.to_string(), "test error");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_fail_to_verify_if_has_no_authentication_methods() {
|
||||
let mut authenticator = TestAuthenticator {
|
||||
initialize: Box::new(|_| {
|
||||
Ok(InitializationResponse {
|
||||
methods: vec![SuccessAuthenticationMethod.id().to_string()]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
})
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![];
|
||||
let verifier = Verifier::from(methods);
|
||||
verifier.verify(&mut authenticator).await.unwrap_err();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_fail_to_verify_if_initialization_yields_no_valid_authentication_methods(
|
||||
) {
|
||||
let mut authenticator = TestAuthenticator {
|
||||
initialize: Box::new(|_| {
|
||||
Ok(InitializationResponse {
|
||||
methods: vec!["other".to_string()].into_iter().collect(),
|
||||
})
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> =
|
||||
vec![Box::new(SuccessAuthenticationMethod)];
|
||||
let verifier = Verifier::from(methods);
|
||||
verifier.verify(&mut authenticator).await.unwrap_err();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_fail_to_verify_if_no_authentication_method_succeeds() {
|
||||
let mut authenticator = TestAuthenticator {
|
||||
initialize: Box::new(|_| {
|
||||
Ok(InitializationResponse {
|
||||
methods: vec![FailAuthenticationMethod.id().to_string()]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
})
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![Box::new(FailAuthenticationMethod)];
|
||||
let verifier = Verifier::from(methods);
|
||||
verifier.verify(&mut authenticator).await.unwrap_err();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_return_id_of_authentication_method_upon_success() {
|
||||
let mut authenticator = TestAuthenticator {
|
||||
initialize: Box::new(|_| {
|
||||
Ok(InitializationResponse {
|
||||
methods: vec![SuccessAuthenticationMethod.id().to_string()]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
})
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> =
|
||||
vec![Box::new(SuccessAuthenticationMethod)];
|
||||
let verifier = Verifier::from(methods);
|
||||
assert_eq!(
|
||||
verifier.verify(&mut authenticator).await.unwrap(),
|
||||
SuccessAuthenticationMethod.id()
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_try_authentication_methods_in_order_until_one_succeeds() {
|
||||
let mut authenticator = TestAuthenticator {
|
||||
initialize: Box::new(|_| {
|
||||
Ok(InitializationResponse {
|
||||
methods: vec![
|
||||
FailAuthenticationMethod.id().to_string(),
|
||||
SuccessAuthenticationMethod.id().to_string(),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
})
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
|
||||
Box::new(FailAuthenticationMethod),
|
||||
Box::new(SuccessAuthenticationMethod),
|
||||
];
|
||||
let verifier = Verifier::from(methods);
|
||||
assert_eq!(
|
||||
verifier.verify(&mut authenticator).await.unwrap(),
|
||||
SuccessAuthenticationMethod.id()
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_send_start_method_before_attempting_each_method() {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
let mut authenticator = TestAuthenticator {
|
||||
initialize: Box::new(|_| {
|
||||
Ok(InitializationResponse {
|
||||
methods: vec![
|
||||
FailAuthenticationMethod.id().to_string(),
|
||||
SuccessAuthenticationMethod.id().to_string(),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
})
|
||||
}),
|
||||
start_method: Box::new(move |method| {
|
||||
tx.send(method.method).unwrap();
|
||||
Ok(())
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
|
||||
Box::new(FailAuthenticationMethod),
|
||||
Box::new(SuccessAuthenticationMethod),
|
||||
];
|
||||
Verifier::from(methods)
|
||||
.verify(&mut authenticator)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(rx.try_recv().unwrap(), FailAuthenticationMethod.id());
|
||||
assert_eq!(rx.try_recv().unwrap(), SuccessAuthenticationMethod.id());
|
||||
assert_eq!(rx.try_recv().unwrap_err(), mpsc::TryRecvError::Empty);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_send_finished_when_a_method_succeeds() {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
let mut authenticator = TestAuthenticator {
|
||||
initialize: Box::new(|_| {
|
||||
Ok(InitializationResponse {
|
||||
methods: vec![
|
||||
FailAuthenticationMethod.id().to_string(),
|
||||
SuccessAuthenticationMethod.id().to_string(),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
})
|
||||
}),
|
||||
finished: Box::new(move || {
|
||||
tx.send(()).unwrap();
|
||||
Ok(())
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
|
||||
Box::new(FailAuthenticationMethod),
|
||||
Box::new(SuccessAuthenticationMethod),
|
||||
];
|
||||
Verifier::from(methods)
|
||||
.verify(&mut authenticator)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
rx.try_recv().unwrap();
|
||||
assert_eq!(rx.try_recv().unwrap_err(), mpsc::TryRecvError::Empty);
|
||||
}
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
use std::io;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::authenticator::Authenticator;
|
||||
use crate::methods::AuthenticationMethod;
|
||||
|
||||
/// Authenticaton method that skips authentication and approves anything.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NoneAuthenticationMethod;
|
||||
|
||||
impl NoneAuthenticationMethod {
|
||||
pub const ID: &str = "none";
|
||||
|
||||
#[inline]
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NoneAuthenticationMethod {
|
||||
#[inline]
|
||||
fn default() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthenticationMethod for NoneAuthenticationMethod {
|
||||
fn id(&self) -> &'static str {
|
||||
Self::ID
|
||||
}
|
||||
|
||||
async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -0,0 +1,133 @@
|
||||
use std::io;
|
||||
use std::str::FromStr;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::authenticator::Authenticator;
|
||||
use crate::methods::AuthenticationMethod;
|
||||
use crate::msg::{Challenge, Error, Question};
|
||||
|
||||
/// Authenticaton method for a static secret key
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct StaticKeyAuthenticationMethod<T> {
|
||||
key: T,
|
||||
}
|
||||
|
||||
impl<T> StaticKeyAuthenticationMethod<T> {
|
||||
pub const ID: &str = "static_key";
|
||||
|
||||
#[inline]
|
||||
pub fn new(key: T) -> Self {
|
||||
Self { key }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> AuthenticationMethod for StaticKeyAuthenticationMethod<T>
|
||||
where
|
||||
T: FromStr + PartialEq + Send + Sync,
|
||||
{
|
||||
fn id(&self) -> &'static str {
|
||||
Self::ID
|
||||
}
|
||||
|
||||
async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()> {
|
||||
let response = authenticator
|
||||
.challenge(Challenge {
|
||||
questions: vec![Question {
|
||||
label: "key".to_string(),
|
||||
text: "Provide a key: ".to_string(),
|
||||
options: Default::default(),
|
||||
}],
|
||||
options: Default::default(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
if response.answers.is_empty() {
|
||||
return Err(Error::non_fatal("missing answer").into_io_permission_denied());
|
||||
}
|
||||
|
||||
match response.answers.into_iter().next().unwrap().parse::<T>() {
|
||||
Ok(key) if key == self.key => Ok(()),
|
||||
_ => Err(Error::non_fatal("answer does not match key").into_io_permission_denied()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use test_log::test;
|
||||
|
||||
use super::*;
|
||||
use crate::authenticator::TestAuthenticator;
|
||||
use crate::msg::*;
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticate_should_fail_if_key_challenge_fails() {
|
||||
let method = StaticKeyAuthenticationMethod::new(String::new());
|
||||
|
||||
let mut authenticator = TestAuthenticator {
|
||||
challenge: Box::new(|_| Err(io::Error::new(io::ErrorKind::InvalidData, "test error"))),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = method.authenticate(&mut authenticator).await.unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
assert_eq!(err.to_string(), "test error");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticate_should_fail_if_no_answer_included_in_challenge_response() {
|
||||
let method = StaticKeyAuthenticationMethod::new(String::new());
|
||||
|
||||
let mut authenticator = TestAuthenticator {
|
||||
challenge: Box::new(|_| {
|
||||
Ok(ChallengeResponse {
|
||||
answers: Vec::new(),
|
||||
})
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = method.authenticate(&mut authenticator).await.unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
|
||||
assert_eq!(err.to_string(), "Error: missing answer");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticate_should_fail_if_answer_does_not_match_key() {
|
||||
let method = StaticKeyAuthenticationMethod::new(String::from("answer"));
|
||||
|
||||
let mut authenticator = TestAuthenticator {
|
||||
challenge: Box::new(|_| {
|
||||
Ok(ChallengeResponse {
|
||||
answers: vec![String::from("other")],
|
||||
})
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = method.authenticate(&mut authenticator).await.unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
|
||||
assert_eq!(err.to_string(), "Error: answer does not match key");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticate_should_succeed_if_answer_matches_key() {
|
||||
let method = StaticKeyAuthenticationMethod::new(String::from("answer"));
|
||||
|
||||
let mut authenticator = TestAuthenticator {
|
||||
challenge: Box::new(|_| {
|
||||
Ok(ChallengeResponse {
|
||||
answers: vec![String::from("answer")],
|
||||
})
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
method.authenticate(&mut authenticator).await.unwrap();
|
||||
}
|
||||
}
|
@ -0,0 +1,217 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use derive_more::{Display, Error, From};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Represents messages from an authenticator that act as initiators such as providing
|
||||
/// a challenge, verifying information, presenting information, or highlighting an error
|
||||
#[derive(Clone, Debug, From, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum Authentication {
|
||||
/// Indicates the beginning of authentication, providing available methods
|
||||
#[serde(rename = "auth_initialization")]
|
||||
Initialization(Initialization),
|
||||
|
||||
/// Indicates that authentication is starting for the specific `method`
|
||||
#[serde(rename = "auth_start_method")]
|
||||
StartMethod(StartMethod),
|
||||
|
||||
/// Issues a challenge to be answered
|
||||
#[serde(rename = "auth_challenge")]
|
||||
Challenge(Challenge),
|
||||
|
||||
/// Requests verification of some text
|
||||
#[serde(rename = "auth_verification")]
|
||||
Verification(Verification),
|
||||
|
||||
/// Reports some information associated with authentication
|
||||
#[serde(rename = "auth_info")]
|
||||
Info(Info),
|
||||
|
||||
/// Reports an error occurrred during authentication
|
||||
#[serde(rename = "auth_error")]
|
||||
Error(Error),
|
||||
|
||||
/// Indicates that the authentication of all methods is finished
|
||||
#[serde(rename = "auth_finished")]
|
||||
Finished,
|
||||
}
|
||||
|
||||
/// Represents the beginning of the authentication procedure
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Initialization {
|
||||
/// Available methods to use for authentication
|
||||
pub methods: Vec<String>,
|
||||
}
|
||||
|
||||
/// Represents the start of authentication for some method
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct StartMethod {
|
||||
pub method: String,
|
||||
}
|
||||
|
||||
/// Represents a challenge comprising a series of questions to be presented
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Challenge {
|
||||
pub questions: Vec<Question>,
|
||||
pub options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Represents an ask to verify some information
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Verification {
|
||||
pub kind: VerificationKind,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
/// Represents some information to be presented related to authentication
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Info {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
/// Represents authentication messages that are responses to authenticator requests such
|
||||
/// as answers to challenges or verifying information
|
||||
#[derive(Clone, Debug, From, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthenticationResponse {
|
||||
/// Contains response to initialization, providing details about which methods to use
|
||||
#[serde(rename = "auth_initialization_response")]
|
||||
Initialization(InitializationResponse),
|
||||
|
||||
/// Contains answers to challenge request
|
||||
#[serde(rename = "auth_challenge_response")]
|
||||
Challenge(ChallengeResponse),
|
||||
|
||||
/// Contains response to a verification request
|
||||
#[serde(rename = "auth_verification_response")]
|
||||
Verification(VerificationResponse),
|
||||
}
|
||||
|
||||
/// Represents a response to initialization to specify which authentication methods to pursue
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct InitializationResponse {
|
||||
/// Methods to use (in order as provided)
|
||||
pub methods: Vec<String>,
|
||||
}
|
||||
|
||||
/// Represents the answers to a previously-asked challenge associated with authentication
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ChallengeResponse {
|
||||
/// Answers to challenge questions (in order relative to questions)
|
||||
pub answers: Vec<String>,
|
||||
}
|
||||
|
||||
/// Represents the answer to a previously-asked verification associated with authentication
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct VerificationResponse {
|
||||
/// Whether or not the verification was deemed valid
|
||||
pub valid: bool,
|
||||
}
|
||||
|
||||
/// Represents the type of verification being requested
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum VerificationKind {
|
||||
/// An ask to verify the host such as with SSH
|
||||
#[display(fmt = "host")]
|
||||
Host,
|
||||
|
||||
/// When the verification is unknown (happens when other side is unaware of the kind)
|
||||
#[display(fmt = "unknown")]
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl VerificationKind {
|
||||
/// Returns all variants except "unknown"
|
||||
pub const fn known_variants() -> &'static [Self] {
|
||||
&[Self::Host]
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a single question in a challenge associated with authentication
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Question {
|
||||
/// Label associated with the question for more programmatic usage
|
||||
pub label: String,
|
||||
|
||||
/// The text of the question (used for display purposes)
|
||||
pub text: String,
|
||||
|
||||
/// Any options information specific to a particular auth domain
|
||||
/// such as including a username and instructions for SSH authentication
|
||||
pub options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Question {
|
||||
/// Creates a new question without any options data using `text` for both label and text
|
||||
pub fn new(text: impl Into<String>) -> Self {
|
||||
let text = text.into();
|
||||
|
||||
Self {
|
||||
label: text.clone(),
|
||||
text,
|
||||
options: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents some error that occurred during authentication
|
||||
#[derive(Clone, Debug, Display, Error, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[display(fmt = "{kind}: {text}")]
|
||||
pub struct Error {
|
||||
/// Represents the kind of error
|
||||
pub kind: ErrorKind,
|
||||
|
||||
/// Description of the error
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Creates a fatal error
|
||||
pub fn fatal(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Fatal,
|
||||
text: text.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a non-fatal error
|
||||
pub fn non_fatal(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Error,
|
||||
text: text.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if error represents a fatal error, meaning that there is no recovery possible
|
||||
/// from this error
|
||||
pub fn is_fatal(&self) -> bool {
|
||||
self.kind.is_fatal()
|
||||
}
|
||||
|
||||
/// Converts the error into a [`std::io::Error`] representing permission denied
|
||||
pub fn into_io_permission_denied(self) -> std::io::Error {
|
||||
std::io::Error::new(std::io::ErrorKind::PermissionDenied, self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the type of error encountered during authentication
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ErrorKind {
|
||||
/// Error is unrecoverable
|
||||
Fatal,
|
||||
|
||||
/// Error is recoverable
|
||||
Error,
|
||||
}
|
||||
|
||||
impl ErrorKind {
|
||||
/// Returns true if error kind represents a fatal error, meaning that there is no recovery
|
||||
/// possible from this error
|
||||
pub fn is_fatal(self) -> bool {
|
||||
matches!(self, Self::Fatal)
|
||||
}
|
||||
}
|
@ -1,68 +0,0 @@
|
||||
use crate::{data::ProcessId, ConnectionId};
|
||||
use std::{io, path::PathBuf};
|
||||
|
||||
mod process;
|
||||
pub use process::*;
|
||||
|
||||
mod watcher;
|
||||
pub use watcher::*;
|
||||
|
||||
/// Holds global state state managed by the server
|
||||
pub struct GlobalState {
|
||||
/// State that holds information about processes running on the server
|
||||
pub process: ProcessState,
|
||||
|
||||
/// Watcher used for filesystem events
|
||||
pub watcher: WatcherState,
|
||||
}
|
||||
|
||||
impl GlobalState {
|
||||
pub fn initialize() -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
process: ProcessState::new(),
|
||||
watcher: WatcherState::initialize()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds connection-specific state managed by the server
|
||||
#[derive(Default)]
|
||||
pub struct ConnectionState {
|
||||
/// Unique id associated with connection
|
||||
id: ConnectionId,
|
||||
|
||||
/// Channel connected to global process state
|
||||
pub(crate) process_channel: ProcessChannel,
|
||||
|
||||
/// Channel connected to global watcher state
|
||||
pub(crate) watcher_channel: WatcherChannel,
|
||||
|
||||
/// Contains ids of processes that will be terminated when the connection is closed
|
||||
processes: Vec<ProcessId>,
|
||||
|
||||
/// Contains paths being watched that will be unwatched when the connection is closed
|
||||
paths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl Drop for ConnectionState {
|
||||
fn drop(&mut self) {
|
||||
let id = self.id;
|
||||
let processes: Vec<ProcessId> = self.processes.drain(..).collect();
|
||||
let paths: Vec<PathBuf> = self.paths.drain(..).collect();
|
||||
|
||||
let process_channel = self.process_channel.clone();
|
||||
let watcher_channel = self.watcher_channel.clone();
|
||||
|
||||
// NOTE: We cannot (and should not) block during drop to perform cleanup,
|
||||
// instead spawning a task that will do the cleanup async
|
||||
tokio::spawn(async move {
|
||||
for id in processes {
|
||||
let _ = process_channel.kill(id).await;
|
||||
}
|
||||
|
||||
for path in paths {
|
||||
let _ = watcher_channel.unwatch(id, path).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
@ -1,311 +0,0 @@
|
||||
use crate::{constants::SERVER_WATCHER_CAPACITY, data::ChangeKind, ConnectionId};
|
||||
use log::*;
|
||||
use notify::{
|
||||
Config as WatcherConfig, Error as WatcherError, ErrorKind as WatcherErrorKind,
|
||||
Event as WatcherEvent, PollWatcher, RecursiveMode, Watcher,
|
||||
};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
ops::Deref,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use tokio::{
|
||||
sync::{
|
||||
mpsc::{self, error::TrySendError},
|
||||
oneshot,
|
||||
},
|
||||
task::JoinHandle,
|
||||
};
|
||||
|
||||
mod path;
|
||||
pub use path::*;
|
||||
|
||||
/// Holds information related to watched paths on the server
|
||||
pub struct WatcherState {
|
||||
channel: WatcherChannel,
|
||||
task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Drop for WatcherState {
|
||||
/// Aborts the task that handles watcher path operations and management
|
||||
fn drop(&mut self) {
|
||||
self.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl WatcherState {
|
||||
/// Will create a watcher and initialize watched paths to be empty
|
||||
pub fn initialize() -> io::Result<Self> {
|
||||
// NOTE: Cannot be something small like 1 as this seems to cause a deadlock sometimes
|
||||
// with a large volume of watch requests
|
||||
let (tx, rx) = mpsc::channel(SERVER_WATCHER_CAPACITY);
|
||||
|
||||
macro_rules! configure_and_spawn {
|
||||
($watcher:ident) => {{
|
||||
// Attempt to configure watcher, but don't fail if these configurations fail
|
||||
match $watcher.configure(WatcherConfig::PreciseEvents(true)) {
|
||||
Ok(true) => debug!("Watcher configured for precise events"),
|
||||
Ok(false) => debug!("Watcher not configured for precise events",),
|
||||
Err(x) => error!("Watcher configuration for precise events failed: {}", x),
|
||||
}
|
||||
|
||||
// Attempt to configure watcher, but don't fail if these configurations fail
|
||||
match $watcher.configure(WatcherConfig::NoticeEvents(true)) {
|
||||
Ok(true) => debug!("Watcher configured for notice events"),
|
||||
Ok(false) => debug!("Watcher not configured for notice events",),
|
||||
Err(x) => error!("Watcher configuration for notice events failed: {}", x),
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
channel: WatcherChannel { tx },
|
||||
task: tokio::spawn(watcher_task($watcher, rx)),
|
||||
})
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! event_handler {
|
||||
($tx:ident) => {
|
||||
move |res| match $tx.try_send(match res {
|
||||
Ok(x) => InnerWatcherMsg::Event { ev: x },
|
||||
Err(x) => InnerWatcherMsg::Error { err: x },
|
||||
}) {
|
||||
Ok(_) => (),
|
||||
Err(TrySendError::Full(_)) => {
|
||||
warn!(
|
||||
"Reached watcher capacity of {}! Dropping watcher event!",
|
||||
SERVER_WATCHER_CAPACITY,
|
||||
);
|
||||
}
|
||||
Err(TrySendError::Closed(_)) => {
|
||||
warn!("Skipping watch event because watcher channel closed");
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let tx = tx.clone();
|
||||
let result = {
|
||||
let tx = tx.clone();
|
||||
notify::recommended_watcher(event_handler!(tx))
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(mut watcher) => configure_and_spawn!(watcher),
|
||||
Err(x) => match x.kind {
|
||||
// notify-rs has a bug on Mac M1 with Docker and Linux, so we detect that error
|
||||
// and fall back to the poll watcher if this occurs
|
||||
//
|
||||
// https://github.com/notify-rs/notify/issues/423
|
||||
WatcherErrorKind::Io(x) if x.raw_os_error() == Some(38) => {
|
||||
warn!("Recommended watcher is unsupported! Falling back to polling watcher!");
|
||||
let mut watcher = PollWatcher::new(event_handler!(tx))
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
|
||||
configure_and_spawn!(watcher)
|
||||
}
|
||||
_ => Err(io::Error::new(io::ErrorKind::Other, x)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clone_channel(&self) -> WatcherChannel {
|
||||
self.channel.clone()
|
||||
}
|
||||
|
||||
/// Aborts the watcher task
|
||||
pub fn abort(&self) {
|
||||
self.task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for WatcherState {
|
||||
type Target = WatcherChannel;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.channel
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WatcherChannel {
|
||||
tx: mpsc::Sender<InnerWatcherMsg>,
|
||||
}
|
||||
|
||||
impl Default for WatcherChannel {
|
||||
/// Creates a new channel that is closed by default
|
||||
fn default() -> Self {
|
||||
let (tx, _) = mpsc::channel(1);
|
||||
Self { tx }
|
||||
}
|
||||
}
|
||||
|
||||
impl WatcherChannel {
|
||||
/// Watch a path for a specific connection denoted by the id within the registered path
|
||||
pub async fn watch(&self, registered_path: RegisteredPath) -> io::Result<()> {
|
||||
let (cb, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(InnerWatcherMsg::Watch {
|
||||
registered_path,
|
||||
cb,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal watcher task closed"))?;
|
||||
rx.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to watch dropped"))?
|
||||
}
|
||||
|
||||
/// Unwatch a path for a specific connection denoted by the id
|
||||
pub async fn unwatch(&self, id: ConnectionId, path: impl AsRef<Path>) -> io::Result<()> {
|
||||
let (cb, rx) = oneshot::channel();
|
||||
let path = tokio::fs::canonicalize(path.as_ref())
|
||||
.await
|
||||
.unwrap_or_else(|_| path.as_ref().to_path_buf());
|
||||
self.tx
|
||||
.send(InnerWatcherMsg::Unwatch { id, path, cb })
|
||||
.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal watcher task closed"))?;
|
||||
rx.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to unwatch dropped"))?
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal message to pass to our task below to perform some action
|
||||
enum InnerWatcherMsg {
|
||||
Watch {
|
||||
registered_path: RegisteredPath,
|
||||
cb: oneshot::Sender<io::Result<()>>,
|
||||
},
|
||||
Unwatch {
|
||||
id: ConnectionId,
|
||||
path: PathBuf,
|
||||
cb: oneshot::Sender<io::Result<()>>,
|
||||
},
|
||||
Event {
|
||||
ev: WatcherEvent,
|
||||
},
|
||||
Error {
|
||||
err: WatcherError,
|
||||
},
|
||||
}
|
||||
|
||||
async fn watcher_task(mut watcher: impl Watcher, mut rx: mpsc::Receiver<InnerWatcherMsg>) {
|
||||
// TODO: Optimize this in some way to be more performant than
|
||||
// checking every path whenever an event comes in
|
||||
let mut registered_paths: Vec<RegisteredPath> = Vec::new();
|
||||
let mut path_cnt: HashMap<PathBuf, usize> = HashMap::new();
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
match msg {
|
||||
InnerWatcherMsg::Watch {
|
||||
registered_path,
|
||||
cb,
|
||||
} => {
|
||||
// Check if we are tracking the path across any connection
|
||||
if let Some(cnt) = path_cnt.get_mut(registered_path.path()) {
|
||||
// Increment the count of times we are watching that path
|
||||
*cnt += 1;
|
||||
|
||||
// Store the registered path in our collection without worry
|
||||
// since we are already watching a path that impacts this one
|
||||
registered_paths.push(registered_path);
|
||||
|
||||
// Send an okay because we always succeed in this case
|
||||
let _ = cb.send(Ok(()));
|
||||
} else {
|
||||
let res = watcher
|
||||
.watch(
|
||||
registered_path.path(),
|
||||
if registered_path.is_recursive() {
|
||||
RecursiveMode::Recursive
|
||||
} else {
|
||||
RecursiveMode::NonRecursive
|
||||
},
|
||||
)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x));
|
||||
|
||||
// If we succeeded, store our registered path and set the tracking cnt to 1
|
||||
if res.is_ok() {
|
||||
path_cnt.insert(registered_path.path().to_path_buf(), 1);
|
||||
registered_paths.push(registered_path);
|
||||
}
|
||||
|
||||
// Send the result of the watch, but don't worry if the channel was closed
|
||||
let _ = cb.send(res);
|
||||
}
|
||||
}
|
||||
InnerWatcherMsg::Unwatch { id, path, cb } => {
|
||||
// Check if we are tracking the path across any connection
|
||||
if let Some(cnt) = path_cnt.get(path.as_path()) {
|
||||
// Cycle through and remove all paths that match the given id and path,
|
||||
// capturing how many paths we removed
|
||||
let removed_cnt = {
|
||||
let old_len = registered_paths.len();
|
||||
registered_paths
|
||||
.retain(|p| p.id() != id || (p.path() != path && p.raw_path() != path));
|
||||
let new_len = registered_paths.len();
|
||||
old_len - new_len
|
||||
};
|
||||
|
||||
// 1. If we are now at zero cnt for our path, we want to actually unwatch the
|
||||
// path with our watcher
|
||||
// 2. If we removed nothing from our path list, we want to return an error
|
||||
// 3. Otherwise, we return okay because we succeeded
|
||||
if *cnt <= removed_cnt {
|
||||
let _ = cb.send(
|
||||
watcher
|
||||
.unwatch(&path)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x)),
|
||||
);
|
||||
} else if removed_cnt == 0 {
|
||||
// Send a failure as there was nothing to unwatch for this connection
|
||||
let _ = cb.send(Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("{:?} is not being watched", path),
|
||||
)));
|
||||
} else {
|
||||
// Send a success as we removed some paths
|
||||
let _ = cb.send(Ok(()));
|
||||
}
|
||||
} else {
|
||||
// Send a failure as there was nothing to unwatch
|
||||
let _ = cb.send(Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("{:?} is not being watched", path),
|
||||
)));
|
||||
}
|
||||
}
|
||||
InnerWatcherMsg::Event { ev } => {
|
||||
let kind = ChangeKind::from(ev.kind);
|
||||
|
||||
for registered_path in registered_paths.iter() {
|
||||
match registered_path.filter_and_send(kind, &ev.paths).await {
|
||||
Ok(_) => (),
|
||||
Err(x) => error!(
|
||||
"[Conn {}] Failed to forward changes to paths: {}",
|
||||
registered_path.id(),
|
||||
x
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
InnerWatcherMsg::Error { err } => {
|
||||
let msg = err.to_string();
|
||||
error!("Watcher encountered an error {} for {:?}", msg, err.paths);
|
||||
|
||||
for registered_path in registered_paths.iter() {
|
||||
match registered_path
|
||||
.filter_and_send_error(&msg, &err.paths, !err.paths.is_empty())
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(x) => error!(
|
||||
"[Conn {}] Failed to forward changes to paths: {}",
|
||||
registered_path.id(),
|
||||
x
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,18 +1,24 @@
|
||||
use crate::{DistantMsg, DistantRequestData, DistantResponseData};
|
||||
use distant_net::{Channel, Client};
|
||||
use distant_net::client::Channel;
|
||||
use distant_net::Client;
|
||||
|
||||
use crate::protocol;
|
||||
|
||||
mod ext;
|
||||
mod lsp;
|
||||
mod process;
|
||||
mod searcher;
|
||||
mod watcher;
|
||||
|
||||
/// Represents a [`Client`] that communicates using the distant protocol
|
||||
pub type DistantClient = Client<DistantMsg<DistantRequestData>, DistantMsg<DistantResponseData>>;
|
||||
pub type DistantClient =
|
||||
Client<protocol::Msg<protocol::Request>, protocol::Msg<protocol::Response>>;
|
||||
|
||||
/// Represents a [`Channel`] that communicates using the distant protocol
|
||||
pub type DistantChannel = Channel<DistantMsg<DistantRequestData>, DistantMsg<DistantResponseData>>;
|
||||
pub type DistantChannel =
|
||||
Channel<protocol::Msg<protocol::Request>, protocol::Msg<protocol::Response>>;
|
||||
|
||||
pub use ext::*;
|
||||
pub use lsp::*;
|
||||
pub use process::*;
|
||||
pub use searcher::*;
|
||||
pub use watcher::*;
|
||||
|
@ -0,0 +1,624 @@
|
||||
use std::{fmt, io};
|
||||
|
||||
use distant_net::common::Request;
|
||||
use log::*;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::client::{DistantChannel, DistantChannelExt};
|
||||
use crate::constants::CLIENT_SEARCHER_CAPACITY;
|
||||
use crate::protocol::{self, SearchId, SearchQuery, SearchQueryMatch};
|
||||
|
||||
/// Represents a searcher for files, directories, and symlinks on the filesystem
|
||||
pub struct Searcher {
|
||||
channel: DistantChannel,
|
||||
id: SearchId,
|
||||
query: SearchQuery,
|
||||
task: JoinHandle<()>,
|
||||
rx: mpsc::Receiver<SearchQueryMatch>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Searcher {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Searcher")
|
||||
.field("id", &self.id)
|
||||
.field("query", &self.query)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Searcher {
|
||||
/// Creates a searcher for some query
|
||||
pub async fn search(mut channel: DistantChannel, query: SearchQuery) -> io::Result<Self> {
|
||||
trace!("Searching using {query:?}",);
|
||||
|
||||
// Submit our run request and get back a mailbox for responses
|
||||
let mut mailbox = channel
|
||||
.mail(Request::new(protocol::Msg::Single(
|
||||
protocol::Request::Search {
|
||||
query: query.clone(),
|
||||
},
|
||||
)))
|
||||
.await?;
|
||||
|
||||
let (tx, rx) = mpsc::channel(CLIENT_SEARCHER_CAPACITY);
|
||||
|
||||
// Wait to get the confirmation of watch as either ok or error
|
||||
let mut queue: Vec<SearchQueryMatch> = Vec::new();
|
||||
let mut search_id = None;
|
||||
while let Some(res) = mailbox.next().await {
|
||||
for data in res.payload.into_vec() {
|
||||
match data {
|
||||
// If we get results before the started indicator, queue them up
|
||||
protocol::Response::SearchResults { matches, .. } => {
|
||||
queue.extend(matches);
|
||||
}
|
||||
|
||||
// Once we get the started indicator, mark as ready to go
|
||||
protocol::Response::SearchStarted { id } => {
|
||||
trace!("[Query {id}] Searcher has started");
|
||||
search_id = Some(id);
|
||||
}
|
||||
|
||||
// If we get an explicit error, convert and return it
|
||||
protocol::Response::Error(x) => return Err(io::Error::from(x)),
|
||||
|
||||
// Otherwise, we got something unexpected, and report as such
|
||||
x => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("Unexpected response: {x:?}"),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exit if we got the confirmation
|
||||
// NOTE: Doing this later because we want to make sure the entire payload is processed
|
||||
// first before exiting the loop
|
||||
if search_id.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let search_id = match search_id {
|
||||
// Send out any of our queued changes that we got prior to the acknowledgement
|
||||
Some(id) => {
|
||||
trace!("[Query {id}] Forwarding {} queued matches", queue.len());
|
||||
for r#match in queue.drain(..) {
|
||||
if tx.send(r#match).await.is_err() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("[Query {id}] Queue search match dropped"),
|
||||
));
|
||||
}
|
||||
}
|
||||
id
|
||||
}
|
||||
|
||||
// If we never received an acknowledgement of search before the mailbox closed,
|
||||
// fail with a missing confirmation error
|
||||
None => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Search query missing started confirmation",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Spawn a task that continues to look for search result events and the conclusion of the
|
||||
// search, discarding anything else that it gets
|
||||
let task = tokio::spawn({
|
||||
async move {
|
||||
while let Some(res) = mailbox.next().await {
|
||||
let mut done = false;
|
||||
|
||||
for data in res.payload.into_vec() {
|
||||
match data {
|
||||
protocol::Response::SearchResults { matches, .. } => {
|
||||
// If we can't queue up a match anymore, we've
|
||||
// been closed and therefore want to quit
|
||||
if tx.is_closed() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Otherwise, send over the matches
|
||||
for r#match in matches {
|
||||
if let Err(x) = tx.send(r#match).await {
|
||||
error!(
|
||||
"[Query {search_id}] Searcher failed to send match {:?}",
|
||||
x.0
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Received completion indicator, so close out
|
||||
protocol::Response::SearchDone { .. } => {
|
||||
trace!("[Query {search_id}] Searcher has finished");
|
||||
done = true;
|
||||
break;
|
||||
}
|
||||
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
id: search_id,
|
||||
query,
|
||||
channel,
|
||||
task,
|
||||
rx,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a reference to the query this searcher is running
|
||||
pub fn query(&self) -> &SearchQuery {
|
||||
&self.query
|
||||
}
|
||||
|
||||
/// Returns true if the searcher is still actively searching
|
||||
pub fn is_active(&self) -> bool {
|
||||
!self.task.is_finished()
|
||||
}
|
||||
|
||||
/// Returns the next match detected by the searcher, or none if the searcher has concluded
|
||||
pub async fn next(&mut self) -> Option<SearchQueryMatch> {
|
||||
self.rx.recv().await
|
||||
}
|
||||
|
||||
/// Cancels the search being performed by the watcher
|
||||
pub async fn cancel(&mut self) -> io::Result<()> {
|
||||
trace!("[Query {}] Cancelling search", self.id);
|
||||
self.channel.cancel_search(self.id).await?;
|
||||
|
||||
// Kill our task that processes inbound matches if we have successfully stopped searching
|
||||
self.task.abort();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use distant_net::common::{FramedTransport, InmemoryTransport, Response};
|
||||
use distant_net::Client;
|
||||
use test_log::test;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::*;
|
||||
use crate::protocol::{
|
||||
SearchQueryCondition, SearchQueryMatchData, SearchQueryOptions, SearchQueryPathMatch,
|
||||
SearchQuerySubmatch, SearchQueryTarget,
|
||||
};
|
||||
use crate::DistantClient;
|
||||
|
||||
fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
|
||||
let (t1, t2) = FramedTransport::pair(100);
|
||||
(t1, Client::spawn_inmemory(t2, Default::default()))
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn searcher_should_have_query_reflect_ongoing_query() {
|
||||
let (mut transport, session) = make_session();
|
||||
let test_query = SearchQuery {
|
||||
paths: vec![PathBuf::from("/some/test/path")],
|
||||
target: SearchQueryTarget::Path,
|
||||
condition: SearchQueryCondition::Regex {
|
||||
value: String::from("."),
|
||||
},
|
||||
options: SearchQueryOptions::default(),
|
||||
};
|
||||
|
||||
// Create a task for searcher as we need to handle the request and a response
|
||||
// in a separate async block
|
||||
let search_task = {
|
||||
let test_query = test_query.clone();
|
||||
tokio::spawn(async move { Searcher::search(session.clone_channel(), test_query).await })
|
||||
};
|
||||
|
||||
// Wait until we get the request from the session
|
||||
let req: Request<protocol::Request> = transport.read_frame_as().await.unwrap().unwrap();
|
||||
|
||||
// Send back an acknowledgement that a search was started
|
||||
transport
|
||||
.write_frame_for(&Response::new(
|
||||
req.id,
|
||||
protocol::Response::SearchStarted { id: rand::random() },
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Get the searcher and verify the query
|
||||
let searcher = search_task.await.unwrap().unwrap();
|
||||
assert_eq!(searcher.query(), &test_query);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn searcher_should_support_getting_next_match() {
|
||||
let (mut transport, session) = make_session();
|
||||
let test_query = SearchQuery {
|
||||
paths: vec![PathBuf::from("/some/test/path")],
|
||||
target: SearchQueryTarget::Path,
|
||||
condition: SearchQueryCondition::Regex {
|
||||
value: String::from("."),
|
||||
},
|
||||
options: SearchQueryOptions::default(),
|
||||
};
|
||||
|
||||
// Create a task for searcher as we need to handle the request and a response
|
||||
// in a separate async block
|
||||
let search_task =
|
||||
tokio::spawn(
|
||||
async move { Searcher::search(session.clone_channel(), test_query).await },
|
||||
);
|
||||
|
||||
// Wait until we get the request from the session
|
||||
let req: Request<protocol::Request> = transport.read_frame_as().await.unwrap().unwrap();
|
||||
|
||||
// Send back an acknowledgement that a searcher was created
|
||||
let id = rand::random::<SearchId>();
|
||||
transport
|
||||
.write_frame_for(&Response::new(
|
||||
req.id.clone(),
|
||||
protocol::Response::SearchStarted { id },
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Get the searcher
|
||||
let mut searcher = search_task.await.unwrap().unwrap();
|
||||
|
||||
// Send some matches related to the file
|
||||
transport
|
||||
.write_frame_for(&Response::new(
|
||||
req.id,
|
||||
vec![
|
||||
protocol::Response::SearchResults {
|
||||
id,
|
||||
matches: vec![
|
||||
SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/1"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match".to_string()),
|
||||
start: 3,
|
||||
end: 7,
|
||||
}],
|
||||
}),
|
||||
SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/2"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match 2".to_string()),
|
||||
start: 88,
|
||||
end: 99,
|
||||
}],
|
||||
}),
|
||||
],
|
||||
},
|
||||
protocol::Response::SearchResults {
|
||||
id,
|
||||
matches: vec![SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/3"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match 3".to_string()),
|
||||
start: 5,
|
||||
end: 9,
|
||||
}],
|
||||
})],
|
||||
},
|
||||
],
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the searcher gets the matches, one at a time
|
||||
let m = searcher.next().await.expect("Searcher closed unexpectedly");
|
||||
assert_eq!(
|
||||
m,
|
||||
SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/1"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match".to_string()),
|
||||
start: 3,
|
||||
end: 7,
|
||||
}],
|
||||
})
|
||||
);
|
||||
|
||||
let m = searcher.next().await.expect("Searcher closed unexpectedly");
|
||||
assert_eq!(
|
||||
m,
|
||||
SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/2"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match 2".to_string()),
|
||||
start: 88,
|
||||
end: 99,
|
||||
}],
|
||||
}),
|
||||
);
|
||||
|
||||
let m = searcher.next().await.expect("Searcher closed unexpectedly");
|
||||
assert_eq!(
|
||||
m,
|
||||
SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/3"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match 3".to_string()),
|
||||
start: 5,
|
||||
end: 9,
|
||||
}],
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn searcher_should_distinguish_match_events_and_only_receive_matches_for_itself() {
|
||||
let (mut transport, session) = make_session();
|
||||
|
||||
let test_query = SearchQuery {
|
||||
paths: vec![PathBuf::from("/some/test/path")],
|
||||
target: SearchQueryTarget::Path,
|
||||
condition: SearchQueryCondition::Regex {
|
||||
value: String::from("."),
|
||||
},
|
||||
options: SearchQueryOptions::default(),
|
||||
};
|
||||
|
||||
// Create a task for searcher as we need to handle the request and a response
|
||||
// in a separate async block
|
||||
let search_task =
|
||||
tokio::spawn(
|
||||
async move { Searcher::search(session.clone_channel(), test_query).await },
|
||||
);
|
||||
|
||||
// Wait until we get the request from the session
|
||||
let req: Request<protocol::Request> = transport.read_frame_as().await.unwrap().unwrap();
|
||||
|
||||
// Send back an acknowledgement that a searcher was created
|
||||
let id = rand::random();
|
||||
transport
|
||||
.write_frame_for(&Response::new(
|
||||
req.id.clone(),
|
||||
protocol::Response::SearchStarted { id },
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Get the searcher
|
||||
let mut searcher = search_task.await.unwrap().unwrap();
|
||||
|
||||
// Send a match from the appropriate origin
|
||||
transport
|
||||
.write_frame_for(&Response::new(
|
||||
req.id.clone(),
|
||||
protocol::Response::SearchResults {
|
||||
id,
|
||||
matches: vec![SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/1"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match".to_string()),
|
||||
start: 3,
|
||||
end: 7,
|
||||
}],
|
||||
})],
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send a chanmatchge from a different origin
|
||||
transport
|
||||
.write_frame_for(&Response::new(
|
||||
req.id.clone() + "1",
|
||||
protocol::Response::SearchResults {
|
||||
id,
|
||||
matches: vec![SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/2"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match 2".to_string()),
|
||||
start: 88,
|
||||
end: 99,
|
||||
}],
|
||||
})],
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send a chanmatchge from the appropriate origin
|
||||
transport
|
||||
.write_frame_for(&Response::new(
|
||||
req.id,
|
||||
protocol::Response::SearchResults {
|
||||
id,
|
||||
matches: vec![SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/3"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match 3".to_string()),
|
||||
start: 5,
|
||||
end: 9,
|
||||
}],
|
||||
})],
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the searcher gets the matches, one at a time
|
||||
let m = searcher.next().await.expect("Searcher closed unexpectedly");
|
||||
assert_eq!(
|
||||
m,
|
||||
SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/1"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match".to_string()),
|
||||
start: 3,
|
||||
end: 7,
|
||||
}],
|
||||
})
|
||||
);
|
||||
|
||||
let m = searcher.next().await.expect("Watcher closed unexpectedly");
|
||||
assert_eq!(
|
||||
m,
|
||||
SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/3"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match 3".to_string()),
|
||||
start: 5,
|
||||
end: 9,
|
||||
}],
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn searcher_should_stop_receiving_events_if_cancelled() {
|
||||
let (mut transport, session) = make_session();
|
||||
|
||||
let test_query = SearchQuery {
|
||||
paths: vec![PathBuf::from("/some/test/path")],
|
||||
target: SearchQueryTarget::Path,
|
||||
condition: SearchQueryCondition::Regex {
|
||||
value: String::from("."),
|
||||
},
|
||||
options: SearchQueryOptions::default(),
|
||||
};
|
||||
|
||||
// Create a task for searcher as we need to handle the request and a response
|
||||
// in a separate async block
|
||||
let search_task =
|
||||
tokio::spawn(
|
||||
async move { Searcher::search(session.clone_channel(), test_query).await },
|
||||
);
|
||||
|
||||
// Wait until we get the request from the session
|
||||
let req: Request<protocol::Request> = transport.read_frame_as().await.unwrap().unwrap();
|
||||
|
||||
// Send back an acknowledgement that a watcher was created
|
||||
let id = rand::random::<SearchId>();
|
||||
transport
|
||||
.write_frame_for(&Response::new(
|
||||
req.id.clone(),
|
||||
protocol::Response::SearchStarted { id },
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send some matches from the appropriate origin
|
||||
transport
|
||||
.write_frame_for(&Response::new(
|
||||
req.id,
|
||||
protocol::Response::SearchResults {
|
||||
id,
|
||||
matches: vec![
|
||||
SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/1"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match".to_string()),
|
||||
start: 3,
|
||||
end: 7,
|
||||
}],
|
||||
}),
|
||||
SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/2"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match 2".to_string()),
|
||||
start: 88,
|
||||
end: 99,
|
||||
}],
|
||||
}),
|
||||
],
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait a little bit for all matches to be queued
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
// Create a task for for cancelling as we need to handle the request and a response
|
||||
// in a separate async block
|
||||
let searcher = Arc::new(Mutex::new(search_task.await.unwrap().unwrap()));
|
||||
|
||||
// Verify that the searcher gets the first match
|
||||
let m = searcher
|
||||
.lock()
|
||||
.await
|
||||
.next()
|
||||
.await
|
||||
.expect("Searcher closed unexpectedly");
|
||||
assert_eq!(
|
||||
m,
|
||||
SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/1"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match".to_string()),
|
||||
start: 3,
|
||||
end: 7,
|
||||
}],
|
||||
}),
|
||||
);
|
||||
|
||||
// Cancel the search, verify the request is sent out, and respond with ok
|
||||
let searcher_2 = Arc::clone(&searcher);
|
||||
let cancel_task = tokio::spawn(async move { searcher_2.lock().await.cancel().await });
|
||||
|
||||
let req: Request<protocol::Request> = transport.read_frame_as().await.unwrap().unwrap();
|
||||
|
||||
transport
|
||||
.write_frame_for(&Response::new(req.id.clone(), protocol::Response::Ok))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for the cancel to complete
|
||||
cancel_task.await.unwrap().unwrap();
|
||||
|
||||
// Send a match that will get ignored
|
||||
transport
|
||||
.write_frame_for(&Response::new(
|
||||
req.id,
|
||||
protocol::Response::SearchResults {
|
||||
id,
|
||||
matches: vec![SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/3"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match 3".to_string()),
|
||||
start: 5,
|
||||
end: 9,
|
||||
}],
|
||||
})],
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that we get any remaining matches that were received before cancel,
|
||||
// but nothing new after that
|
||||
assert_eq!(
|
||||
searcher.lock().await.next().await,
|
||||
Some(SearchQueryMatch::Path(SearchQueryPathMatch {
|
||||
path: PathBuf::from("/some/path/2"),
|
||||
submatches: vec![SearchQuerySubmatch {
|
||||
r#match: SearchQueryMatchData::Text("test match 2".to_string()),
|
||||
start: 88,
|
||||
end: 99,
|
||||
}],
|
||||
}))
|
||||
);
|
||||
assert_eq!(searcher.lock().await.next().await, None);
|
||||
}
|
||||
}
|
@ -1,515 +0,0 @@
|
||||
use derive_more::{From, IsVariant};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{io, path::PathBuf};
|
||||
use strum::AsRefStr;
|
||||
|
||||
#[cfg(feature = "clap")]
|
||||
use strum::VariantNames;
|
||||
|
||||
mod change;
|
||||
pub use change::*;
|
||||
|
||||
mod cmd;
|
||||
pub use cmd::*;
|
||||
|
||||
#[cfg(feature = "clap")]
|
||||
mod clap_impl;
|
||||
|
||||
mod error;
|
||||
pub use error::*;
|
||||
|
||||
mod filesystem;
|
||||
pub use filesystem::*;
|
||||
|
||||
mod map;
|
||||
pub use map::Map;
|
||||
|
||||
mod metadata;
|
||||
pub use metadata::*;
|
||||
|
||||
mod pty;
|
||||
pub use pty::*;
|
||||
|
||||
mod system;
|
||||
pub use system::*;
|
||||
|
||||
mod utils;
|
||||
pub(crate) use utils::*;
|
||||
|
||||
/// Id for a remote process
|
||||
pub type ProcessId = u32;
|
||||
|
||||
/// Mapping of environment variables
|
||||
pub type Environment = Map;
|
||||
|
||||
/// Type alias for a vec of bytes
|
||||
///
|
||||
/// NOTE: This only exists to support properly parsing a Vec<u8> from an entire string
|
||||
/// with clap rather than trying to parse a string as a singular u8
|
||||
pub type ByteVec = Vec<u8>;
|
||||
|
||||
#[cfg(feature = "clap")]
|
||||
fn parse_byte_vec(src: &str) -> ByteVec {
|
||||
src.as_bytes().to_vec()
|
||||
}
|
||||
|
||||
/// Represents a wrapper around a distant message, supporting single and batch requests
|
||||
#[derive(Clone, Debug, From, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(untagged)]
|
||||
pub enum DistantMsg<T> {
|
||||
Single(T),
|
||||
Batch(Vec<T>),
|
||||
}
|
||||
|
||||
impl<T> DistantMsg<T> {
|
||||
/// Returns true if msg has a single payload
|
||||
pub fn is_single(&self) -> bool {
|
||||
matches!(self, Self::Single(_))
|
||||
}
|
||||
|
||||
/// Returns reference to single value if msg is single variant
|
||||
pub fn as_single(&self) -> Option<&T> {
|
||||
match self {
|
||||
Self::Single(x) => Some(x),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns mutable reference to single value if msg is single variant
|
||||
pub fn as_mut_single(&mut self) -> Option<&T> {
|
||||
match self {
|
||||
Self::Single(x) => Some(x),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the single value if msg is single variant
|
||||
pub fn into_single(self) -> Option<T> {
|
||||
match self {
|
||||
Self::Single(x) => Some(x),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if msg has a batch of payloads
|
||||
pub fn is_batch(&self) -> bool {
|
||||
matches!(self, Self::Batch(_))
|
||||
}
|
||||
|
||||
/// Returns reference to batch value if msg is batch variant
|
||||
pub fn as_batch(&self) -> Option<&[T]> {
|
||||
match self {
|
||||
Self::Batch(x) => Some(x),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns mutable reference to batch value if msg is batch variant
|
||||
pub fn as_mut_batch(&mut self) -> Option<&mut [T]> {
|
||||
match self {
|
||||
Self::Batch(x) => Some(x),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the batch value if msg is batch variant
|
||||
pub fn into_batch(self) -> Option<Vec<T>> {
|
||||
match self {
|
||||
Self::Batch(x) => Some(x),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert into a collection of payload data
|
||||
pub fn into_vec(self) -> Vec<T> {
|
||||
match self {
|
||||
Self::Single(x) => vec![x],
|
||||
Self::Batch(x) => x,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl<T: schemars::JsonSchema> DistantMsg<T> {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(DistantMsg<T>)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the payload of a request to be performed on the remote machine
|
||||
#[derive(Clone, Debug, PartialEq, Eq, IsVariant, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[cfg_attr(feature = "clap", derive(clap::Subcommand))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")]
|
||||
#[cfg_attr(feature = "clap", clap(rename_all = "kebab-case"))]
|
||||
pub enum DistantRequestData {
|
||||
/// Reads a file from the specified path on the remote machine
|
||||
#[cfg_attr(feature = "clap", clap(visible_aliases = &["cat"]))]
|
||||
FileRead {
|
||||
/// The path to the file on the remote machine
|
||||
path: PathBuf,
|
||||
},
|
||||
|
||||
/// Reads a file from the specified path on the remote machine
|
||||
/// and treats the contents as text
|
||||
FileReadText {
|
||||
/// The path to the file on the remote machine
|
||||
path: PathBuf,
|
||||
},
|
||||
|
||||
/// Writes a file, creating it if it does not exist, and overwriting any existing content
|
||||
/// on the remote machine
|
||||
FileWrite {
|
||||
/// The path to the file on the remote machine
|
||||
path: PathBuf,
|
||||
|
||||
/// Data for server-side writing of content
|
||||
#[cfg_attr(feature = "clap", clap(parse(from_str = parse_byte_vec)))]
|
||||
data: ByteVec,
|
||||
},
|
||||
|
||||
/// Writes a file using text instead of bytes, creating it if it does not exist,
|
||||
/// and overwriting any existing content on the remote machine
|
||||
FileWriteText {
|
||||
/// The path to the file on the remote machine
|
||||
path: PathBuf,
|
||||
|
||||
/// Data for server-side writing of content
|
||||
text: String,
|
||||
},
|
||||
|
||||
/// Appends to a file, creating it if it does not exist, on the remote machine
|
||||
FileAppend {
|
||||
/// The path to the file on the remote machine
|
||||
path: PathBuf,
|
||||
|
||||
/// Data for server-side writing of content
|
||||
#[cfg_attr(feature = "clap", clap(parse(from_str = parse_byte_vec)))]
|
||||
data: ByteVec,
|
||||
},
|
||||
|
||||
/// Appends text to a file, creating it if it does not exist, on the remote machine
|
||||
FileAppendText {
|
||||
/// The path to the file on the remote machine
|
||||
path: PathBuf,
|
||||
|
||||
/// Data for server-side writing of content
|
||||
text: String,
|
||||
},
|
||||
|
||||
/// Reads a directory from the specified path on the remote machine
|
||||
#[cfg_attr(feature = "clap", clap(visible_aliases = &["ls"]))]
|
||||
DirRead {
|
||||
/// The path to the directory on the remote machine
|
||||
path: PathBuf,
|
||||
|
||||
/// Maximum depth to traverse with 0 indicating there is no maximum
|
||||
/// depth and 1 indicating the most immediate children within the
|
||||
/// directory
|
||||
#[serde(default = "one")]
|
||||
#[cfg_attr(feature = "clap", clap(long, default_value = "1"))]
|
||||
depth: usize,
|
||||
|
||||
/// Whether or not to return absolute or relative paths
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long))]
|
||||
absolute: bool,
|
||||
|
||||
/// Whether or not to canonicalize the resulting paths, meaning
|
||||
/// returning the canonical, absolute form of a path with all
|
||||
/// intermediate components normalized and symbolic links resolved
|
||||
///
|
||||
/// Note that the flag absolute must be true to have absolute paths
|
||||
/// returned, even if canonicalize is flagged as true
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long))]
|
||||
canonicalize: bool,
|
||||
|
||||
/// Whether or not to include the root directory in the retrieved
|
||||
/// entries
|
||||
///
|
||||
/// If included, the root directory will also be a canonicalized,
|
||||
/// absolute path and will not follow any of the other flags
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long))]
|
||||
include_root: bool,
|
||||
},
|
||||
|
||||
/// Creates a directory on the remote machine
|
||||
#[cfg_attr(feature = "clap", clap(visible_aliases = &["mkdir"]))]
|
||||
DirCreate {
|
||||
/// The path to the directory on the remote machine
|
||||
path: PathBuf,
|
||||
|
||||
/// Whether or not to create all parent directories
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long))]
|
||||
all: bool,
|
||||
},
|
||||
|
||||
/// Removes a file or directory on the remote machine
|
||||
#[cfg_attr(feature = "clap", clap(visible_aliases = &["rm"]))]
|
||||
Remove {
|
||||
/// The path to the file or directory on the remote machine
|
||||
path: PathBuf,
|
||||
|
||||
/// Whether or not to remove all contents within directory if is a directory.
|
||||
/// Does nothing different for files
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long))]
|
||||
force: bool,
|
||||
},
|
||||
|
||||
/// Copies a file or directory on the remote machine
|
||||
#[cfg_attr(feature = "clap", clap(visible_aliases = &["cp"]))]
|
||||
Copy {
|
||||
/// The path to the file or directory on the remote machine
|
||||
src: PathBuf,
|
||||
|
||||
/// New location on the remote machine for copy of file or directory
|
||||
dst: PathBuf,
|
||||
},
|
||||
|
||||
/// Moves/renames a file or directory on the remote machine
|
||||
#[cfg_attr(feature = "clap", clap(visible_aliases = &["mv"]))]
|
||||
Rename {
|
||||
/// The path to the file or directory on the remote machine
|
||||
src: PathBuf,
|
||||
|
||||
/// New location on the remote machine for the file or directory
|
||||
dst: PathBuf,
|
||||
},
|
||||
|
||||
/// Watches a path for changes
|
||||
Watch {
|
||||
/// The path to the file, directory, or symlink on the remote machine
|
||||
path: PathBuf,
|
||||
|
||||
/// If true, will recursively watch for changes within directories, othewise
|
||||
/// will only watch for changes immediately within directories
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long))]
|
||||
recursive: bool,
|
||||
|
||||
/// Filter to only report back specified changes
|
||||
#[serde(default)]
|
||||
#[cfg_attr(
|
||||
feature = "clap",
|
||||
clap(long, possible_values = ChangeKind::VARIANTS)
|
||||
)]
|
||||
only: Vec<ChangeKind>,
|
||||
|
||||
/// Filter to report back changes except these specified changes
|
||||
#[serde(default)]
|
||||
#[cfg_attr(
|
||||
feature = "clap",
|
||||
clap(long, possible_values = ChangeKind::VARIANTS)
|
||||
)]
|
||||
except: Vec<ChangeKind>,
|
||||
},
|
||||
|
||||
/// Unwatches a path for changes, meaning no additional changes will be reported
|
||||
Unwatch {
|
||||
/// The path to the file, directory, or symlink on the remote machine
|
||||
path: PathBuf,
|
||||
},
|
||||
|
||||
/// Checks whether the given path exists
|
||||
Exists {
|
||||
/// The path to the file or directory on the remote machine
|
||||
path: PathBuf,
|
||||
},
|
||||
|
||||
/// Retrieves filesystem metadata for the specified path on the remote machine
|
||||
Metadata {
|
||||
/// The path to the file, directory, or symlink on the remote machine
|
||||
path: PathBuf,
|
||||
|
||||
/// Whether or not to include a canonicalized version of the path, meaning
|
||||
/// returning the canonical, absolute form of a path with all
|
||||
/// intermediate components normalized and symbolic links resolved
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long))]
|
||||
canonicalize: bool,
|
||||
|
||||
/// Whether or not to follow symlinks to determine absolute file type (dir/file)
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long))]
|
||||
resolve_file_type: bool,
|
||||
},
|
||||
|
||||
/// Spawns a new process on the remote machine
|
||||
#[cfg_attr(feature = "clap", clap(visible_aliases = &["spawn", "run"]))]
|
||||
ProcSpawn {
|
||||
/// The full command to run including arguments
|
||||
#[cfg_attr(feature = "clap", clap(flatten))]
|
||||
cmd: Cmd,
|
||||
|
||||
/// Environment to provide to the remote process
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long, default_value_t = Environment::default()))]
|
||||
environment: Environment,
|
||||
|
||||
/// Alternative current directory for the remote process
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long))]
|
||||
current_dir: Option<PathBuf>,
|
||||
|
||||
/// Whether or not the process should be persistent, meaning that the process will not be
|
||||
/// killed when the associated client disconnects
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long))]
|
||||
persist: bool,
|
||||
|
||||
/// If provided, will spawn process in a pty, otherwise spawns directly
|
||||
#[serde(default)]
|
||||
#[cfg_attr(feature = "clap", clap(long))]
|
||||
pty: Option<PtySize>,
|
||||
},
|
||||
|
||||
/// Kills a process running on the remote machine
|
||||
#[cfg_attr(feature = "clap", clap(visible_aliases = &["kill"]))]
|
||||
ProcKill {
|
||||
/// Id of the actively-running process
|
||||
id: ProcessId,
|
||||
},
|
||||
|
||||
/// Sends additional data to stdin of running process
|
||||
ProcStdin {
|
||||
/// Id of the actively-running process to send stdin data
|
||||
id: ProcessId,
|
||||
|
||||
/// Data to send to a process's stdin pipe
|
||||
#[serde(with = "serde_bytes")]
|
||||
#[cfg_attr(feature = "schemars", schemars(with = "Vec<u8>"))]
|
||||
data: Vec<u8>,
|
||||
},
|
||||
|
||||
/// Resize pty of remote process
|
||||
ProcResizePty {
|
||||
/// Id of the actively-running process whose pty to resize
|
||||
id: ProcessId,
|
||||
|
||||
/// The new pty dimensions
|
||||
size: PtySize,
|
||||
},
|
||||
|
||||
/// Retrieve information about the server and the system it is on
|
||||
SystemInfo {},
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl DistantRequestData {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(DistantRequestData)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the payload of a successful response
|
||||
#[derive(Clone, Debug, PartialEq, Eq, AsRefStr, IsVariant, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum DistantResponseData {
|
||||
/// General okay with no extra data, returned in cases like
|
||||
/// creating or removing a directory, copying a file, or renaming
|
||||
/// a file
|
||||
Ok,
|
||||
|
||||
/// General-purpose failure that occurred from some request
|
||||
Error(Error),
|
||||
|
||||
/// Response containing some arbitrary, binary data
|
||||
Blob {
|
||||
/// Binary data associated with the response
|
||||
#[serde(with = "serde_bytes")]
|
||||
#[cfg_attr(feature = "schemars", schemars(with = "Vec<u8>"))]
|
||||
data: Vec<u8>,
|
||||
},
|
||||
|
||||
/// Response containing some arbitrary, text data
|
||||
Text {
|
||||
/// Text data associated with the response
|
||||
data: String,
|
||||
},
|
||||
|
||||
/// Response to reading a directory
|
||||
DirEntries {
|
||||
/// Entries contained within the requested directory
|
||||
entries: Vec<DirEntry>,
|
||||
|
||||
/// Errors encountered while scanning for entries
|
||||
errors: Vec<Error>,
|
||||
},
|
||||
|
||||
/// Response to a filesystem change for some watched file, directory, or symlink
|
||||
Changed(Change),
|
||||
|
||||
/// Response to checking if a path exists
|
||||
Exists { value: bool },
|
||||
|
||||
/// Represents metadata about some filesystem object (file, directory, symlink) on remote machine
|
||||
Metadata(Metadata),
|
||||
|
||||
/// Response to starting a new process
|
||||
ProcSpawned {
|
||||
/// Arbitrary id associated with running process
|
||||
id: ProcessId,
|
||||
},
|
||||
|
||||
/// Actively-transmitted stdout as part of running process
|
||||
ProcStdout {
|
||||
/// Arbitrary id associated with running process
|
||||
id: ProcessId,
|
||||
|
||||
/// Data read from a process' stdout pipe
|
||||
#[serde(with = "serde_bytes")]
|
||||
#[cfg_attr(feature = "schemars", schemars(with = "Vec<u8>"))]
|
||||
data: Vec<u8>,
|
||||
},
|
||||
|
||||
/// Actively-transmitted stderr as part of running process
|
||||
ProcStderr {
|
||||
/// Arbitrary id associated with running process
|
||||
id: ProcessId,
|
||||
|
||||
/// Data read from a process' stderr pipe
|
||||
#[serde(with = "serde_bytes")]
|
||||
#[cfg_attr(feature = "schemars", schemars(with = "Vec<u8>"))]
|
||||
data: Vec<u8>,
|
||||
},
|
||||
|
||||
/// Response to a process finishing
|
||||
ProcDone {
|
||||
/// Arbitrary id associated with running process
|
||||
id: ProcessId,
|
||||
|
||||
/// Whether or not termination was successful
|
||||
success: bool,
|
||||
|
||||
/// Exit code associated with termination, will be missing if terminated by signal
|
||||
code: Option<i32>,
|
||||
},
|
||||
|
||||
/// Response to retrieving information about the server and the system it is on
|
||||
SystemInfo(SystemInfo),
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl DistantResponseData {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(DistantResponseData)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for DistantResponseData {
|
||||
fn from(x: io::Error) -> Self {
|
||||
Self::Error(Error::from(x))
|
||||
}
|
||||
}
|
||||
|
||||
/// Used to provide a default serde value of 1
|
||||
const fn one() -> usize {
|
||||
1
|
||||
}
|
@ -1,506 +0,0 @@
|
||||
use derive_more::{Deref, DerefMut, IntoIterator};
|
||||
use notify::{event::Event as NotifyEvent, EventKind as NotifyEventKind};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
fmt,
|
||||
hash::{Hash, Hasher},
|
||||
iter::FromIterator,
|
||||
ops::{BitOr, Sub},
|
||||
path::PathBuf,
|
||||
str::FromStr,
|
||||
};
|
||||
use strum::{EnumString, EnumVariantNames};
|
||||
|
||||
/// Change to one or more paths on the filesystem
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields)]
|
||||
pub struct Change {
|
||||
/// Label describing the kind of change
|
||||
pub kind: ChangeKind,
|
||||
|
||||
/// Paths that were changed
|
||||
pub paths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl Change {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(Change)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NotifyEvent> for Change {
|
||||
fn from(x: NotifyEvent) -> Self {
|
||||
Self {
|
||||
kind: x.kind.into(),
|
||||
paths: x.paths,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
Debug,
|
||||
strum::Display,
|
||||
EnumString,
|
||||
EnumVariantNames,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
|
||||
#[cfg_attr(feature = "clap", clap(rename_all = "snake_case"))]
|
||||
pub enum ChangeKind {
|
||||
/// Something about a file or directory was accessed, but
|
||||
/// no specific details were known
|
||||
Access,
|
||||
|
||||
/// A file was closed for executing
|
||||
AccessCloseExecute,
|
||||
|
||||
/// A file was closed for reading
|
||||
AccessCloseRead,
|
||||
|
||||
/// A file was closed for writing
|
||||
AccessCloseWrite,
|
||||
|
||||
/// A file was opened for executing
|
||||
AccessOpenExecute,
|
||||
|
||||
/// A file was opened for reading
|
||||
AccessOpenRead,
|
||||
|
||||
/// A file was opened for writing
|
||||
AccessOpenWrite,
|
||||
|
||||
/// A file or directory was read
|
||||
AccessRead,
|
||||
|
||||
/// The access time of a file or directory was changed
|
||||
AccessTime,
|
||||
|
||||
/// A file, directory, or something else was created
|
||||
Create,
|
||||
|
||||
/// The content of a file or directory changed
|
||||
Content,
|
||||
|
||||
/// The data of a file or directory was modified, but
|
||||
/// no specific details were known
|
||||
Data,
|
||||
|
||||
/// The metadata of a file or directory was modified, but
|
||||
/// no specific details were known
|
||||
Metadata,
|
||||
|
||||
/// Something about a file or directory was modified, but
|
||||
/// no specific details were known
|
||||
Modify,
|
||||
|
||||
/// A file, directory, or something else was removed
|
||||
Remove,
|
||||
|
||||
/// A file or directory was renamed, but no specific details were known
|
||||
Rename,
|
||||
|
||||
/// A file or directory was renamed, and the provided paths
|
||||
/// are the source and target in that order (from, to)
|
||||
RenameBoth,
|
||||
|
||||
/// A file or directory was renamed, and the provided path
|
||||
/// is the origin of the rename (before being renamed)
|
||||
RenameFrom,
|
||||
|
||||
/// A file or directory was renamed, and the provided path
|
||||
/// is the result of the rename
|
||||
RenameTo,
|
||||
|
||||
/// A file's size changed
|
||||
Size,
|
||||
|
||||
/// The ownership of a file or directory was changed
|
||||
Ownership,
|
||||
|
||||
/// The permissions of a file or directory was changed
|
||||
Permissions,
|
||||
|
||||
/// The write or modify time of a file or directory was changed
|
||||
WriteTime,
|
||||
|
||||
// Catchall in case we have no insight as to the type of change
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl ChangeKind {
|
||||
/// Returns true if the change is a kind of access
|
||||
pub fn is_access_kind(&self) -> bool {
|
||||
self.is_open_access_kind()
|
||||
|| self.is_close_access_kind()
|
||||
|| matches!(self, Self::Access | Self::AccessRead)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of open access
|
||||
pub fn is_open_access_kind(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::AccessOpenExecute | Self::AccessOpenRead | Self::AccessOpenWrite
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of close access
|
||||
pub fn is_close_access_kind(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::AccessCloseExecute | Self::AccessCloseRead | Self::AccessCloseWrite
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of creation
|
||||
pub fn is_create_kind(&self) -> bool {
|
||||
matches!(self, Self::Create)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of modification
|
||||
pub fn is_modify_kind(&self) -> bool {
|
||||
self.is_data_modify_kind() || self.is_metadata_modify_kind() || matches!(self, Self::Modify)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of data modification
|
||||
pub fn is_data_modify_kind(&self) -> bool {
|
||||
matches!(self, Self::Content | Self::Data | Self::Size)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of metadata modification
|
||||
pub fn is_metadata_modify_kind(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::AccessTime
|
||||
| Self::Metadata
|
||||
| Self::Ownership
|
||||
| Self::Permissions
|
||||
| Self::WriteTime
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of rename
|
||||
pub fn is_rename_kind(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::Rename | Self::RenameBoth | Self::RenameFrom | Self::RenameTo
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns true if the change is a kind of removal
|
||||
pub fn is_remove_kind(&self) -> bool {
|
||||
matches!(self, Self::Remove)
|
||||
}
|
||||
|
||||
/// Returns true if the change kind is unknown
|
||||
pub fn is_unknown_kind(&self) -> bool {
|
||||
matches!(self, Self::Unknown)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl ChangeKind {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(ChangeKind)
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr for ChangeKind {
|
||||
type Output = ChangeKindSet;
|
||||
|
||||
fn bitor(self, rhs: Self) -> Self::Output {
|
||||
let mut set = ChangeKindSet::empty();
|
||||
set.insert(self);
|
||||
set.insert(rhs);
|
||||
set
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NotifyEventKind> for ChangeKind {
|
||||
fn from(x: NotifyEventKind) -> Self {
|
||||
use notify::event::{
|
||||
AccessKind, AccessMode, DataChange, MetadataKind, ModifyKind, RenameMode,
|
||||
};
|
||||
match x {
|
||||
// File/directory access events
|
||||
NotifyEventKind::Access(AccessKind::Read) => Self::AccessRead,
|
||||
NotifyEventKind::Access(AccessKind::Open(AccessMode::Execute)) => {
|
||||
Self::AccessOpenExecute
|
||||
}
|
||||
NotifyEventKind::Access(AccessKind::Open(AccessMode::Read)) => Self::AccessOpenRead,
|
||||
NotifyEventKind::Access(AccessKind::Open(AccessMode::Write)) => Self::AccessOpenWrite,
|
||||
NotifyEventKind::Access(AccessKind::Close(AccessMode::Execute)) => {
|
||||
Self::AccessCloseExecute
|
||||
}
|
||||
NotifyEventKind::Access(AccessKind::Close(AccessMode::Read)) => Self::AccessCloseRead,
|
||||
NotifyEventKind::Access(AccessKind::Close(AccessMode::Write)) => Self::AccessCloseWrite,
|
||||
NotifyEventKind::Access(_) => Self::Access,
|
||||
|
||||
// File/directory creation events
|
||||
NotifyEventKind::Create(_) => Self::Create,
|
||||
|
||||
// Rename-oriented events
|
||||
NotifyEventKind::Modify(ModifyKind::Name(RenameMode::Both)) => Self::RenameBoth,
|
||||
NotifyEventKind::Modify(ModifyKind::Name(RenameMode::From)) => Self::RenameFrom,
|
||||
NotifyEventKind::Modify(ModifyKind::Name(RenameMode::To)) => Self::RenameTo,
|
||||
NotifyEventKind::Modify(ModifyKind::Name(_)) => Self::Rename,
|
||||
|
||||
// Data-modification events
|
||||
NotifyEventKind::Modify(ModifyKind::Data(DataChange::Content)) => Self::Content,
|
||||
NotifyEventKind::Modify(ModifyKind::Data(DataChange::Size)) => Self::Size,
|
||||
NotifyEventKind::Modify(ModifyKind::Data(_)) => Self::Data,
|
||||
|
||||
// Metadata-modification events
|
||||
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::AccessTime)) => {
|
||||
Self::AccessTime
|
||||
}
|
||||
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::WriteTime)) => {
|
||||
Self::WriteTime
|
||||
}
|
||||
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::Permissions)) => {
|
||||
Self::Permissions
|
||||
}
|
||||
NotifyEventKind::Modify(ModifyKind::Metadata(MetadataKind::Ownership)) => {
|
||||
Self::Ownership
|
||||
}
|
||||
NotifyEventKind::Modify(ModifyKind::Metadata(_)) => Self::Metadata,
|
||||
|
||||
// General modification events
|
||||
NotifyEventKind::Modify(_) => Self::Modify,
|
||||
|
||||
// File/directory removal events
|
||||
NotifyEventKind::Remove(_) => Self::Remove,
|
||||
|
||||
// Catch-all for other events
|
||||
NotifyEventKind::Any | NotifyEventKind::Other => Self::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a distinct set of different change kinds
|
||||
#[derive(Clone, Debug, Deref, DerefMut, IntoIterator, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct ChangeKindSet(HashSet<ChangeKind>);
|
||||
|
||||
impl ChangeKindSet {
|
||||
/// Produces an empty set of [`ChangeKind`]
|
||||
pub fn empty() -> Self {
|
||||
Self(HashSet::new())
|
||||
}
|
||||
|
||||
/// Produces a set of all [`ChangeKind`]
|
||||
pub fn all() -> Self {
|
||||
vec![
|
||||
ChangeKind::Access,
|
||||
ChangeKind::AccessCloseExecute,
|
||||
ChangeKind::AccessCloseRead,
|
||||
ChangeKind::AccessCloseWrite,
|
||||
ChangeKind::AccessOpenExecute,
|
||||
ChangeKind::AccessOpenRead,
|
||||
ChangeKind::AccessOpenWrite,
|
||||
ChangeKind::AccessRead,
|
||||
ChangeKind::AccessTime,
|
||||
ChangeKind::Create,
|
||||
ChangeKind::Content,
|
||||
ChangeKind::Data,
|
||||
ChangeKind::Metadata,
|
||||
ChangeKind::Modify,
|
||||
ChangeKind::Remove,
|
||||
ChangeKind::Rename,
|
||||
ChangeKind::RenameBoth,
|
||||
ChangeKind::RenameFrom,
|
||||
ChangeKind::RenameTo,
|
||||
ChangeKind::Size,
|
||||
ChangeKind::Ownership,
|
||||
ChangeKind::Permissions,
|
||||
ChangeKind::WriteTime,
|
||||
ChangeKind::Unknown,
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the access kinds
|
||||
pub fn access_set() -> Self {
|
||||
Self::access_open_set()
|
||||
| Self::access_close_set()
|
||||
| ChangeKind::AccessRead
|
||||
| ChangeKind::Access
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the open access kinds
|
||||
pub fn access_open_set() -> Self {
|
||||
ChangeKind::AccessOpenExecute | ChangeKind::AccessOpenRead | ChangeKind::AccessOpenWrite
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the close access kinds
|
||||
pub fn access_close_set() -> Self {
|
||||
ChangeKind::AccessCloseExecute | ChangeKind::AccessCloseRead | ChangeKind::AccessCloseWrite
|
||||
}
|
||||
|
||||
// Produces a changeset containing all of the modification kinds
|
||||
pub fn modify_set() -> Self {
|
||||
Self::modify_data_set() | Self::modify_metadata_set() | ChangeKind::Modify
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the data modification kinds
|
||||
pub fn modify_data_set() -> Self {
|
||||
ChangeKind::Content | ChangeKind::Data | ChangeKind::Size
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the metadata modification kinds
|
||||
pub fn modify_metadata_set() -> Self {
|
||||
ChangeKind::AccessTime
|
||||
| ChangeKind::Metadata
|
||||
| ChangeKind::Ownership
|
||||
| ChangeKind::Permissions
|
||||
| ChangeKind::WriteTime
|
||||
}
|
||||
|
||||
/// Produces a changeset containing all of the rename kinds
|
||||
pub fn rename_set() -> Self {
|
||||
ChangeKind::Rename | ChangeKind::RenameBoth | ChangeKind::RenameFrom | ChangeKind::RenameTo
|
||||
}
|
||||
|
||||
/// Consumes set and returns a vec of the kinds of changes
|
||||
pub fn into_vec(self) -> Vec<ChangeKind> {
|
||||
self.0.into_iter().collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl ChangeKindSet {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(ChangeKindSet)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ChangeKindSet {
|
||||
/// Outputs a comma-separated series of [`ChangeKind`] as string that are sorted
|
||||
/// such that this will always be consistent output
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let mut kinds = self
|
||||
.0
|
||||
.iter()
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<String>>();
|
||||
kinds.sort_unstable();
|
||||
write!(f, "{}", kinds.join(","))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for ChangeKindSet {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.to_string() == other.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for ChangeKindSet {}
|
||||
|
||||
impl Hash for ChangeKindSet {
|
||||
/// Hashes based on the output of [`fmt::Display`]
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.to_string().hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr<ChangeKindSet> for ChangeKindSet {
|
||||
type Output = Self;
|
||||
|
||||
fn bitor(mut self, rhs: ChangeKindSet) -> Self::Output {
|
||||
self.extend(rhs.0);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr<ChangeKind> for ChangeKindSet {
|
||||
type Output = Self;
|
||||
|
||||
fn bitor(mut self, rhs: ChangeKind) -> Self::Output {
|
||||
self.0.insert(rhs);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr<ChangeKindSet> for ChangeKind {
|
||||
type Output = ChangeKindSet;
|
||||
|
||||
fn bitor(self, rhs: ChangeKindSet) -> Self::Output {
|
||||
rhs | self
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<ChangeKindSet> for ChangeKindSet {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: Self) -> Self::Output {
|
||||
ChangeKindSet(&self.0 - &other.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<&'_ ChangeKindSet> for &ChangeKindSet {
|
||||
type Output = ChangeKindSet;
|
||||
|
||||
fn sub(self, other: &ChangeKindSet) -> Self::Output {
|
||||
ChangeKindSet(&self.0 - &other.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for ChangeKindSet {
|
||||
type Err = strum::ParseError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut change_set = HashSet::new();
|
||||
|
||||
for word in s.split(',') {
|
||||
change_set.insert(ChangeKind::from_str(word.trim())?);
|
||||
}
|
||||
|
||||
Ok(ChangeKindSet(change_set))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<ChangeKind> for ChangeKindSet {
|
||||
fn from_iter<I: IntoIterator<Item = ChangeKind>>(iter: I) -> Self {
|
||||
let mut change_set = HashSet::new();
|
||||
|
||||
for i in iter {
|
||||
change_set.insert(i);
|
||||
}
|
||||
|
||||
ChangeKindSet(change_set)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ChangeKind> for ChangeKindSet {
|
||||
fn from(change_kind: ChangeKind) -> Self {
|
||||
let mut set = Self::empty();
|
||||
set.insert(change_kind);
|
||||
set
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<ChangeKind>> for ChangeKindSet {
|
||||
fn from(changes: Vec<ChangeKind>) -> Self {
|
||||
changes.into_iter().collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ChangeKindSet {
|
||||
fn default() -> Self {
|
||||
Self::empty()
|
||||
}
|
||||
}
|
@ -1,106 +0,0 @@
|
||||
use crate::{data::Cmd, DistantMsg, DistantRequestData};
|
||||
use clap::{
|
||||
error::{Error, ErrorKind},
|
||||
Arg, ArgAction, ArgMatches, Args, Command, FromArgMatches, Subcommand,
|
||||
};
|
||||
|
||||
impl FromArgMatches for Cmd {
|
||||
fn from_arg_matches(matches: &ArgMatches) -> Result<Self, Error> {
|
||||
let mut matches = matches.clone();
|
||||
Self::from_arg_matches_mut(&mut matches)
|
||||
}
|
||||
fn from_arg_matches_mut(matches: &mut ArgMatches) -> Result<Self, Error> {
|
||||
let cmd = matches.get_one::<String>("cmd").ok_or_else(|| {
|
||||
Error::raw(
|
||||
ErrorKind::MissingRequiredArgument,
|
||||
"program must be specified",
|
||||
)
|
||||
})?;
|
||||
let args: Vec<String> = matches
|
||||
.get_many::<String>("arg")
|
||||
.unwrap_or_default()
|
||||
.map(ToString::to_string)
|
||||
.collect();
|
||||
Ok(Self::new(format!("{cmd} {}", args.join(" "))))
|
||||
}
|
||||
fn update_from_arg_matches(&mut self, matches: &ArgMatches) -> Result<(), Error> {
|
||||
let mut matches = matches.clone();
|
||||
self.update_from_arg_matches_mut(&mut matches)
|
||||
}
|
||||
fn update_from_arg_matches_mut(&mut self, _matches: &mut ArgMatches) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Args for Cmd {
|
||||
fn augment_args(cmd: Command<'_>) -> Command<'_> {
|
||||
cmd.arg(
|
||||
Arg::new("cmd")
|
||||
.required(true)
|
||||
.value_name("CMD")
|
||||
.action(ArgAction::Set),
|
||||
)
|
||||
.trailing_var_arg(true)
|
||||
.arg(
|
||||
Arg::new("arg")
|
||||
.value_name("ARGS")
|
||||
.multiple_values(true)
|
||||
.action(ArgAction::Append),
|
||||
)
|
||||
}
|
||||
fn augment_args_for_update(cmd: Command<'_>) -> Command<'_> {
|
||||
cmd
|
||||
}
|
||||
}
|
||||
|
||||
impl FromArgMatches for DistantMsg<DistantRequestData> {
|
||||
fn from_arg_matches(matches: &ArgMatches) -> Result<Self, Error> {
|
||||
match matches.subcommand() {
|
||||
Some(("single", args)) => Ok(Self::Single(DistantRequestData::from_arg_matches(args)?)),
|
||||
Some((_, _)) => Err(Error::raw(
|
||||
ErrorKind::UnrecognizedSubcommand,
|
||||
"Valid subcommand is `single`",
|
||||
)),
|
||||
None => Err(Error::raw(
|
||||
ErrorKind::MissingSubcommand,
|
||||
"Valid subcommand is `single`",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_from_arg_matches(&mut self, matches: &ArgMatches) -> Result<(), Error> {
|
||||
match matches.subcommand() {
|
||||
Some(("single", args)) => {
|
||||
*self = Self::Single(DistantRequestData::from_arg_matches(args)?)
|
||||
}
|
||||
Some((_, _)) => {
|
||||
return Err(Error::raw(
|
||||
ErrorKind::UnrecognizedSubcommand,
|
||||
"Valid subcommand is `single`",
|
||||
))
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Subcommand for DistantMsg<DistantRequestData> {
|
||||
fn augment_subcommands(cmd: Command<'_>) -> Command<'_> {
|
||||
cmd.subcommand(DistantRequestData::augment_subcommands(Command::new(
|
||||
"single",
|
||||
)))
|
||||
.subcommand_required(true)
|
||||
}
|
||||
|
||||
fn augment_subcommands_for_update(cmd: Command<'_>) -> Command<'_> {
|
||||
cmd.subcommand(DistantRequestData::augment_subcommands(Command::new(
|
||||
"single",
|
||||
)))
|
||||
.subcommand_required(true)
|
||||
}
|
||||
|
||||
fn has_subcommand(name: &str) -> bool {
|
||||
matches!(name, "single")
|
||||
}
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
use derive_more::{Display, From, Into};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
/// Represents some command with arguments to execute
|
||||
#[derive(Clone, Debug, Display, From, Into, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct Cmd(String);
|
||||
|
||||
impl Cmd {
|
||||
/// Creates a new command from the given `cmd`
|
||||
pub fn new(cmd: impl Into<String>) -> Self {
|
||||
Self(cmd.into())
|
||||
}
|
||||
|
||||
/// Returns reference to the program portion of the command
|
||||
pub fn program(&self) -> &str {
|
||||
match self.0.split_once(' ') {
|
||||
Some((program, _)) => program.trim(),
|
||||
None => self.0.trim(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to the arguments portion of the command
|
||||
pub fn arguments(&self) -> &str {
|
||||
match self.0.split_once(' ') {
|
||||
Some((_, arguments)) => arguments.trim(),
|
||||
None => "",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl Cmd {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(Cmd)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Cmd {
|
||||
type Target = String;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Cmd {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
use derive_more::IsVariant;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use strum::AsRefStr;
|
||||
|
||||
/// Represents information about a single entry within a directory
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields)]
|
||||
pub struct DirEntry {
|
||||
/// Represents the full path to the entry
|
||||
pub path: PathBuf,
|
||||
|
||||
/// Represents the type of the entry as a file/dir/symlink
|
||||
pub file_type: FileType,
|
||||
|
||||
/// Depth at which this entry was created relative to the root (0 being immediately within
|
||||
/// root)
|
||||
pub depth: usize,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl DirEntry {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(DirEntry)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the type associated with a dir entry
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, AsRefStr, IsVariant, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum FileType {
|
||||
Dir,
|
||||
File,
|
||||
Symlink,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl FileType {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(FileType)
|
||||
}
|
||||
}
|
@ -1,404 +0,0 @@
|
||||
use super::{deserialize_u128_option, serialize_u128_option, FileType};
|
||||
use bitflags::bitflags;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
io,
|
||||
path::{Path, PathBuf},
|
||||
time::SystemTime,
|
||||
};
|
||||
|
||||
/// Represents metadata about some path on a remote machine
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct Metadata {
|
||||
/// Canonicalized path to the file or directory, resolving symlinks, only included
|
||||
/// if flagged during the request
|
||||
pub canonicalized_path: Option<PathBuf>,
|
||||
|
||||
/// Represents the type of the entry as a file/dir/symlink
|
||||
pub file_type: FileType,
|
||||
|
||||
/// Size of the file/directory/symlink in bytes
|
||||
pub len: u64,
|
||||
|
||||
/// Whether or not the file/directory/symlink is marked as unwriteable
|
||||
pub readonly: bool,
|
||||
|
||||
/// Represents the last time (in milliseconds) when the file/directory/symlink was accessed;
|
||||
/// can be optional as certain systems don't support this
|
||||
#[serde(serialize_with = "serialize_u128_option")]
|
||||
#[serde(deserialize_with = "deserialize_u128_option")]
|
||||
pub accessed: Option<u128>,
|
||||
|
||||
/// Represents when (in milliseconds) the file/directory/symlink was created;
|
||||
/// can be optional as certain systems don't support this
|
||||
#[serde(serialize_with = "serialize_u128_option")]
|
||||
#[serde(deserialize_with = "deserialize_u128_option")]
|
||||
pub created: Option<u128>,
|
||||
|
||||
/// Represents the last time (in milliseconds) when the file/directory/symlink was modified;
|
||||
/// can be optional as certain systems don't support this
|
||||
#[serde(serialize_with = "serialize_u128_option")]
|
||||
#[serde(deserialize_with = "deserialize_u128_option")]
|
||||
pub modified: Option<u128>,
|
||||
|
||||
/// Represents metadata that is specific to a unix remote machine
|
||||
pub unix: Option<UnixMetadata>,
|
||||
|
||||
/// Represents metadata that is specific to a windows remote machine
|
||||
pub windows: Option<WindowsMetadata>,
|
||||
}
|
||||
|
||||
impl Metadata {
|
||||
pub async fn read(
|
||||
path: impl AsRef<Path>,
|
||||
canonicalize: bool,
|
||||
resolve_file_type: bool,
|
||||
) -> io::Result<Self> {
|
||||
let metadata = tokio::fs::symlink_metadata(path.as_ref()).await?;
|
||||
let canonicalized_path = if canonicalize {
|
||||
Some(tokio::fs::canonicalize(path.as_ref()).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// If asking for resolved file type and current type is symlink, then we want to refresh
|
||||
// our metadata to get the filetype for the resolved link
|
||||
let file_type = if resolve_file_type && metadata.file_type().is_symlink() {
|
||||
tokio::fs::metadata(path).await?.file_type()
|
||||
} else {
|
||||
metadata.file_type()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
canonicalized_path,
|
||||
accessed: metadata
|
||||
.accessed()
|
||||
.ok()
|
||||
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
|
||||
.map(|d| d.as_millis()),
|
||||
created: metadata
|
||||
.created()
|
||||
.ok()
|
||||
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
|
||||
.map(|d| d.as_millis()),
|
||||
modified: metadata
|
||||
.modified()
|
||||
.ok()
|
||||
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
|
||||
.map(|d| d.as_millis()),
|
||||
len: metadata.len(),
|
||||
readonly: metadata.permissions().readonly(),
|
||||
file_type: if file_type.is_dir() {
|
||||
FileType::Dir
|
||||
} else if file_type.is_file() {
|
||||
FileType::File
|
||||
} else {
|
||||
FileType::Symlink
|
||||
},
|
||||
|
||||
#[cfg(unix)]
|
||||
unix: Some({
|
||||
use std::os::unix::prelude::*;
|
||||
let mode = metadata.mode();
|
||||
crate::data::UnixMetadata::from(mode)
|
||||
}),
|
||||
#[cfg(not(unix))]
|
||||
unix: None,
|
||||
|
||||
#[cfg(windows)]
|
||||
windows: Some({
|
||||
use std::os::windows::prelude::*;
|
||||
let attributes = metadata.file_attributes();
|
||||
crate::data::WindowsMetadata::from(attributes)
|
||||
}),
|
||||
#[cfg(not(windows))]
|
||||
windows: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl Metadata {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents unix-specific metadata about some path on a remote machine
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct UnixMetadata {
|
||||
/// Represents whether or not owner can read from the file
|
||||
pub owner_read: bool,
|
||||
|
||||
/// Represents whether or not owner can write to the file
|
||||
pub owner_write: bool,
|
||||
|
||||
/// Represents whether or not owner can execute the file
|
||||
pub owner_exec: bool,
|
||||
|
||||
/// Represents whether or not associated group can read from the file
|
||||
pub group_read: bool,
|
||||
|
||||
/// Represents whether or not associated group can write to the file
|
||||
pub group_write: bool,
|
||||
|
||||
/// Represents whether or not associated group can execute the file
|
||||
pub group_exec: bool,
|
||||
|
||||
/// Represents whether or not other can read from the file
|
||||
pub other_read: bool,
|
||||
|
||||
/// Represents whether or not other can write to the file
|
||||
pub other_write: bool,
|
||||
|
||||
/// Represents whether or not other can execute the file
|
||||
pub other_exec: bool,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl UnixMetadata {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(UnixMetadata)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for UnixMetadata {
|
||||
/// Create from a unix mode bitset
|
||||
fn from(mode: u32) -> Self {
|
||||
let flags = UnixFilePermissionFlags::from_bits_truncate(mode);
|
||||
Self {
|
||||
owner_read: flags.contains(UnixFilePermissionFlags::OWNER_READ),
|
||||
owner_write: flags.contains(UnixFilePermissionFlags::OWNER_WRITE),
|
||||
owner_exec: flags.contains(UnixFilePermissionFlags::OWNER_EXEC),
|
||||
group_read: flags.contains(UnixFilePermissionFlags::GROUP_READ),
|
||||
group_write: flags.contains(UnixFilePermissionFlags::GROUP_WRITE),
|
||||
group_exec: flags.contains(UnixFilePermissionFlags::GROUP_EXEC),
|
||||
other_read: flags.contains(UnixFilePermissionFlags::OTHER_READ),
|
||||
other_write: flags.contains(UnixFilePermissionFlags::OTHER_WRITE),
|
||||
other_exec: flags.contains(UnixFilePermissionFlags::OTHER_EXEC),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UnixMetadata> for u32 {
|
||||
/// Convert to a unix mode bitset
|
||||
fn from(metadata: UnixMetadata) -> Self {
|
||||
let mut flags = UnixFilePermissionFlags::empty();
|
||||
|
||||
if metadata.owner_read {
|
||||
flags.insert(UnixFilePermissionFlags::OWNER_READ);
|
||||
}
|
||||
if metadata.owner_write {
|
||||
flags.insert(UnixFilePermissionFlags::OWNER_WRITE);
|
||||
}
|
||||
if metadata.owner_exec {
|
||||
flags.insert(UnixFilePermissionFlags::OWNER_EXEC);
|
||||
}
|
||||
|
||||
if metadata.group_read {
|
||||
flags.insert(UnixFilePermissionFlags::GROUP_READ);
|
||||
}
|
||||
if metadata.group_write {
|
||||
flags.insert(UnixFilePermissionFlags::GROUP_WRITE);
|
||||
}
|
||||
if metadata.group_exec {
|
||||
flags.insert(UnixFilePermissionFlags::GROUP_EXEC);
|
||||
}
|
||||
|
||||
if metadata.other_read {
|
||||
flags.insert(UnixFilePermissionFlags::OTHER_READ);
|
||||
}
|
||||
if metadata.other_write {
|
||||
flags.insert(UnixFilePermissionFlags::OTHER_WRITE);
|
||||
}
|
||||
if metadata.other_exec {
|
||||
flags.insert(UnixFilePermissionFlags::OTHER_EXEC);
|
||||
}
|
||||
|
||||
flags.bits
|
||||
}
|
||||
}
|
||||
|
||||
impl UnixMetadata {
|
||||
pub fn is_readonly(self) -> bool {
|
||||
!(self.owner_read || self.group_read || self.other_read)
|
||||
}
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
struct UnixFilePermissionFlags: u32 {
|
||||
const OWNER_READ = 0o400;
|
||||
const OWNER_WRITE = 0o200;
|
||||
const OWNER_EXEC = 0o100;
|
||||
const GROUP_READ = 0o40;
|
||||
const GROUP_WRITE = 0o20;
|
||||
const GROUP_EXEC = 0o10;
|
||||
const OTHER_READ = 0o4;
|
||||
const OTHER_WRITE = 0o2;
|
||||
const OTHER_EXEC = 0o1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents windows-specific metadata about some path on a remote machine
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct WindowsMetadata {
|
||||
/// Represents whether or not a file or directory is an archive
|
||||
pub archive: bool,
|
||||
|
||||
/// Represents whether or not a file or directory is compressed
|
||||
pub compressed: bool,
|
||||
|
||||
/// Represents whether or not the file or directory is encrypted
|
||||
pub encrypted: bool,
|
||||
|
||||
/// Represents whether or not a file or directory is hidden
|
||||
pub hidden: bool,
|
||||
|
||||
/// Represents whether or not a directory or user data stream is configured with integrity
|
||||
pub integrity_stream: bool,
|
||||
|
||||
/// Represents whether or not a file does not have other attributes set
|
||||
pub normal: bool,
|
||||
|
||||
/// Represents whether or not a file or directory is not to be indexed by content indexing
|
||||
/// service
|
||||
pub not_content_indexed: bool,
|
||||
|
||||
/// Represents whether or not a user data stream is not to be read by the background data
|
||||
/// integrity scanner
|
||||
pub no_scrub_data: bool,
|
||||
|
||||
/// Represents whether or not the data of a file is not available immediately
|
||||
pub offline: bool,
|
||||
|
||||
/// Represents whether or not a file or directory is not fully present locally
|
||||
pub recall_on_data_access: bool,
|
||||
|
||||
/// Represents whether or not a file or directory has no physical representation on the local
|
||||
/// system (is virtual)
|
||||
pub recall_on_open: bool,
|
||||
|
||||
/// Represents whether or not a file or directory has an associated reparse point, or a file is
|
||||
/// a symbolic link
|
||||
pub reparse_point: bool,
|
||||
|
||||
/// Represents whether or not a file is a sparse file
|
||||
pub sparse_file: bool,
|
||||
|
||||
/// Represents whether or not a file or directory is used partially or exclusively by the
|
||||
/// operating system
|
||||
pub system: bool,
|
||||
|
||||
/// Represents whether or not a file is being used for temporary storage
|
||||
pub temporary: bool,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl WindowsMetadata {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(WindowsMetadata)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for WindowsMetadata {
|
||||
/// Create from a windows file attribute bitset
|
||||
fn from(file_attributes: u32) -> Self {
|
||||
let flags = WindowsFileAttributeFlags::from_bits_truncate(file_attributes);
|
||||
Self {
|
||||
archive: flags.contains(WindowsFileAttributeFlags::ARCHIVE),
|
||||
compressed: flags.contains(WindowsFileAttributeFlags::COMPRESSED),
|
||||
encrypted: flags.contains(WindowsFileAttributeFlags::ENCRYPTED),
|
||||
hidden: flags.contains(WindowsFileAttributeFlags::HIDDEN),
|
||||
integrity_stream: flags.contains(WindowsFileAttributeFlags::INTEGRITY_SYSTEM),
|
||||
normal: flags.contains(WindowsFileAttributeFlags::NORMAL),
|
||||
not_content_indexed: flags.contains(WindowsFileAttributeFlags::NOT_CONTENT_INDEXED),
|
||||
no_scrub_data: flags.contains(WindowsFileAttributeFlags::NO_SCRUB_DATA),
|
||||
offline: flags.contains(WindowsFileAttributeFlags::OFFLINE),
|
||||
recall_on_data_access: flags.contains(WindowsFileAttributeFlags::RECALL_ON_DATA_ACCESS),
|
||||
recall_on_open: flags.contains(WindowsFileAttributeFlags::RECALL_ON_OPEN),
|
||||
reparse_point: flags.contains(WindowsFileAttributeFlags::REPARSE_POINT),
|
||||
sparse_file: flags.contains(WindowsFileAttributeFlags::SPARSE_FILE),
|
||||
system: flags.contains(WindowsFileAttributeFlags::SYSTEM),
|
||||
temporary: flags.contains(WindowsFileAttributeFlags::TEMPORARY),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WindowsMetadata> for u32 {
|
||||
/// Convert to a windows file attribute bitset
|
||||
fn from(metadata: WindowsMetadata) -> Self {
|
||||
let mut flags = WindowsFileAttributeFlags::empty();
|
||||
|
||||
if metadata.archive {
|
||||
flags.insert(WindowsFileAttributeFlags::ARCHIVE);
|
||||
}
|
||||
if metadata.compressed {
|
||||
flags.insert(WindowsFileAttributeFlags::COMPRESSED);
|
||||
}
|
||||
if metadata.encrypted {
|
||||
flags.insert(WindowsFileAttributeFlags::ENCRYPTED);
|
||||
}
|
||||
if metadata.hidden {
|
||||
flags.insert(WindowsFileAttributeFlags::HIDDEN);
|
||||
}
|
||||
if metadata.integrity_stream {
|
||||
flags.insert(WindowsFileAttributeFlags::INTEGRITY_SYSTEM);
|
||||
}
|
||||
if metadata.normal {
|
||||
flags.insert(WindowsFileAttributeFlags::NORMAL);
|
||||
}
|
||||
if metadata.not_content_indexed {
|
||||
flags.insert(WindowsFileAttributeFlags::NOT_CONTENT_INDEXED);
|
||||
}
|
||||
if metadata.no_scrub_data {
|
||||
flags.insert(WindowsFileAttributeFlags::NO_SCRUB_DATA);
|
||||
}
|
||||
if metadata.offline {
|
||||
flags.insert(WindowsFileAttributeFlags::OFFLINE);
|
||||
}
|
||||
if metadata.recall_on_data_access {
|
||||
flags.insert(WindowsFileAttributeFlags::RECALL_ON_DATA_ACCESS);
|
||||
}
|
||||
if metadata.recall_on_open {
|
||||
flags.insert(WindowsFileAttributeFlags::RECALL_ON_OPEN);
|
||||
}
|
||||
if metadata.reparse_point {
|
||||
flags.insert(WindowsFileAttributeFlags::REPARSE_POINT);
|
||||
}
|
||||
if metadata.sparse_file {
|
||||
flags.insert(WindowsFileAttributeFlags::SPARSE_FILE);
|
||||
}
|
||||
if metadata.system {
|
||||
flags.insert(WindowsFileAttributeFlags::SYSTEM);
|
||||
}
|
||||
if metadata.temporary {
|
||||
flags.insert(WindowsFileAttributeFlags::TEMPORARY);
|
||||
}
|
||||
|
||||
flags.bits
|
||||
}
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
struct WindowsFileAttributeFlags: u32 {
|
||||
const ARCHIVE = 0x20;
|
||||
const COMPRESSED = 0x800;
|
||||
const ENCRYPTED = 0x4000;
|
||||
const HIDDEN = 0x2;
|
||||
const INTEGRITY_SYSTEM = 0x8000;
|
||||
const NORMAL = 0x80;
|
||||
const NOT_CONTENT_INDEXED = 0x2000;
|
||||
const NO_SCRUB_DATA = 0x20000;
|
||||
const OFFLINE = 0x1000;
|
||||
const RECALL_ON_DATA_ACCESS = 0x400000;
|
||||
const RECALL_ON_OPEN = 0x40000;
|
||||
const REPARSE_POINT = 0x400;
|
||||
const SPARSE_FILE = 0x200;
|
||||
const SYSTEM = 0x4;
|
||||
const TEMPORARY = 0x100;
|
||||
const VIRTUAL = 0x10000;
|
||||
}
|
||||
}
|
@ -1,137 +0,0 @@
|
||||
use derive_more::{Display, Error};
|
||||
use portable_pty::PtySize as PortablePtySize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt, num::ParseIntError, str::FromStr};
|
||||
|
||||
/// Represents the size associated with a remote PTY
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct PtySize {
|
||||
/// Number of lines of text
|
||||
pub rows: u16,
|
||||
|
||||
/// Number of columns of text
|
||||
pub cols: u16,
|
||||
|
||||
/// Width of a cell in pixels. Note that some systems never fill this value and ignore it.
|
||||
#[serde(default)]
|
||||
pub pixel_width: u16,
|
||||
|
||||
/// Height of a cell in pixels. Note that some systems never fill this value and ignore it.
|
||||
#[serde(default)]
|
||||
pub pixel_height: u16,
|
||||
}
|
||||
|
||||
impl PtySize {
|
||||
/// Creates new size using just rows and columns
|
||||
pub fn from_rows_and_cols(rows: u16, cols: u16) -> Self {
|
||||
Self {
|
||||
rows,
|
||||
cols,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl PtySize {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(PtySize)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PortablePtySize> for PtySize {
|
||||
fn from(size: PortablePtySize) -> Self {
|
||||
Self {
|
||||
rows: size.rows,
|
||||
cols: size.cols,
|
||||
pixel_width: size.pixel_width,
|
||||
pixel_height: size.pixel_height,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PtySize> for PortablePtySize {
|
||||
fn from(size: PtySize) -> Self {
|
||||
Self {
|
||||
rows: size.rows,
|
||||
cols: size.cols,
|
||||
pixel_width: size.pixel_width,
|
||||
pixel_height: size.pixel_height,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PtySize {
|
||||
/// Prints out `rows,cols[,pixel_width,pixel_height]` where the
|
||||
/// pixel width and pixel height are only included if either
|
||||
/// one of them is not zero
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{},{}", self.rows, self.cols)?;
|
||||
if self.pixel_width > 0 || self.pixel_height > 0 {
|
||||
write!(f, ",{},{}", self.pixel_width, self.pixel_height)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PtySize {
|
||||
fn default() -> Self {
|
||||
PtySize {
|
||||
rows: 24,
|
||||
cols: 80,
|
||||
pixel_width: 0,
|
||||
pixel_height: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Display, Error)]
|
||||
pub enum PtySizeParseError {
|
||||
MissingRows,
|
||||
MissingColumns,
|
||||
InvalidRows(ParseIntError),
|
||||
InvalidColumns(ParseIntError),
|
||||
InvalidPixelWidth(ParseIntError),
|
||||
InvalidPixelHeight(ParseIntError),
|
||||
}
|
||||
|
||||
impl FromStr for PtySize {
|
||||
type Err = PtySizeParseError;
|
||||
|
||||
/// Attempts to parse a str into PtySize using one of the following formats:
|
||||
///
|
||||
/// * rows,cols (defaults to 0 for pixel_width & pixel_height)
|
||||
/// * rows,cols,pixel_width,pixel_height
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut tokens = s.split(',');
|
||||
|
||||
Ok(Self {
|
||||
rows: tokens
|
||||
.next()
|
||||
.ok_or(PtySizeParseError::MissingRows)?
|
||||
.trim()
|
||||
.parse()
|
||||
.map_err(PtySizeParseError::InvalidRows)?,
|
||||
cols: tokens
|
||||
.next()
|
||||
.ok_or(PtySizeParseError::MissingColumns)?
|
||||
.trim()
|
||||
.parse()
|
||||
.map_err(PtySizeParseError::InvalidColumns)?,
|
||||
pixel_width: tokens
|
||||
.next()
|
||||
.map(|s| s.trim().parse())
|
||||
.transpose()
|
||||
.map_err(PtySizeParseError::InvalidPixelWidth)?
|
||||
.unwrap_or(0),
|
||||
pixel_height: tokens
|
||||
.next()
|
||||
.map(|s| s.trim().parse())
|
||||
.transpose()
|
||||
.map_err(PtySizeParseError::InvalidPixelHeight)?
|
||||
.unwrap_or(0),
|
||||
})
|
||||
}
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{env, path::PathBuf};
|
||||
|
||||
/// Represents information about a system
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
pub struct SystemInfo {
|
||||
/// Family of the operating system as described in
|
||||
/// https://doc.rust-lang.org/std/env/consts/constant.FAMILY.html
|
||||
pub family: String,
|
||||
|
||||
/// Name of the specific operating system as described in
|
||||
/// https://doc.rust-lang.org/std/env/consts/constant.OS.html
|
||||
pub os: String,
|
||||
|
||||
/// Architecture of the CPI as described in
|
||||
/// https://doc.rust-lang.org/std/env/consts/constant.ARCH.html
|
||||
pub arch: String,
|
||||
|
||||
/// Current working directory of the running server process
|
||||
pub current_dir: PathBuf,
|
||||
|
||||
/// Primary separator for path components for the current platform
|
||||
/// as defined in https://doc.rust-lang.org/std/path/constant.MAIN_SEPARATOR.html
|
||||
pub main_separator: char,
|
||||
}
|
||||
|
||||
#[cfg(feature = "schemars")]
|
||||
impl SystemInfo {
|
||||
pub fn root_schema() -> schemars::schema::RootSchema {
|
||||
schemars::schema_for!(SystemInfo)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SystemInfo {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
family: env::consts::FAMILY.to_string(),
|
||||
os: env::consts::OS.to_string(),
|
||||
arch: env::consts::ARCH.to_string(),
|
||||
current_dir: env::current_dir().unwrap_or_default(),
|
||||
main_separator: std::path::MAIN_SEPARATOR,
|
||||
}
|
||||
}
|
||||
}
|
@ -1,27 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub(crate) fn deserialize_u128_option<'de, D>(deserializer: D) -> Result<Option<u128>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
match Option::<String>::deserialize(deserializer)? {
|
||||
Some(s) => match s.parse::<u128>() {
|
||||
Ok(value) => Ok(Some(value)),
|
||||
Err(error) => Err(serde::de::Error::custom(format!(
|
||||
"Cannot convert to u128 with error: {:?}",
|
||||
error
|
||||
))),
|
||||
},
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn serialize_u128_option<S: serde::Serializer>(
|
||||
val: &Option<u128>,
|
||||
s: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
match val {
|
||||
Some(v) => format!("{}", *v).serialize(s),
|
||||
None => s.serialize_unit(),
|
||||
}
|
||||
}
|
@ -1,7 +0,0 @@
|
||||
mod client;
|
||||
mod data;
|
||||
mod server;
|
||||
|
||||
pub use client::*;
|
||||
pub use data::*;
|
||||
pub use server::*;
|
@ -1,761 +0,0 @@
|
||||
use super::data::{
|
||||
ConnectionId, ConnectionInfo, ConnectionList, Destination, Extra, ManagerRequest,
|
||||
ManagerResponse,
|
||||
};
|
||||
use crate::{DistantChannel, DistantClient, DistantMsg, DistantRequestData, DistantResponseData};
|
||||
use distant_net::{
|
||||
router, Auth, AuthServer, Client, IntoSplit, MpscTransport, OneshotListener, Request, Response,
|
||||
ServerExt, ServerRef, UntypedTransportRead, UntypedTransportWrite,
|
||||
};
|
||||
use log::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
mod config;
|
||||
pub use config::*;
|
||||
|
||||
mod ext;
|
||||
pub use ext::*;
|
||||
|
||||
router!(DistantManagerClientRouter {
|
||||
auth_transport: Request<Auth> => Response<Auth>,
|
||||
manager_transport: Response<ManagerResponse> => Request<ManagerRequest>,
|
||||
});
|
||||
|
||||
/// Represents a client that can connect to a remote distant manager
|
||||
pub struct DistantManagerClient {
|
||||
auth: Box<dyn ServerRef>,
|
||||
client: Client<ManagerRequest, ManagerResponse>,
|
||||
distant_clients: HashMap<ConnectionId, ClientHandle>,
|
||||
}
|
||||
|
||||
impl Drop for DistantManagerClient {
|
||||
fn drop(&mut self) {
|
||||
self.auth.abort();
|
||||
self.client.abort();
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a raw channel between a manager client and some remote server
|
||||
pub struct RawDistantChannel {
|
||||
pub transport: MpscTransport<
|
||||
Request<DistantMsg<DistantRequestData>>,
|
||||
Response<DistantMsg<DistantResponseData>>,
|
||||
>,
|
||||
forward_task: JoinHandle<()>,
|
||||
mailbox_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl RawDistantChannel {
|
||||
pub fn abort(&self) {
|
||||
self.forward_task.abort();
|
||||
self.mailbox_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for RawDistantChannel {
|
||||
type Target = MpscTransport<
|
||||
Request<DistantMsg<DistantRequestData>>,
|
||||
Response<DistantMsg<DistantResponseData>>,
|
||||
>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.transport
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for RawDistantChannel {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.transport
|
||||
}
|
||||
}
|
||||
|
||||
struct ClientHandle {
|
||||
client: DistantClient,
|
||||
forward_task: JoinHandle<()>,
|
||||
mailbox_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Drop for ClientHandle {
|
||||
fn drop(&mut self) {
|
||||
self.forward_task.abort();
|
||||
self.mailbox_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl DistantManagerClient {
|
||||
/// Initializes a client using the provided [`UntypedTransport`]
|
||||
pub fn new<T>(config: DistantManagerClientConfig, transport: T) -> io::Result<Self>
|
||||
where
|
||||
T: IntoSplit + 'static,
|
||||
T::Read: UntypedTransportRead + 'static,
|
||||
T::Write: UntypedTransportWrite + 'static,
|
||||
{
|
||||
let DistantManagerClientRouter {
|
||||
auth_transport,
|
||||
manager_transport,
|
||||
..
|
||||
} = DistantManagerClientRouter::new(transport);
|
||||
|
||||
// Initialize our client with manager request/response transport
|
||||
let (writer, reader) = manager_transport.into_split();
|
||||
let client = Client::new(writer, reader)?;
|
||||
|
||||
// Initialize our auth handler with auth/auth transport
|
||||
let auth = AuthServer {
|
||||
on_challenge: config.on_challenge,
|
||||
on_verify: config.on_verify,
|
||||
on_info: config.on_info,
|
||||
on_error: config.on_error,
|
||||
}
|
||||
.start(OneshotListener::from_value(auth_transport.into_split()))?;
|
||||
|
||||
Ok(Self {
|
||||
auth,
|
||||
client,
|
||||
distant_clients: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Request that the manager launches a new server at the given `destination`
|
||||
/// with `extra` being passed for destination-specific details, returning the new
|
||||
/// `destination` of the spawned server to connect to
|
||||
pub async fn launch(
|
||||
&mut self,
|
||||
destination: impl Into<Destination>,
|
||||
extra: impl Into<Extra>,
|
||||
) -> io::Result<Destination> {
|
||||
let destination = Box::new(destination.into());
|
||||
let extra = extra.into();
|
||||
trace!("launch({}, {})", destination, extra);
|
||||
|
||||
let res = self
|
||||
.client
|
||||
.send(ManagerRequest::Launch { destination, extra })
|
||||
.await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Launched { destination } => Ok(destination),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Request that the manager establishes a new connection at the given `destination`
|
||||
/// with `extra` being passed for destination-specific details
|
||||
pub async fn connect(
|
||||
&mut self,
|
||||
destination: impl Into<Destination>,
|
||||
extra: impl Into<Extra>,
|
||||
) -> io::Result<ConnectionId> {
|
||||
let destination = Box::new(destination.into());
|
||||
let extra = extra.into();
|
||||
trace!("connect({}, {})", destination, extra);
|
||||
|
||||
let res = self
|
||||
.client
|
||||
.send(ManagerRequest::Connect { destination, extra })
|
||||
.await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Connected { id } => Ok(id),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Establishes a channel with the server represented by the `connection_id`,
|
||||
/// returning a [`DistantChannel`] acting as the connection
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Multiple calls to open a channel against the same connection will result in
|
||||
/// clones of the same [`DistantChannel`] rather than establishing a duplicate
|
||||
/// remote connection to the same server
|
||||
pub async fn open_channel(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
) -> io::Result<DistantChannel> {
|
||||
trace!("open_channel({})", connection_id);
|
||||
if let Some(handle) = self.distant_clients.get(&connection_id) {
|
||||
Ok(handle.client.clone_channel())
|
||||
} else {
|
||||
let RawDistantChannel {
|
||||
transport,
|
||||
forward_task,
|
||||
mailbox_task,
|
||||
} = self.open_raw_channel(connection_id).await?;
|
||||
let (writer, reader) = transport.into_split();
|
||||
let client = DistantClient::new(writer, reader)?;
|
||||
let channel = client.clone_channel();
|
||||
self.distant_clients.insert(
|
||||
connection_id,
|
||||
ClientHandle {
|
||||
client,
|
||||
forward_task,
|
||||
mailbox_task,
|
||||
},
|
||||
);
|
||||
Ok(channel)
|
||||
}
|
||||
}
|
||||
|
||||
/// Establishes a channel with the server represented by the `connection_id`,
|
||||
/// returning a [`Transport`] acting as the connection
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Multiple calls to open a channel against the same connection will result in establishing a
|
||||
/// duplicate remote connections to the same server, so take care when using this method
|
||||
pub async fn open_raw_channel(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
) -> io::Result<RawDistantChannel> {
|
||||
trace!("open_raw_channel({})", connection_id);
|
||||
let mut mailbox = self
|
||||
.client
|
||||
.mail(ManagerRequest::OpenChannel { id: connection_id })
|
||||
.await?;
|
||||
|
||||
// Wait for the first response, which should be channel confirmation
|
||||
let channel_id = match mailbox.next().await {
|
||||
Some(response) => match response.payload {
|
||||
ManagerResponse::ChannelOpened { id } => Ok(id),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
},
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"open_channel mailbox aborted",
|
||||
)),
|
||||
}?;
|
||||
|
||||
// Spawn reader and writer tasks to forward requests and replies
|
||||
// using our opened channel
|
||||
let (t1, t2) = MpscTransport::pair(1);
|
||||
let (mut writer, mut reader) = t1.into_split();
|
||||
let mailbox_task = tokio::spawn(async move {
|
||||
use distant_net::TypedAsyncWrite;
|
||||
while let Some(response) = mailbox.next().await {
|
||||
match response.payload {
|
||||
ManagerResponse::Channel { response, .. } => {
|
||||
if let Err(x) = writer.write(response).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
ManagerResponse::ChannelClosed { .. } => break,
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut manager_channel = self.client.clone_channel();
|
||||
let forward_task = tokio::spawn(async move {
|
||||
use distant_net::TypedAsyncRead;
|
||||
loop {
|
||||
match reader.read().await {
|
||||
Ok(Some(request)) => {
|
||||
// NOTE: In this situation, we do not expect a response to this
|
||||
// request (even if the server sends something back)
|
||||
if let Err(x) = manager_channel
|
||||
.fire(ManagerRequest::Channel {
|
||||
id: channel_id,
|
||||
request,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(RawDistantChannel {
|
||||
transport: t2,
|
||||
forward_task,
|
||||
mailbox_task,
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieves information about a specific connection
|
||||
pub async fn info(&mut self, id: ConnectionId) -> io::Result<ConnectionInfo> {
|
||||
trace!("info({})", id);
|
||||
let res = self.client.send(ManagerRequest::Info { id }).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Info(info) => Ok(info),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Kills the specified connection
|
||||
pub async fn kill(&mut self, id: ConnectionId) -> io::Result<()> {
|
||||
trace!("kill({})", id);
|
||||
let res = self.client.send(ManagerRequest::Kill { id }).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Killed => Ok(()),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves a list of active connections
|
||||
pub async fn list(&mut self) -> io::Result<ConnectionList> {
|
||||
trace!("list()");
|
||||
let res = self.client.send(ManagerRequest::List).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::List(list) => Ok(list),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Requests that the manager shuts down
|
||||
pub async fn shutdown(&mut self) -> io::Result<()> {
|
||||
trace!("shutdown()");
|
||||
let res = self.client.send(ManagerRequest::Shutdown).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Shutdown => Ok(()),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::data::{Error, ErrorKind};
|
||||
use distant_net::{
|
||||
FramedTransport, InmemoryTransport, PlainCodec, UntypedTransportRead, UntypedTransportWrite,
|
||||
};
|
||||
|
||||
fn setup() -> (
|
||||
DistantManagerClient,
|
||||
FramedTransport<InmemoryTransport, PlainCodec>,
|
||||
) {
|
||||
let (t1, t2) = FramedTransport::pair(100);
|
||||
let client =
|
||||
DistantManagerClient::new(DistantManagerClientConfig::with_empty_prompts(), t1)
|
||||
.unwrap();
|
||||
(client, t2)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn test_error() -> Error {
|
||||
Error {
|
||||
kind: ErrorKind::Interrupted,
|
||||
description: "test error".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn test_io_error() -> io::Error {
|
||||
test_error().into()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Extra>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Extra>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_return_id_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
let expected_id = 999;
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Connected { id: expected_id },
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let id = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Extra>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(id, expected_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.info(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.info(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_return_connection_info_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let info = ConnectionInfo {
|
||||
id: 123,
|
||||
destination: "scheme://host".parse::<Destination>().unwrap(),
|
||||
extra: "key=value".parse::<Extra>().unwrap(),
|
||||
};
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Info(info)))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let info = client.info(123).await.unwrap();
|
||||
assert_eq!(info.id, 123);
|
||||
assert_eq!(
|
||||
info.destination,
|
||||
"scheme://host".parse::<Destination>().unwrap()
|
||||
);
|
||||
assert_eq!(info.extra, "key=value".parse::<Extra>().unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.list().await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.list().await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_return_connection_list_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut list = ConnectionList::new();
|
||||
list.insert(123, "scheme://host".parse::<Destination>().unwrap());
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::List(list)))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let list = client.list().await.unwrap();
|
||||
assert_eq!(list.len(), 1);
|
||||
assert_eq!(
|
||||
list.get(&123).expect("Connection list missing item"),
|
||||
&"scheme://host".parse::<Destination>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.kill(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.kill(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_return_success_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Killed))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
client.kill(123).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Connected { id: 0 },
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.shutdown().await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.shutdown().await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_should_return_success_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
client.shutdown().await.unwrap();
|
||||
}
|
||||
}
|
@ -1,85 +0,0 @@
|
||||
use distant_net::{AuthChallengeFn, AuthErrorFn, AuthInfoFn, AuthVerifyFn, AuthVerifyKind};
|
||||
use log::*;
|
||||
use std::io;
|
||||
|
||||
/// Configuration to use when creating a new [`DistantManagerClient`](super::DistantManagerClient)
|
||||
pub struct DistantManagerClientConfig {
|
||||
pub on_challenge: Box<AuthChallengeFn>,
|
||||
pub on_verify: Box<AuthVerifyFn>,
|
||||
pub on_info: Box<AuthInfoFn>,
|
||||
pub on_error: Box<AuthErrorFn>,
|
||||
}
|
||||
|
||||
impl DistantManagerClientConfig {
|
||||
/// Creates a new config with prompts that return empty strings
|
||||
pub fn with_empty_prompts() -> Self {
|
||||
Self::with_prompts(|_| Ok("".to_string()), |_| Ok("".to_string()))
|
||||
}
|
||||
|
||||
/// Creates a new config with two prompts
|
||||
///
|
||||
/// * `password_prompt` - used for prompting for a secret, and should not display what is typed
|
||||
/// * `text_prompt` - used for general text, and is okay to display what is typed
|
||||
pub fn with_prompts<PP, PT>(password_prompt: PP, text_prompt: PT) -> Self
|
||||
where
|
||||
PP: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
|
||||
PT: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
on_challenge: Box::new(move |questions, _extra| {
|
||||
trace!("[manager client] on_challenge({questions:?}, {_extra:?})");
|
||||
let mut answers = Vec::new();
|
||||
for question in questions.iter() {
|
||||
// Contains all prompt lines including same line
|
||||
let mut lines = question.text.split('\n').collect::<Vec<_>>();
|
||||
|
||||
// Line that is prompt on same line as answer
|
||||
let line = lines.pop().unwrap();
|
||||
|
||||
// Go ahead and display all other lines
|
||||
for line in lines.into_iter() {
|
||||
eprintln!("{}", line);
|
||||
}
|
||||
|
||||
// Get an answer from user input, or use a blank string as an answer
|
||||
// if we fail to get input from the user
|
||||
let answer = password_prompt(line).unwrap_or_default();
|
||||
|
||||
answers.push(answer);
|
||||
}
|
||||
answers
|
||||
}),
|
||||
on_verify: Box::new(move |kind, text| {
|
||||
trace!("[manager client] on_verify({kind}, {text})");
|
||||
match kind {
|
||||
AuthVerifyKind::Host => {
|
||||
eprintln!("{}", text);
|
||||
|
||||
match text_prompt("Enter [y/N]> ") {
|
||||
Ok(answer) => {
|
||||
trace!("Verify? Answer = '{answer}'");
|
||||
matches!(answer.trim(), "y" | "Y" | "yes" | "YES")
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Failed verification: {x}");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
x => {
|
||||
error!("Unsupported verify kind: {x}");
|
||||
false
|
||||
}
|
||||
}
|
||||
}),
|
||||
on_info: Box::new(|text| {
|
||||
trace!("[manager client] on_info({text})");
|
||||
println!("{}", text);
|
||||
}),
|
||||
on_error: Box::new(|kind, text| {
|
||||
trace!("[manager client] on_error({kind}, {text})");
|
||||
eprintln!("{}: {}", kind, text);
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,14 +0,0 @@
|
||||
mod tcp;
|
||||
pub use tcp::*;
|
||||
|
||||
#[cfg(unix)]
|
||||
mod unix;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::*;
|
||||
|
||||
#[cfg(windows)]
|
||||
mod windows;
|
||||
|
||||
#[cfg(windows)]
|
||||
pub use windows::*;
|
@ -1,50 +0,0 @@
|
||||
use crate::{DistantManagerClient, DistantManagerClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{Codec, FramedTransport, TcpTransport};
|
||||
use std::{convert, net::SocketAddr};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait TcpDistantManagerClientExt {
|
||||
/// Connect to a remote TCP server using the provided information
|
||||
async fn connect<C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: SocketAddr,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a remote TCP server, timing out after duration has passed
|
||||
async fn connect_timeout<C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: SocketAddr,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(config, addr, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TcpDistantManagerClientExt for DistantManagerClient {
|
||||
/// Connect to a remote TCP server using the provided information
|
||||
async fn connect<C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: SocketAddr,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let transport = TcpTransport::connect(addr).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
Self::new(config, transport)
|
||||
}
|
||||
}
|
@ -1,54 +0,0 @@
|
||||
use crate::{DistantManagerClient, DistantManagerClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{Codec, FramedTransport, UnixSocketTransport};
|
||||
use std::{convert, path::Path};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait UnixSocketDistantManagerClientExt {
|
||||
/// Connect to a proxy unix socket
|
||||
async fn connect<P, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a proxy unix socket, timing out after duration has passed
|
||||
async fn connect_timeout<P, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(config, path, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl UnixSocketDistantManagerClientExt for DistantManagerClient {
|
||||
/// Connect to a proxy unix socket
|
||||
async fn connect<P, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let p = path.as_ref();
|
||||
let transport = UnixSocketTransport::connect(p).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
Ok(DistantManagerClient::new(config, transport)?)
|
||||
}
|
||||
}
|
@ -1,91 +0,0 @@
|
||||
use crate::{DistantManagerClient, DistantManagerClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{Codec, FramedTransport, WindowsPipeTransport};
|
||||
use std::{
|
||||
convert,
|
||||
ffi::{OsStr, OsString},
|
||||
};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait WindowsPipeDistantManagerClientExt {
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// using the given codec
|
||||
async fn connect<A, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// via `\\.\pipe\{name}` using the given codec
|
||||
async fn connect_local<N, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
name: N,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
N: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let mut addr = OsString::from(r"\\.\pipe\");
|
||||
addr.push(name.as_ref());
|
||||
Self::connect(config, addr, codec).await
|
||||
}
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// using the given codec, timing out after duration has passed
|
||||
async fn connect_timeout<A, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(config, addr, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// via `\\.\pipe\{name}` using the given codec, timing out after duration has passed
|
||||
async fn connect_local_timeout<N, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
name: N,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
N: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let mut addr = OsString::from(r"\\.\pipe\");
|
||||
addr.push(name.as_ref());
|
||||
Self::connect_timeout(config, addr, codec, duration).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WindowsPipeDistantManagerClientExt for DistantManagerClient {
|
||||
async fn connect<A, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let a = addr.as_ref();
|
||||
let transport = WindowsPipeTransport::connect(a).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
Ok(DistantManagerClient::new(config, transport)?)
|
||||
}
|
||||
}
|
@ -1,2 +0,0 @@
|
||||
/// Represents extra data included for connections
|
||||
pub type Extra = crate::data::Map;
|
@ -1,5 +0,0 @@
|
||||
/// Id associated with an active connection
|
||||
pub type ConnectionId = u64;
|
||||
|
||||
/// Id associated with an open channel
|
||||
pub type ChannelId = u64;
|
@ -1,72 +0,0 @@
|
||||
use super::{ChannelId, ConnectionId, Destination, Extra};
|
||||
use crate::{DistantMsg, DistantRequestData};
|
||||
use distant_net::Request;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "clap", derive(clap::Subcommand))]
|
||||
#[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")]
|
||||
pub enum ManagerRequest {
|
||||
/// Launch a server using the manager
|
||||
Launch {
|
||||
// NOTE: Boxed per clippy's large_enum_variant warning
|
||||
destination: Box<Destination>,
|
||||
|
||||
/// Extra details specific to the connection
|
||||
#[cfg_attr(feature = "clap", clap(short, long, action = clap::ArgAction::Append))]
|
||||
extra: Extra,
|
||||
},
|
||||
|
||||
/// Initiate a connection through the manager
|
||||
Connect {
|
||||
// NOTE: Boxed per clippy's large_enum_variant warning
|
||||
destination: Box<Destination>,
|
||||
|
||||
/// Extra details specific to the connection
|
||||
#[cfg_attr(feature = "clap", clap(short, long, action = clap::ArgAction::Append))]
|
||||
extra: Extra,
|
||||
},
|
||||
|
||||
/// Opens a channel for communication with a server
|
||||
#[cfg_attr(feature = "clap", clap(skip))]
|
||||
OpenChannel {
|
||||
/// Id of the connection
|
||||
id: ConnectionId,
|
||||
},
|
||||
|
||||
/// Sends data through channel
|
||||
#[cfg_attr(feature = "clap", clap(skip))]
|
||||
Channel {
|
||||
/// Id of the channel
|
||||
id: ChannelId,
|
||||
|
||||
/// Request to send to through the channel
|
||||
#[cfg_attr(feature = "clap", clap(skip = skipped_request()))]
|
||||
request: Request<DistantMsg<DistantRequestData>>,
|
||||
},
|
||||
|
||||
/// Closes an open channel
|
||||
#[cfg_attr(feature = "clap", clap(skip))]
|
||||
CloseChannel {
|
||||
/// Id of the channel to close
|
||||
id: ChannelId,
|
||||
},
|
||||
|
||||
/// Retrieve information about a specific connection
|
||||
Info { id: ConnectionId },
|
||||
|
||||
/// Kill a specific connection
|
||||
Kill { id: ConnectionId },
|
||||
|
||||
/// Retrieve list of connections being managed
|
||||
List,
|
||||
|
||||
/// Signals the manager to shutdown
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
/// Produces some default request, purely to satisfy clap
|
||||
#[cfg(feature = "clap")]
|
||||
fn skipped_request() -> Request<DistantMsg<DistantRequestData>> {
|
||||
Request::new(DistantMsg::Single(DistantRequestData::SystemInfo {}))
|
||||
}
|
@ -1,698 +0,0 @@
|
||||
use crate::{
|
||||
ChannelId, ConnectionId, ConnectionInfo, ConnectionList, Destination, Extra, ManagerRequest,
|
||||
ManagerResponse,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{
|
||||
router, Auth, AuthClient, Client, IntoSplit, Listener, MpscListener, Request, Response, Server,
|
||||
ServerCtx, ServerExt, UntypedTransportRead, UntypedTransportWrite,
|
||||
};
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io, sync::Arc};
|
||||
use tokio::{
|
||||
sync::{mpsc, Mutex, RwLock},
|
||||
task::JoinHandle,
|
||||
};
|
||||
|
||||
mod config;
|
||||
pub use config::*;
|
||||
|
||||
mod connection;
|
||||
pub use connection::*;
|
||||
|
||||
mod ext;
|
||||
pub use ext::*;
|
||||
|
||||
mod handler;
|
||||
pub use handler::*;
|
||||
|
||||
mod r#ref;
|
||||
pub use r#ref::*;
|
||||
|
||||
router!(DistantManagerRouter {
|
||||
auth_transport: Response<Auth> => Request<Auth>,
|
||||
manager_transport: Request<ManagerRequest> => Response<ManagerResponse>,
|
||||
});
|
||||
|
||||
/// Represents a manager of multiple distant server connections
|
||||
pub struct DistantManager {
|
||||
/// Receives authentication clients to feed into local data of server
|
||||
auth_client_rx: Mutex<mpsc::Receiver<AuthClient>>,
|
||||
|
||||
/// Configuration settings for the server
|
||||
config: DistantManagerConfig,
|
||||
|
||||
/// Mapping of connection id -> connection
|
||||
connections: RwLock<HashMap<ConnectionId, DistantManagerConnection>>,
|
||||
|
||||
/// Handlers for launch requests
|
||||
launch_handlers: Arc<RwLock<HashMap<String, BoxedLaunchHandler>>>,
|
||||
|
||||
/// Handlers for connect requests
|
||||
connect_handlers: Arc<RwLock<HashMap<String, BoxedConnectHandler>>>,
|
||||
|
||||
/// Primary task of server
|
||||
task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl DistantManager {
|
||||
/// Initializes a new instance of [`DistantManagerServer`] using the provided [`UntypedTransport`]
|
||||
pub fn start<L, T>(
|
||||
mut config: DistantManagerConfig,
|
||||
mut listener: L,
|
||||
) -> io::Result<DistantManagerRef>
|
||||
where
|
||||
L: Listener<Output = T> + 'static,
|
||||
T: IntoSplit + Send + 'static,
|
||||
T::Read: UntypedTransportRead + 'static,
|
||||
T::Write: UntypedTransportWrite + 'static,
|
||||
{
|
||||
let (conn_tx, mpsc_listener) = MpscListener::channel(config.connection_buffer_size);
|
||||
let (auth_client_tx, auth_client_rx) = mpsc::channel(1);
|
||||
|
||||
// Spawn task that uses our input listener to get both auth and manager events,
|
||||
// forwarding manager events to the internal mpsc listener
|
||||
let task = tokio::spawn(async move {
|
||||
while let Ok(transport) = listener.accept().await {
|
||||
let DistantManagerRouter {
|
||||
auth_transport,
|
||||
manager_transport,
|
||||
..
|
||||
} = DistantManagerRouter::new(transport);
|
||||
|
||||
let (writer, reader) = auth_transport.into_split();
|
||||
let client = match Client::new(writer, reader) {
|
||||
Ok(client) => client,
|
||||
Err(x) => {
|
||||
error!("Creating auth client failed: {}", x);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let auth_client = AuthClient::from(client);
|
||||
|
||||
// Forward auth client for new connection in server
|
||||
if auth_client_tx.send(auth_client).await.is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Forward connected and routed transport to server
|
||||
if conn_tx.send(manager_transport.into_split()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let launch_handlers = Arc::new(RwLock::new(config.launch_handlers.drain().collect()));
|
||||
let weak_launch_handlers = Arc::downgrade(&launch_handlers);
|
||||
let connect_handlers = Arc::new(RwLock::new(config.connect_handlers.drain().collect()));
|
||||
let weak_connect_handlers = Arc::downgrade(&connect_handlers);
|
||||
let server_ref = Self {
|
||||
auth_client_rx: Mutex::new(auth_client_rx),
|
||||
config,
|
||||
launch_handlers,
|
||||
connect_handlers,
|
||||
connections: RwLock::new(HashMap::new()),
|
||||
task,
|
||||
}
|
||||
.start(mpsc_listener)?;
|
||||
|
||||
Ok(DistantManagerRef {
|
||||
launch_handlers: weak_launch_handlers,
|
||||
connect_handlers: weak_connect_handlers,
|
||||
inner: server_ref,
|
||||
})
|
||||
}
|
||||
|
||||
/// Launches a new server at the specified `destination` using the given `extra` information
|
||||
/// and authentication client (if needed) to retrieve additional information needed to
|
||||
/// enter the destination prior to starting the server, returning the destination of the
|
||||
/// launched server
|
||||
async fn launch(
|
||||
&self,
|
||||
destination: Destination,
|
||||
extra: Extra,
|
||||
auth: Option<&mut AuthClient>,
|
||||
) -> io::Result<Destination> {
|
||||
let auth = auth.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Authentication client not initialized",
|
||||
)
|
||||
})?;
|
||||
|
||||
let scheme = match destination.scheme.as_deref() {
|
||||
Some(scheme) => {
|
||||
trace!("Using scheme {}", scheme);
|
||||
scheme
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"Using fallback scheme of {}",
|
||||
self.config.launch_fallback_scheme.as_str()
|
||||
);
|
||||
self.config.launch_fallback_scheme.as_str()
|
||||
}
|
||||
}
|
||||
.to_lowercase();
|
||||
|
||||
let credentials = {
|
||||
let lock = self.launch_handlers.read().await;
|
||||
let handler = lock.get(&scheme).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("No launch handler registered for {}", scheme),
|
||||
)
|
||||
})?;
|
||||
handler.launch(&destination, &extra, auth).await?
|
||||
};
|
||||
|
||||
Ok(credentials)
|
||||
}
|
||||
|
||||
/// Connects to a new server at the specified `destination` using the given `extra` information
|
||||
/// and authentication client (if needed) to retrieve additional information needed to
|
||||
/// establish the connection to the server
|
||||
async fn connect(
|
||||
&self,
|
||||
destination: Destination,
|
||||
extra: Extra,
|
||||
auth: Option<&mut AuthClient>,
|
||||
) -> io::Result<ConnectionId> {
|
||||
let auth = auth.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Authentication client not initialized",
|
||||
)
|
||||
})?;
|
||||
|
||||
let scheme = match destination.scheme.as_deref() {
|
||||
Some(scheme) => {
|
||||
trace!("Using scheme {}", scheme);
|
||||
scheme
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"Using fallback scheme of {}",
|
||||
self.config.connect_fallback_scheme.as_str()
|
||||
);
|
||||
self.config.connect_fallback_scheme.as_str()
|
||||
}
|
||||
}
|
||||
.to_lowercase();
|
||||
|
||||
let (writer, reader) = {
|
||||
let lock = self.connect_handlers.read().await;
|
||||
let handler = lock.get(&scheme).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("No connect handler registered for {}", scheme),
|
||||
)
|
||||
})?;
|
||||
handler.connect(&destination, &extra, auth).await?
|
||||
};
|
||||
|
||||
let connection = DistantManagerConnection::new(destination, extra, writer, reader);
|
||||
let id = connection.id;
|
||||
self.connections.write().await.insert(id, connection);
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Retrieves information about the connection to the server with the specified `id`
|
||||
async fn info(&self, id: ConnectionId) -> io::Result<ConnectionInfo> {
|
||||
match self.connections.read().await.get(&id) {
|
||||
Some(connection) => Ok(ConnectionInfo {
|
||||
id: connection.id,
|
||||
destination: connection.destination.clone(),
|
||||
extra: connection.extra.clone(),
|
||||
}),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"No connection found",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves a list of connections to servers
|
||||
async fn list(&self) -> io::Result<ConnectionList> {
|
||||
Ok(ConnectionList(
|
||||
self.connections
|
||||
.read()
|
||||
.await
|
||||
.values()
|
||||
.map(|conn| (conn.id, conn.destination.clone()))
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Kills the connection to the server with the specified `id`
|
||||
async fn kill(&self, id: ConnectionId) -> io::Result<()> {
|
||||
match self.connections.write().await.remove(&id) {
|
||||
Some(_) => Ok(()),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"No connection found",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DistantManagerServerConnection {
|
||||
/// Authentication client that manager can use when establishing a new connection
|
||||
/// and needing to get authentication details from the client to move forward
|
||||
auth_client: Option<Mutex<AuthClient>>,
|
||||
|
||||
/// Holds on to open channels feeding data back from a server to some connected client,
|
||||
/// enabling us to cancel the tasks on demand
|
||||
channels: RwLock<HashMap<ChannelId, DistantManagerChannel>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Server for DistantManager {
|
||||
type Request = ManagerRequest;
|
||||
type Response = ManagerResponse;
|
||||
type LocalData = DistantManagerServerConnection;
|
||||
|
||||
async fn on_accept(&self, local_data: &mut Self::LocalData) {
|
||||
local_data.auth_client = self
|
||||
.auth_client_rx
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
.await
|
||||
.map(Mutex::new);
|
||||
|
||||
// Enable jit handshake
|
||||
if let Some(auth_client) = local_data.auth_client.as_ref() {
|
||||
auth_client.lock().await.set_jit_handshake(true);
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
|
||||
let ServerCtx {
|
||||
connection_id,
|
||||
request,
|
||||
reply,
|
||||
local_data,
|
||||
} = ctx;
|
||||
|
||||
let response = match request.payload {
|
||||
ManagerRequest::Launch { destination, extra } => {
|
||||
let mut auth = match local_data.auth_client.as_ref() {
|
||||
Some(client) => Some(client.lock().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
match self.launch(*destination, extra, auth.as_deref_mut()).await {
|
||||
Ok(destination) => ManagerResponse::Launched { destination },
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
}
|
||||
}
|
||||
ManagerRequest::Connect { destination, extra } => {
|
||||
let mut auth = match local_data.auth_client.as_ref() {
|
||||
Some(client) => Some(client.lock().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
match self.connect(*destination, extra, auth.as_deref_mut()).await {
|
||||
Ok(id) => ManagerResponse::Connected { id },
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
}
|
||||
}
|
||||
ManagerRequest::OpenChannel { id } => match self.connections.read().await.get(&id) {
|
||||
Some(connection) => match connection.open_channel(reply.clone()).await {
|
||||
Ok(channel) => {
|
||||
let id = channel.id();
|
||||
local_data.channels.write().await.insert(id, channel);
|
||||
ManagerResponse::ChannelOpened { id }
|
||||
}
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
None => ManagerResponse::Error(
|
||||
io::Error::new(io::ErrorKind::NotConnected, "Connection does not exist").into(),
|
||||
),
|
||||
},
|
||||
ManagerRequest::Channel { id, request } => {
|
||||
match local_data.channels.read().await.get(&id) {
|
||||
// TODO: For now, we are NOT sending back a response to acknowledge
|
||||
// a successful channel send. We could do this in order for
|
||||
// the client to listen for a complete send, but is it worth it?
|
||||
Some(channel) => match channel.send(request).await {
|
||||
Ok(_) => return,
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
None => ManagerResponse::Error(
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"Channel is not open or does not exist",
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
}
|
||||
}
|
||||
ManagerRequest::CloseChannel { id } => {
|
||||
match local_data.channels.write().await.remove(&id) {
|
||||
Some(channel) => match channel.close().await {
|
||||
Ok(_) => ManagerResponse::ChannelClosed { id },
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
None => ManagerResponse::Error(
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"Channel is not open or does not exist",
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
}
|
||||
}
|
||||
ManagerRequest::Info { id } => match self.info(id).await {
|
||||
Ok(info) => ManagerResponse::Info(info),
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
ManagerRequest::List => match self.list().await {
|
||||
Ok(list) => ManagerResponse::List(list),
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
ManagerRequest::Kill { id } => match self.kill(id).await {
|
||||
Ok(()) => ManagerResponse::Killed,
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
ManagerRequest::Shutdown => {
|
||||
if let Err(x) = reply.send(ManagerResponse::Shutdown).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
|
||||
// Clear out handler state in order to trigger drops
|
||||
self.launch_handlers.write().await.clear();
|
||||
self.connect_handlers.write().await.clear();
|
||||
|
||||
// Shutdown the primary server task
|
||||
self.task.abort();
|
||||
|
||||
// TODO: Perform a graceful shutdown instead of this?
|
||||
// Review https://tokio.rs/tokio/topics/shutdown
|
||||
std::process::exit(0);
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(x) = reply.send(response).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use distant_net::{
|
||||
AuthClient, FramedTransport, HeapAuthServer, InmemoryTransport, IntoSplit, MappedListener,
|
||||
OneshotListener, PlainCodec, ServerExt, ServerRef,
|
||||
};
|
||||
|
||||
/// Create a new server, bypassing the start loop
|
||||
fn setup() -> DistantManager {
|
||||
let (_, rx) = mpsc::channel(1);
|
||||
DistantManager {
|
||||
auth_client_rx: Mutex::new(rx),
|
||||
config: Default::default(),
|
||||
connections: RwLock::new(HashMap::new()),
|
||||
launch_handlers: Arc::new(RwLock::new(HashMap::new())),
|
||||
connect_handlers: Arc::new(RwLock::new(HashMap::new())),
|
||||
task: tokio::spawn(async move {}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a connected [`AuthClient`] with a launched auth server that blindly responds
|
||||
fn auth_client_server() -> (AuthClient, Box<dyn ServerRef>) {
|
||||
let (t1, t2) = FramedTransport::pair(1);
|
||||
let client = AuthClient::from(Client::from_framed_transport(t1).unwrap());
|
||||
|
||||
// Create a server that does nothing, but will support
|
||||
let server = HeapAuthServer {
|
||||
on_challenge: Box::new(|_, _| Vec::new()),
|
||||
on_verify: Box::new(|_, _| false),
|
||||
on_info: Box::new(|_| ()),
|
||||
on_error: Box::new(|_, _| ()),
|
||||
}
|
||||
.start(MappedListener::new(OneshotListener::from_value(t2), |t| {
|
||||
t.into_split()
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
(client, server)
|
||||
}
|
||||
|
||||
fn dummy_distant_writer_reader() -> (BoxedDistantWriter, BoxedDistantReader) {
|
||||
setup_distant_writer_reader().0
|
||||
}
|
||||
|
||||
/// Creates a writer & reader with a connected transport
|
||||
fn setup_distant_writer_reader() -> (
|
||||
(BoxedDistantWriter, BoxedDistantReader),
|
||||
FramedTransport<InmemoryTransport, PlainCodec>,
|
||||
) {
|
||||
let (t1, t2) = FramedTransport::pair(1);
|
||||
let (writer, reader) = t1.into_split();
|
||||
((Box::new(writer), Box::new(reader)), t2)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn launch_should_fail_if_destination_scheme_is_unsupported() {
|
||||
let server = setup();
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.launch(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn launch_should_fail_if_handler_tied_to_scheme_fails() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn LaunchHandler> = Box::new(|_: &_, _: &_, _: &mut _| async {
|
||||
Err(io::Error::new(io::ErrorKind::Other, "test failure"))
|
||||
});
|
||||
|
||||
server
|
||||
.launch_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.launch(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::Other);
|
||||
assert_eq!(err.to_string(), "test failure");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn launch_should_return_new_destination_on_success() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn LaunchHandler> = {
|
||||
Box::new(|_: &_, _: &_, _: &mut _| async {
|
||||
Ok("scheme2://host2".parse::<Destination>().unwrap())
|
||||
})
|
||||
};
|
||||
|
||||
server
|
||||
.launch_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "key=value".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let destination = server
|
||||
.launch(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
destination,
|
||||
"scheme2://host2".parse::<Destination>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_fail_if_destination_scheme_is_unsupported() {
|
||||
let server = setup();
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.connect(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_fail_if_handler_tied_to_scheme_fails() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn ConnectHandler> = Box::new(|_: &_, _: &_, _: &mut _| async {
|
||||
Err(io::Error::new(io::ErrorKind::Other, "test failure"))
|
||||
});
|
||||
|
||||
server
|
||||
.connect_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.connect(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::Other);
|
||||
assert_eq!(err.to_string(), "test failure");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_return_id_of_new_connection_on_success() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn ConnectHandler> =
|
||||
Box::new(|_: &_, _: &_, _: &mut _| async { Ok(dummy_distant_writer_reader()) });
|
||||
|
||||
server
|
||||
.connect_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let extra = "key=value".parse::<Extra>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let id = server
|
||||
.connect(destination, extra, Some(&mut auth))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let lock = server.connections.read().await;
|
||||
let connection = lock.get(&id).unwrap();
|
||||
assert_eq!(connection.id, id);
|
||||
assert_eq!(connection.destination, "scheme://host");
|
||||
assert_eq!(connection.extra, "key=value".parse().unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_fail_if_no_connection_found_for_specified_id() {
|
||||
let server = setup();
|
||||
|
||||
let err = server.info(999).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_return_information_about_established_connection() {
|
||||
let server = setup();
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"scheme://host".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id = connection.id;
|
||||
server.connections.write().await.insert(id, connection);
|
||||
|
||||
let info = server.info(id).await.unwrap();
|
||||
assert_eq!(
|
||||
info,
|
||||
ConnectionInfo {
|
||||
id,
|
||||
destination: "scheme://host".parse().unwrap(),
|
||||
extra: "key=value".parse().unwrap(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_return_empty_connection_list_if_no_established_connections() {
|
||||
let server = setup();
|
||||
|
||||
let list = server.list().await.unwrap();
|
||||
assert_eq!(list, ConnectionList(HashMap::new()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_return_a_list_of_established_connections() {
|
||||
let server = setup();
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"scheme://host".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id_1 = connection.id;
|
||||
server.connections.write().await.insert(id_1, connection);
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"other://host2".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id_2 = connection.id;
|
||||
server.connections.write().await.insert(id_2, connection);
|
||||
|
||||
let list = server.list().await.unwrap();
|
||||
assert_eq!(
|
||||
list.get(&id_1).unwrap(),
|
||||
&"scheme://host".parse::<Destination>().unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
list.get(&id_2).unwrap(),
|
||||
&"other://host2".parse::<Destination>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_fail_if_no_connection_found_for_specified_id() {
|
||||
let server = setup();
|
||||
|
||||
let err = server.kill(999).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_terminate_established_connection_and_remove_it_from_the_list() {
|
||||
let server = setup();
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"scheme://host".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id = connection.id;
|
||||
server.connections.write().await.insert(id, connection);
|
||||
|
||||
server.kill(id).await.unwrap();
|
||||
|
||||
let lock = server.connections.read().await;
|
||||
assert!(!lock.contains_key(&id), "Connection still exists");
|
||||
}
|
||||
}
|
@ -1,201 +0,0 @@
|
||||
use crate::{
|
||||
manager::{
|
||||
data::{ChannelId, ConnectionId, Destination, Extra},
|
||||
BoxedDistantReader, BoxedDistantWriter,
|
||||
},
|
||||
DistantMsg, DistantRequestData, DistantResponseData, ManagerResponse,
|
||||
};
|
||||
use distant_net::{Request, Response, ServerReply};
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io};
|
||||
use tokio::{sync::mpsc, task::JoinHandle};
|
||||
|
||||
/// Represents a connection a distant manager has with some distant-compatible server
|
||||
pub struct DistantManagerConnection {
|
||||
pub id: ConnectionId,
|
||||
pub destination: Destination,
|
||||
pub extra: Extra,
|
||||
tx: mpsc::Sender<StateMachine>,
|
||||
reader_task: JoinHandle<()>,
|
||||
writer_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DistantManagerChannel {
|
||||
channel_id: ChannelId,
|
||||
tx: mpsc::Sender<StateMachine>,
|
||||
}
|
||||
|
||||
impl DistantManagerChannel {
|
||||
pub fn id(&self) -> ChannelId {
|
||||
self.channel_id
|
||||
}
|
||||
|
||||
pub async fn send(&self, request: Request<DistantMsg<DistantRequestData>>) -> io::Result<()> {
|
||||
let channel_id = self.channel_id;
|
||||
self.tx
|
||||
.send(StateMachine::Write {
|
||||
id: channel_id,
|
||||
request,
|
||||
})
|
||||
.await
|
||||
.map_err(|x| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
format!("channel {} send failed: {}", channel_id, x),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> io::Result<()> {
|
||||
let channel_id = self.channel_id;
|
||||
self.tx
|
||||
.send(StateMachine::Unregister { id: channel_id })
|
||||
.await
|
||||
.map_err(|x| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
format!("channel {} close failed: {}", channel_id, x),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
enum StateMachine {
|
||||
Register {
|
||||
id: ChannelId,
|
||||
reply: ServerReply<ManagerResponse>,
|
||||
},
|
||||
|
||||
Unregister {
|
||||
id: ChannelId,
|
||||
},
|
||||
|
||||
Read {
|
||||
response: Response<DistantMsg<DistantResponseData>>,
|
||||
},
|
||||
|
||||
Write {
|
||||
id: ChannelId,
|
||||
request: Request<DistantMsg<DistantRequestData>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl DistantManagerConnection {
|
||||
pub fn new(
|
||||
destination: Destination,
|
||||
extra: Extra,
|
||||
mut writer: BoxedDistantWriter,
|
||||
mut reader: BoxedDistantReader,
|
||||
) -> Self {
|
||||
let connection_id = rand::random();
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let reader_task = {
|
||||
let tx = tx.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match reader.read().await {
|
||||
Ok(Some(response)) => {
|
||||
if tx.send(StateMachine::Read { response }).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let mut registered = HashMap::new();
|
||||
while let Some(state_machine) = rx.recv().await {
|
||||
match state_machine {
|
||||
StateMachine::Register { id, reply } => {
|
||||
registered.insert(id, reply);
|
||||
}
|
||||
StateMachine::Unregister { id } => {
|
||||
registered.remove(&id);
|
||||
}
|
||||
StateMachine::Read { mut response } => {
|
||||
// Split {channel id}_{request id} back into pieces and
|
||||
// update the origin id to match the request id only
|
||||
let channel_id = match response.origin_id.split_once('_') {
|
||||
Some((cid_str, oid_str)) => {
|
||||
if let Ok(cid) = cid_str.parse::<ChannelId>() {
|
||||
response.origin_id = oid_str.to_string();
|
||||
cid
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
None => continue,
|
||||
};
|
||||
|
||||
if let Some(reply) = registered.get(&channel_id) {
|
||||
let response = ManagerResponse::Channel {
|
||||
id: channel_id,
|
||||
response,
|
||||
};
|
||||
if let Err(x) = reply.send(response).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
StateMachine::Write { id, request } => {
|
||||
// Combine channel id with request id so we can properly forward
|
||||
// the response containing this in the origin id
|
||||
let request = Request {
|
||||
id: format!("{}_{}", id, request.id),
|
||||
payload: request.payload,
|
||||
};
|
||||
if let Err(x) = writer.write(request).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
id: connection_id,
|
||||
destination,
|
||||
extra,
|
||||
tx,
|
||||
reader_task,
|
||||
writer_task,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn open_channel(
|
||||
&self,
|
||||
reply: ServerReply<ManagerResponse>,
|
||||
) -> io::Result<DistantManagerChannel> {
|
||||
let channel_id = rand::random();
|
||||
self.tx
|
||||
.send(StateMachine::Register {
|
||||
id: channel_id,
|
||||
reply,
|
||||
})
|
||||
.await
|
||||
.map_err(|x| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
format!("open_channel failed: {}", x),
|
||||
)
|
||||
})?;
|
||||
Ok(DistantManagerChannel {
|
||||
channel_id,
|
||||
tx: self.tx.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DistantManagerConnection {
|
||||
fn drop(&mut self) {
|
||||
self.reader_task.abort();
|
||||
self.writer_task.abort();
|
||||
}
|
||||
}
|
@ -1,30 +0,0 @@
|
||||
use crate::{DistantManager, DistantManagerConfig};
|
||||
use distant_net::{
|
||||
Codec, FramedTransport, IntoSplit, MappedListener, PortRange, TcpListener, TcpServerRef,
|
||||
};
|
||||
use std::{io, net::IpAddr};
|
||||
|
||||
impl DistantManager {
|
||||
/// Start a new server by binding to the given IP address and one of the ports in the
|
||||
/// specified range, mapping all connections to use the given codec
|
||||
pub async fn start_tcp<P, C>(
|
||||
config: DistantManagerConfig,
|
||||
addr: IpAddr,
|
||||
port: P,
|
||||
codec: C,
|
||||
) -> io::Result<TcpServerRef>
|
||||
where
|
||||
P: Into<PortRange> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let listener = TcpListener::bind(addr, port).await?;
|
||||
let port = listener.port();
|
||||
|
||||
let listener = MappedListener::new(listener, move |transport| {
|
||||
let transport = FramedTransport::new(transport, codec.clone());
|
||||
transport.into_split()
|
||||
});
|
||||
let inner = DistantManager::start(config, listener)?;
|
||||
Ok(TcpServerRef::new(addr, port, Box::new(inner)))
|
||||
}
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
use crate::{DistantManager, DistantManagerConfig};
|
||||
use distant_net::{
|
||||
Codec, FramedTransport, IntoSplit, MappedListener, UnixSocketListener, UnixSocketServerRef,
|
||||
};
|
||||
use std::{io, path::Path};
|
||||
|
||||
impl DistantManager {
|
||||
/// Start a new server using the specified path as a unix socket using default unix socket file
|
||||
/// permissions
|
||||
pub async fn start_unix_socket<P, C>(
|
||||
config: DistantManagerConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
) -> io::Result<UnixSocketServerRef>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
Self::start_unix_socket_with_permissions(
|
||||
config,
|
||||
path,
|
||||
codec,
|
||||
UnixSocketListener::default_unix_socket_file_permissions(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Start a new server using the specified path as a unix socket and `mode` as the unix socket
|
||||
/// file permissions
|
||||
pub async fn start_unix_socket_with_permissions<P, C>(
|
||||
config: DistantManagerConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
mode: u32,
|
||||
) -> io::Result<UnixSocketServerRef>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let listener = UnixSocketListener::bind_with_permissions(path, mode).await?;
|
||||
let path = listener.path().to_path_buf();
|
||||
|
||||
let listener = MappedListener::new(listener, move |transport| {
|
||||
let transport = FramedTransport::new(transport, codec.clone());
|
||||
transport.into_split()
|
||||
});
|
||||
let inner = DistantManager::start(config, listener)?;
|
||||
Ok(UnixSocketServerRef::new(path, Box::new(inner)))
|
||||
}
|
||||
}
|
@ -1,48 +0,0 @@
|
||||
use crate::{DistantManager, DistantManagerConfig};
|
||||
use distant_net::{
|
||||
Codec, FramedTransport, IntoSplit, MappedListener, WindowsPipeListener, WindowsPipeServerRef,
|
||||
};
|
||||
use std::{
|
||||
ffi::{OsStr, OsString},
|
||||
io,
|
||||
};
|
||||
|
||||
impl DistantManager {
|
||||
/// Start a new server at the specified address via `\\.\pipe\{name}` using the given codec
|
||||
pub async fn start_local_named_pipe<N, C>(
|
||||
config: DistantManagerConfig,
|
||||
name: N,
|
||||
codec: C,
|
||||
) -> io::Result<WindowsPipeServerRef>
|
||||
where
|
||||
Self: Sized,
|
||||
N: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let mut addr = OsString::from(r"\\.\pipe\");
|
||||
addr.push(name.as_ref());
|
||||
Self::start_named_pipe(config, addr, codec).await
|
||||
}
|
||||
|
||||
/// Start a new server at the specified pipe address using the given codec
|
||||
pub async fn start_named_pipe<A, C>(
|
||||
config: DistantManagerConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
) -> io::Result<WindowsPipeServerRef>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let a = addr.as_ref();
|
||||
let listener = WindowsPipeListener::bind(a)?;
|
||||
let addr = listener.addr().to_os_string();
|
||||
|
||||
let listener = MappedListener::new(listener, move |transport| {
|
||||
let transport = FramedTransport::new(transport, codec.clone());
|
||||
transport.into_split()
|
||||
});
|
||||
let inner = DistantManager::start(config, listener)?;
|
||||
Ok(WindowsPipeServerRef::new(addr, Box::new(inner)))
|
||||
}
|
||||
}
|
@ -1,69 +0,0 @@
|
||||
use crate::{
|
||||
manager::data::{Destination, Extra},
|
||||
DistantMsg, DistantRequestData, DistantResponseData,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{AuthClient, Request, Response, TypedAsyncRead, TypedAsyncWrite};
|
||||
use std::{future::Future, io};
|
||||
|
||||
pub type BoxedDistantWriter =
|
||||
Box<dyn TypedAsyncWrite<Request<DistantMsg<DistantRequestData>>> + Send>;
|
||||
pub type BoxedDistantReader =
|
||||
Box<dyn TypedAsyncRead<Response<DistantMsg<DistantResponseData>>> + Send>;
|
||||
pub type BoxedDistantWriterReader = (BoxedDistantWriter, BoxedDistantReader);
|
||||
pub type BoxedLaunchHandler = Box<dyn LaunchHandler>;
|
||||
pub type BoxedConnectHandler = Box<dyn ConnectHandler>;
|
||||
|
||||
/// Used to launch a server at the specified destination, returning some result as a vec of bytes
|
||||
#[async_trait]
|
||||
pub trait LaunchHandler: Send + Sync {
|
||||
async fn launch(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
extra: &Extra,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<Destination>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F, R> LaunchHandler for F
|
||||
where
|
||||
F: for<'a> Fn(&'a Destination, &'a Extra, &'a mut AuthClient) -> R + Send + Sync + 'static,
|
||||
R: Future<Output = io::Result<Destination>> + Send + 'static,
|
||||
{
|
||||
async fn launch(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
extra: &Extra,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<Destination> {
|
||||
self(destination, extra, auth_client).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Used to connect to a destination, returning a connected reader and writer pair
|
||||
#[async_trait]
|
||||
pub trait ConnectHandler: Send + Sync {
|
||||
async fn connect(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
extra: &Extra,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<BoxedDistantWriterReader>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F, R> ConnectHandler for F
|
||||
where
|
||||
F: for<'a> Fn(&'a Destination, &'a Extra, &'a mut AuthClient) -> R + Send + Sync + 'static,
|
||||
R: Future<Output = io::Result<BoxedDistantWriterReader>> + Send + 'static,
|
||||
{
|
||||
async fn connect(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
extra: &Extra,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<BoxedDistantWriterReader> {
|
||||
self(destination, extra, auth_client).await
|
||||
}
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
use super::{BoxedConnectHandler, BoxedLaunchHandler, ConnectHandler, LaunchHandler};
|
||||
use distant_net::{ServerRef, ServerState};
|
||||
use std::{collections::HashMap, io, sync::Weak};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Reference to a distant manager's server instance
|
||||
pub struct DistantManagerRef {
|
||||
/// Mapping of "scheme" -> handler
|
||||
pub(crate) launch_handlers: Weak<RwLock<HashMap<String, BoxedLaunchHandler>>>,
|
||||
|
||||
/// Mapping of "scheme" -> handler
|
||||
pub(crate) connect_handlers: Weak<RwLock<HashMap<String, BoxedConnectHandler>>>,
|
||||
|
||||
pub(crate) inner: Box<dyn ServerRef>,
|
||||
}
|
||||
|
||||
impl DistantManagerRef {
|
||||
/// Registers a new [`LaunchHandler`] for the specified scheme (e.g. "distant" or "ssh")
|
||||
pub async fn register_launch_handler(
|
||||
&self,
|
||||
scheme: impl Into<String>,
|
||||
handler: impl LaunchHandler + 'static,
|
||||
) -> io::Result<()> {
|
||||
let handlers = Weak::upgrade(&self.launch_handlers).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handler reference is no longer available",
|
||||
)
|
||||
})?;
|
||||
|
||||
handlers
|
||||
.write()
|
||||
.await
|
||||
.insert(scheme.into(), Box::new(handler));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Registers a new [`ConnectHandler`] for the specified scheme (e.g. "distant" or "ssh")
|
||||
pub async fn register_connect_handler(
|
||||
&self,
|
||||
scheme: impl Into<String>,
|
||||
handler: impl ConnectHandler + 'static,
|
||||
) -> io::Result<()> {
|
||||
let handlers = Weak::upgrade(&self.connect_handlers).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handler reference is no longer available",
|
||||
)
|
||||
})?;
|
||||
|
||||
handlers
|
||||
.write()
|
||||
.await
|
||||
.insert(scheme.into(), Box::new(handler));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerRef for DistantManagerRef {
|
||||
fn state(&self) -> &ServerState {
|
||||
self.inner.state()
|
||||
}
|
||||
|
||||
fn is_finished(&self) -> bool {
|
||||
self.inner.is_finished()
|
||||
}
|
||||
|
||||
fn abort(&self) {
|
||||
self.inner.abort();
|
||||
}
|
||||
}
|
@ -0,0 +1,325 @@
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use distant_core::{
|
||||
DistantApi, DistantApiServerHandler, DistantChannelExt, DistantClient, DistantCtx,
|
||||
};
|
||||
use distant_net::auth::{DummyAuthHandler, Verifier};
|
||||
use distant_net::client::Client;
|
||||
use distant_net::common::{InmemoryTransport, OneshotListener, Version};
|
||||
use distant_net::server::{Server, ServerRef};
|
||||
use distant_protocol::PROTOCOL_VERSION;
|
||||
|
||||
/// Stands up an inmemory client and server using the given api.
|
||||
async fn setup(api: impl DistantApi + Send + Sync + 'static) -> (DistantClient, ServerRef) {
|
||||
let (t1, t2) = InmemoryTransport::pair(100);
|
||||
|
||||
let server = Server::new()
|
||||
.handler(DistantApiServerHandler::new(api))
|
||||
.verifier(Verifier::none())
|
||||
.version(Version::new(
|
||||
PROTOCOL_VERSION.major,
|
||||
PROTOCOL_VERSION.minor,
|
||||
PROTOCOL_VERSION.patch,
|
||||
))
|
||||
.start(OneshotListener::from_value(t2))
|
||||
.expect("Failed to start server");
|
||||
|
||||
let client: DistantClient = Client::build()
|
||||
.auth_handler(DummyAuthHandler)
|
||||
.connector(t1)
|
||||
.version(Version::new(
|
||||
PROTOCOL_VERSION.major,
|
||||
PROTOCOL_VERSION.minor,
|
||||
PROTOCOL_VERSION.patch,
|
||||
))
|
||||
.connect()
|
||||
.await
|
||||
.expect("Failed to connect to server");
|
||||
|
||||
(client, server)
|
||||
}
|
||||
|
||||
mod single {
|
||||
use test_log::test;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn should_support_single_request_returning_error() {
|
||||
struct TestDistantApi;
|
||||
|
||||
#[async_trait]
|
||||
impl DistantApi for TestDistantApi {
|
||||
async fn read_file(&self, _ctx: DistantCtx, _path: PathBuf) -> io::Result<Vec<u8>> {
|
||||
Err(io::Error::new(io::ErrorKind::NotFound, "test error"))
|
||||
}
|
||||
}
|
||||
|
||||
let (mut client, _server) = setup(TestDistantApi).await;
|
||||
|
||||
let error = client.read_file(PathBuf::from("file")).await.unwrap_err();
|
||||
assert_eq!(error.kind(), io::ErrorKind::NotFound);
|
||||
assert_eq!(error.to_string(), "test error");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn should_support_single_request_returning_success() {
|
||||
struct TestDistantApi;
|
||||
|
||||
#[async_trait]
|
||||
impl DistantApi for TestDistantApi {
|
||||
async fn read_file(&self, _ctx: DistantCtx, _path: PathBuf) -> io::Result<Vec<u8>> {
|
||||
Ok(b"hello world".to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
let (mut client, _server) = setup(TestDistantApi).await;
|
||||
|
||||
let contents = client.read_file(PathBuf::from("file")).await.unwrap();
|
||||
assert_eq!(contents, b"hello world");
|
||||
}
|
||||
}
|
||||
|
||||
mod batch_parallel {
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use distant_net::common::Request;
|
||||
use distant_protocol::{Msg, Request as RequestPayload};
|
||||
use test_log::test;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn should_support_multiple_requests_running_in_parallel() {
|
||||
struct TestDistantApi;
|
||||
|
||||
#[async_trait]
|
||||
impl DistantApi for TestDistantApi {
|
||||
async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
|
||||
if path.to_str().unwrap() == "slow" {
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
|
||||
let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
|
||||
Ok((time.as_millis() as u64).to_be_bytes().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
let (mut client, _server) = setup(TestDistantApi).await;
|
||||
|
||||
let request = Request::new(Msg::batch([
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("file1"),
|
||||
},
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("slow"),
|
||||
},
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("file2"),
|
||||
},
|
||||
]));
|
||||
|
||||
let response = client.send(request).await.unwrap();
|
||||
let payloads = response.payload.into_batch().unwrap();
|
||||
|
||||
// Collect our times from the reading
|
||||
let mut times = Vec::new();
|
||||
for payload in payloads {
|
||||
match payload {
|
||||
distant_protocol::Response::Blob { data } => {
|
||||
let mut buf = [0u8; 8];
|
||||
buf.copy_from_slice(&data[..8]);
|
||||
times.push(u64::from_be_bytes(buf));
|
||||
}
|
||||
x => panic!("Unexpected payload: {x:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that these ran in parallel as the first and third requests should not be
|
||||
// over 500 milliseconds apart due to the sleep in the middle!
|
||||
let diff = times[0].abs_diff(times[2]);
|
||||
assert!(diff <= 500, "Sequential ordering detected");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn should_run_all_requests_even_if_some_fail() {
|
||||
struct TestDistantApi;
|
||||
|
||||
#[async_trait]
|
||||
impl DistantApi for TestDistantApi {
|
||||
async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
|
||||
if path.to_str().unwrap() == "fail" {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "test error"));
|
||||
}
|
||||
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
let (mut client, _server) = setup(TestDistantApi).await;
|
||||
|
||||
let request = Request::new(Msg::batch([
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("file1"),
|
||||
},
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("fail"),
|
||||
},
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("file2"),
|
||||
},
|
||||
]));
|
||||
|
||||
let response = client.send(request).await.unwrap();
|
||||
let payloads = response.payload.into_batch().unwrap();
|
||||
|
||||
// Should be a success, error, and success
|
||||
assert!(
|
||||
matches!(payloads[0], distant_protocol::Response::Blob { .. }),
|
||||
"Unexpected payloads[0]: {:?}",
|
||||
payloads[0]
|
||||
);
|
||||
assert!(
|
||||
matches!(
|
||||
&payloads[1],
|
||||
distant_protocol::Response::Error(distant_protocol::Error { kind, description })
|
||||
if matches!(kind, distant_protocol::ErrorKind::Other) && description == "test error"
|
||||
),
|
||||
"Unexpected payloads[1]: {:?}",
|
||||
payloads[1]
|
||||
);
|
||||
assert!(
|
||||
matches!(payloads[2], distant_protocol::Response::Blob { .. }),
|
||||
"Unexpected payloads[2]: {:?}",
|
||||
payloads[2]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
mod batch_sequence {
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use distant_net::common::Request;
|
||||
use distant_protocol::{Msg, Request as RequestPayload};
|
||||
use test_log::test;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn should_support_multiple_requests_running_in_sequence() {
|
||||
struct TestDistantApi;
|
||||
|
||||
#[async_trait]
|
||||
impl DistantApi for TestDistantApi {
|
||||
async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
|
||||
if path.to_str().unwrap() == "slow" {
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
|
||||
let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
|
||||
Ok((time.as_millis() as u64).to_be_bytes().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
let (mut client, _server) = setup(TestDistantApi).await;
|
||||
|
||||
let mut request = Request::new(Msg::batch([
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("file1"),
|
||||
},
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("slow"),
|
||||
},
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("file2"),
|
||||
},
|
||||
]));
|
||||
|
||||
// Mark as running in sequence
|
||||
request.header.insert("sequence", true);
|
||||
|
||||
let response = client.send(request).await.unwrap();
|
||||
let payloads = response.payload.into_batch().unwrap();
|
||||
|
||||
// Collect our times from the reading
|
||||
let mut times = Vec::new();
|
||||
for payload in payloads {
|
||||
match payload {
|
||||
distant_protocol::Response::Blob { data } => {
|
||||
let mut buf = [0u8; 8];
|
||||
buf.copy_from_slice(&data[..8]);
|
||||
times.push(u64::from_be_bytes(buf));
|
||||
}
|
||||
x => panic!("Unexpected payload: {x:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that these ran in sequence as the first and third requests should be
|
||||
// over 500 milliseconds apart due to the sleep in the middle!
|
||||
let diff = times[0].abs_diff(times[2]);
|
||||
assert!(diff > 500, "Parallel ordering detected");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn should_interrupt_any_requests_following_a_failure() {
|
||||
struct TestDistantApi;
|
||||
|
||||
#[async_trait]
|
||||
impl DistantApi for TestDistantApi {
|
||||
async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
|
||||
if path.to_str().unwrap() == "fail" {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "test error"));
|
||||
}
|
||||
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
let (mut client, _server) = setup(TestDistantApi).await;
|
||||
|
||||
let mut request = Request::new(Msg::batch([
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("file1"),
|
||||
},
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("fail"),
|
||||
},
|
||||
RequestPayload::FileRead {
|
||||
path: PathBuf::from("file2"),
|
||||
},
|
||||
]));
|
||||
|
||||
// Mark as running in sequence
|
||||
request.header.insert("sequence", true);
|
||||
|
||||
let response = client.send(request).await.unwrap();
|
||||
let payloads = response.payload.into_batch().unwrap();
|
||||
|
||||
// Should be a success, error, and interrupt
|
||||
assert!(
|
||||
matches!(payloads[0], distant_protocol::Response::Blob { .. }),
|
||||
"Unexpected payloads[0]: {:?}",
|
||||
payloads[0]
|
||||
);
|
||||
assert!(
|
||||
matches!(
|
||||
&payloads[1],
|
||||
distant_protocol::Response::Error(distant_protocol::Error { kind, description })
|
||||
if matches!(kind, distant_protocol::ErrorKind::Other) && description == "test error"
|
||||
),
|
||||
"Unexpected payloads[1]: {:?}",
|
||||
payloads[1]
|
||||
);
|
||||
assert!(
|
||||
matches!(
|
||||
&payloads[2],
|
||||
distant_protocol::Response::Error(distant_protocol::Error { kind, .. })
|
||||
if matches!(kind, distant_protocol::ErrorKind::Interrupted)
|
||||
),
|
||||
"Unexpected payloads[2]: {:?}",
|
||||
payloads[2]
|
||||
);
|
||||
}
|
||||
}
|
@ -1,96 +0,0 @@
|
||||
use distant_core::{
|
||||
net::{FramedTransport, InmemoryTransport, IntoSplit, OneshotListener, PlainCodec},
|
||||
BoxedDistantReader, BoxedDistantWriter, Destination, DistantApiServer, DistantChannelExt,
|
||||
DistantManager, DistantManagerClient, DistantManagerClientConfig, DistantManagerConfig, Extra,
|
||||
};
|
||||
use std::io;
|
||||
|
||||
/// Creates a client transport and server listener for our tests
|
||||
/// that are connected together
|
||||
async fn setup() -> (
|
||||
FramedTransport<InmemoryTransport, PlainCodec>,
|
||||
OneshotListener<FramedTransport<InmemoryTransport, PlainCodec>>,
|
||||
) {
|
||||
let (t1, t2) = InmemoryTransport::pair(100);
|
||||
|
||||
let listener = OneshotListener::from_value(FramedTransport::new(t2, PlainCodec));
|
||||
let transport = FramedTransport::new(t1, PlainCodec);
|
||||
(transport, listener)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_be_able_to_establish_a_single_connection_and_communicate() {
|
||||
let (transport, listener) = setup().await;
|
||||
|
||||
let config = DistantManagerConfig::default();
|
||||
let manager_ref = DistantManager::start(config, listener).expect("Failed to start manager");
|
||||
|
||||
// NOTE: To pass in a raw function, we HAVE to specify the types of the parameters manually,
|
||||
// otherwise we get a compilation error about lifetime mismatches
|
||||
manager_ref
|
||||
.register_connect_handler("scheme", |_: &_, _: &_, _: &mut _| async {
|
||||
use distant_core::net::ServerExt;
|
||||
let (t1, t2) = FramedTransport::pair(100);
|
||||
|
||||
// Spawn a server on one end
|
||||
let _ = DistantApiServer::local()
|
||||
.unwrap()
|
||||
.start(OneshotListener::from_value(t2.into_split()))?;
|
||||
|
||||
// Create a reader/writer pair on the other end
|
||||
let (writer, reader) = t1.into_split();
|
||||
let writer: BoxedDistantWriter = Box::new(writer);
|
||||
let reader: BoxedDistantReader = Box::new(reader);
|
||||
Ok((writer, reader))
|
||||
})
|
||||
.await
|
||||
.expect("Failed to register handler");
|
||||
|
||||
let config = DistantManagerClientConfig::with_empty_prompts();
|
||||
let mut client =
|
||||
DistantManagerClient::new(config, transport).expect("Failed to connect to manager");
|
||||
|
||||
// Test establishing a connection to some remote server
|
||||
let id = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Extra>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to connect to a remote server");
|
||||
|
||||
// Test retrieving list of connections
|
||||
let list = client
|
||||
.list()
|
||||
.await
|
||||
.expect("Failed to get list of connections");
|
||||
assert_eq!(list.len(), 1);
|
||||
assert_eq!(list.get(&id).unwrap().to_string(), "scheme://host");
|
||||
|
||||
// Test retrieving information
|
||||
let info = client
|
||||
.info(id)
|
||||
.await
|
||||
.expect("Failed to get info about connection");
|
||||
assert_eq!(info.id, id);
|
||||
assert_eq!(info.destination.to_string(), "scheme://host");
|
||||
assert_eq!(info.extra, "key=value".parse::<Extra>().unwrap());
|
||||
|
||||
// Create a new channel and request some data
|
||||
let mut channel = client
|
||||
.open_channel(id)
|
||||
.await
|
||||
.expect("Failed to open channel");
|
||||
let _ = channel
|
||||
.system_info()
|
||||
.await
|
||||
.expect("Failed to get system information");
|
||||
|
||||
// Test killing a connection
|
||||
client.kill(id).await.expect("Failed to kill connection");
|
||||
|
||||
// Test getting an error to ensure that serialization of that data works,
|
||||
// which we do by trying to access a connection that no longer exists
|
||||
let err = client.info(id).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::NotConnected);
|
||||
}
|
@ -1 +0,0 @@
|
||||
mod watch;
|
@ -1,68 +0,0 @@
|
||||
use crate::stress::utils;
|
||||
use distant_core::{DistantApiServer, DistantClient, LocalDistantApi};
|
||||
use distant_net::{
|
||||
PortRange, SecretKey, SecretKey32, TcpClientExt, TcpServerExt, XChaCha20Poly1305Codec,
|
||||
};
|
||||
use rstest::*;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
const LOG_PATH: &str = "/tmp/test.distant.server.log";
|
||||
|
||||
pub struct DistantClientCtx {
|
||||
pub client: DistantClient,
|
||||
_done_tx: mpsc::Sender<()>,
|
||||
}
|
||||
|
||||
impl DistantClientCtx {
|
||||
pub async fn initialize() -> Self {
|
||||
let ip_addr = "127.0.0.1".parse().unwrap();
|
||||
let (done_tx, mut done_rx) = mpsc::channel::<()>(1);
|
||||
let (started_tx, mut started_rx) = mpsc::channel::<(u16, SecretKey32)>(1);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let logger = utils::init_logging(LOG_PATH);
|
||||
let key = SecretKey::default();
|
||||
let codec = XChaCha20Poly1305Codec::from(key.clone());
|
||||
|
||||
if let Ok(api) = LocalDistantApi::initialize() {
|
||||
let port: PortRange = "0".parse().unwrap();
|
||||
let port = {
|
||||
let server_ref = DistantApiServer::new(api)
|
||||
.start(ip_addr, port, codec)
|
||||
.await
|
||||
.unwrap();
|
||||
server_ref.port()
|
||||
};
|
||||
|
||||
started_tx.send((port, key)).await.unwrap();
|
||||
let _ = done_rx.recv().await;
|
||||
}
|
||||
|
||||
logger.flush();
|
||||
logger.shutdown();
|
||||
});
|
||||
|
||||
// Extract our server startup data if we succeeded
|
||||
let (port, key) = started_rx.recv().await.unwrap();
|
||||
|
||||
// Now initialize our client
|
||||
let client = DistantClient::connect_timeout(
|
||||
format!("{}:{}", ip_addr, port).parse().unwrap(),
|
||||
XChaCha20Poly1305Codec::from(key),
|
||||
Duration::from_secs(1),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
DistantClientCtx {
|
||||
client,
|
||||
_done_tx: done_tx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
pub async fn ctx() -> DistantClientCtx {
|
||||
DistantClientCtx::initialize().await
|
||||
}
|
@ -1,23 +0,0 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Initializes logging (should only call once)
|
||||
pub fn init_logging(path: impl Into<PathBuf>) -> flexi_logger::LoggerHandle {
|
||||
use flexi_logger::{FileSpec, LevelFilter, LogSpecification, Logger};
|
||||
let modules = &["distant", "distant_core", "distant_ssh2"];
|
||||
|
||||
// Disable logging for everything but our binary, which is based on verbosity
|
||||
let mut builder = LogSpecification::builder();
|
||||
builder.default(LevelFilter::Off);
|
||||
|
||||
// For each module, configure logging
|
||||
for module in modules {
|
||||
builder.module(module, LevelFilter::Trace);
|
||||
}
|
||||
|
||||
// Create our logger, but don't initialize yet
|
||||
let logger = Logger::with(builder.build())
|
||||
.format_for_files(flexi_logger::opt_format)
|
||||
.log_to_file(FileSpec::try_from(path).expect("Failed to create log file spec"));
|
||||
|
||||
logger.start().expect("Failed to initialize logger")
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
[package]
|
||||
name = "distant-local"
|
||||
description = "Library implementing distant API for local interactions"
|
||||
categories = ["network-programming"]
|
||||
version = "0.20.0"
|
||||
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
|
||||
edition = "2021"
|
||||
homepage = "https://github.com/chipsenkbeil/distant"
|
||||
repository = "https://github.com/chipsenkbeil/distant"
|
||||
readme = "README.md"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[features]
|
||||
default = ["macos-fsevent"]
|
||||
|
||||
# If specified, will use MacOS FSEvent for file watching
|
||||
macos-fsevent = ["notify/macos_fsevent"]
|
||||
|
||||
# If specified, will use MacOS kqueue for file watching
|
||||
macos-kqueue = ["notify/macos_kqueue"]
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.68"
|
||||
distant-core = { version = "=0.20.0", path = "../distant-core" }
|
||||
grep = "0.2.12"
|
||||
ignore = "0.4.20"
|
||||
log = "0.4.18"
|
||||
notify = { version = "6.0.0", default-features = false, features = ["macos_fsevent"] }
|
||||
notify-debouncer-full = { version = "0.1.0", default-features = false }
|
||||
num_cpus = "1.15.0"
|
||||
portable-pty = "0.8.1"
|
||||
rand = { version = "0.8.5", features = ["getrandom"] }
|
||||
shell-words = "1.1.0"
|
||||
tokio = { version = "1.28.2", features = ["full"] }
|
||||
walkdir = "2.3.3"
|
||||
whoami = "1.4.0"
|
||||
winsplit = "0.1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
assert_fs = "1.0.13"
|
||||
env_logger = "0.10.0"
|
||||
indoc = "2.0.1"
|
||||
once_cell = "1.17.2"
|
||||
predicates = "3.0.3"
|
||||
rstest = "0.17.0"
|
||||
test-log = "0.2.11"
|
@ -0,0 +1,45 @@
|
||||
# distant local
|
||||
|
||||
[![Crates.io][distant_crates_img]][distant_crates_lnk] [![Docs.rs][distant_doc_img]][distant_doc_lnk] [![Rustc 1.70.0][distant_rustc_img]][distant_rustc_lnk]
|
||||
|
||||
[distant_crates_img]: https://img.shields.io/crates/v/distant-local.svg
|
||||
[distant_crates_lnk]: https://crates.io/crates/distant-local
|
||||
[distant_doc_img]: https://docs.rs/distant-local/badge.svg
|
||||
[distant_doc_lnk]: https://docs.rs/distant-local
|
||||
[distant_rustc_img]: https://img.shields.io/badge/distant_local-rustc_1.70+-lightgray.svg
|
||||
[distant_rustc_lnk]: https://blog.rust-lang.org/2023/06/01/Rust-1.70.0.html
|
||||
|
||||
## Details
|
||||
|
||||
The `distant-local` library acts as the primary implementation of a distant
|
||||
server that powers the CLI. The logic acts on the local machine of the server
|
||||
and is designed to be used as the foundation for distant operation handling.
|
||||
|
||||
## Installation
|
||||
|
||||
You can import the dependency by adding the following to your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
distant-local = "0.20"
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
```rust,no_run
|
||||
use distant_local::{Config, new_handler};
|
||||
|
||||
// Create a server API handler to be used with the server
|
||||
let handler = new_handler(Config::default()).unwrap();
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under either of
|
||||
|
||||
Apache License, Version 2.0, (LICENSE-APACHE or
|
||||
[apache-license][apache-license]) MIT license (LICENSE-MIT or
|
||||
[mit-license][mit-license]) at your option.
|
||||
|
||||
[apache-license]: http://www.apache.org/licenses/LICENSE-2.0
|
||||
[mit-license]: http://opensource.org/licenses/MIT
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,9 @@
|
||||
use crate::data::{ProcessId, PtySize};
|
||||
use std::{future::Future, pin::Pin};
|
||||
use tokio::{io, sync::mpsc};
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
use distant_core::protocol::{ProcessId, PtySize};
|
||||
use tokio::io;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
mod pty;
|
||||
pub use pty::*;
|
@ -1,4 +1,5 @@
|
||||
use tokio::{io, sync::mpsc};
|
||||
use tokio::io;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Exit status of a remote process
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
@ -0,0 +1,36 @@
|
||||
use std::io;
|
||||
|
||||
use crate::config::Config;
|
||||
|
||||
mod process;
|
||||
pub use process::*;
|
||||
|
||||
mod search;
|
||||
pub use search::*;
|
||||
|
||||
mod watcher;
|
||||
pub use watcher::*;
|
||||
|
||||
/// Holds global state state managed by the server
|
||||
pub struct GlobalState {
|
||||
/// State that holds information about processes running on the server
|
||||
pub process: ProcessState,
|
||||
|
||||
/// State that holds information about searches running on the server
|
||||
pub search: SearchState,
|
||||
|
||||
/// Watcher used for filesystem events
|
||||
pub watcher: WatcherState,
|
||||
}
|
||||
|
||||
impl GlobalState {
|
||||
pub fn initialize(config: Config) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
process: ProcessState::new(),
|
||||
search: SearchState::new(),
|
||||
watcher: WatcherBuilder::new()
|
||||
.with_config(config.watch)
|
||||
.initialize()?,
|
||||
})
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,429 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::ops::Deref;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use distant_core::net::common::ConnectionId;
|
||||
use distant_core::protocol::{Change, ChangeDetails, ChangeDetailsAttribute, ChangeKind};
|
||||
use log::*;
|
||||
use notify::event::{AccessKind, AccessMode, MetadataKind, ModifyKind, RenameMode};
|
||||
use notify::{
|
||||
Config as WatcherConfig, Error as WatcherError, ErrorKind as WatcherErrorKind,
|
||||
Event as WatcherEvent, EventKind, PollWatcher, RecommendedWatcher, RecursiveMode, Watcher,
|
||||
};
|
||||
use notify_debouncer_full::{new_debouncer_opt, DebounceEventResult, Debouncer, FileIdMap};
|
||||
use tokio::sync::mpsc::error::TrySendError;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::config::WatchConfig;
|
||||
use crate::constants::SERVER_WATCHER_CAPACITY;
|
||||
|
||||
mod path;
|
||||
pub use path::*;
|
||||
|
||||
/// Builder for a watcher.
|
||||
#[derive(Default)]
|
||||
pub struct WatcherBuilder {
|
||||
config: WatchConfig,
|
||||
}
|
||||
|
||||
impl WatcherBuilder {
|
||||
/// Creates a new builder configured to use the native watcher using default configuration.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Swaps the configuration with the provided one.
|
||||
pub fn with_config(self, config: WatchConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Will create a watcher and initialize watched paths to be empty
|
||||
pub fn initialize(self) -> io::Result<WatcherState> {
|
||||
// NOTE: Cannot be something small like 1 as this seems to cause a deadlock sometimes
|
||||
// with a large volume of watch requests
|
||||
let (tx, rx) = mpsc::channel(SERVER_WATCHER_CAPACITY);
|
||||
|
||||
let watcher_config = WatcherConfig::default()
|
||||
.with_compare_contents(self.config.compare_contents)
|
||||
.with_poll_interval(self.config.poll_interval.unwrap_or(Duration::from_secs(30)));
|
||||
|
||||
macro_rules! process_event {
|
||||
($tx:ident, $evt:expr) => {
|
||||
match $tx.try_send(match $evt {
|
||||
Ok(x) => InnerWatcherMsg::Event { ev: x },
|
||||
Err(x) => InnerWatcherMsg::Error { err: x },
|
||||
}) {
|
||||
Ok(_) => (),
|
||||
Err(TrySendError::Full(_)) => {
|
||||
warn!(
|
||||
"Reached watcher capacity of {}! Dropping watcher event!",
|
||||
SERVER_WATCHER_CAPACITY,
|
||||
);
|
||||
}
|
||||
Err(TrySendError::Closed(_)) => {
|
||||
warn!("Skipping watch event because watcher channel closed");
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! new_debouncer {
|
||||
($watcher:ident, $tx:ident) => {{
|
||||
new_debouncer_opt::<_, $watcher, FileIdMap>(
|
||||
self.config.debounce_timeout,
|
||||
self.config.debounce_tick_rate,
|
||||
move |result: DebounceEventResult| match result {
|
||||
Ok(events) => {
|
||||
for x in events {
|
||||
process_event!($tx, Ok(x));
|
||||
}
|
||||
}
|
||||
Err(errors) => {
|
||||
for x in errors {
|
||||
process_event!($tx, Err(x));
|
||||
}
|
||||
}
|
||||
},
|
||||
FileIdMap::new(),
|
||||
watcher_config,
|
||||
)
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! spawn_task {
|
||||
($debouncer:expr) => {{
|
||||
WatcherState {
|
||||
channel: WatcherChannel { tx },
|
||||
task: tokio::spawn(watcher_task($debouncer, rx)),
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
let tx = tx.clone();
|
||||
if self.config.native {
|
||||
let result = {
|
||||
let tx = tx.clone();
|
||||
new_debouncer!(RecommendedWatcher, tx)
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(debouncer) => Ok(spawn_task!(debouncer)),
|
||||
Err(x) => {
|
||||
match x.kind {
|
||||
// notify-rs has a bug on Mac M1 with Docker and Linux, so we detect that error
|
||||
// and fall back to the poll watcher if this occurs
|
||||
//
|
||||
// https://github.com/notify-rs/notify/issues/423
|
||||
WatcherErrorKind::Io(x) if x.raw_os_error() == Some(38) => {
|
||||
warn!("Recommended watcher is unsupported! Falling back to polling watcher!");
|
||||
Ok(spawn_task!(new_debouncer!(PollWatcher, tx)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?))
|
||||
}
|
||||
_ => Err(io::Error::new(io::ErrorKind::Other, x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(spawn_task!(new_debouncer!(PollWatcher, tx)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds information related to watched paths on the server
|
||||
pub struct WatcherState {
|
||||
channel: WatcherChannel,
|
||||
task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Drop for WatcherState {
|
||||
/// Aborts the task that handles watcher path operations and management
|
||||
fn drop(&mut self) {
|
||||
self.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl WatcherState {
|
||||
/// Aborts the watcher task
|
||||
pub fn abort(&self) {
|
||||
self.task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for WatcherState {
|
||||
type Target = WatcherChannel;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.channel
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WatcherChannel {
|
||||
tx: mpsc::Sender<InnerWatcherMsg>,
|
||||
}
|
||||
|
||||
impl Default for WatcherChannel {
|
||||
/// Creates a new channel that is closed by default
|
||||
fn default() -> Self {
|
||||
let (tx, _) = mpsc::channel(1);
|
||||
Self { tx }
|
||||
}
|
||||
}
|
||||
|
||||
impl WatcherChannel {
|
||||
/// Watch a path for a specific connection denoted by the id within the registered path
|
||||
pub async fn watch(&self, registered_path: RegisteredPath) -> io::Result<()> {
|
||||
let (cb, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(InnerWatcherMsg::Watch {
|
||||
registered_path,
|
||||
cb,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal watcher task closed"))?;
|
||||
rx.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to watch dropped"))?
|
||||
}
|
||||
|
||||
/// Unwatch a path for a specific connection denoted by the id
|
||||
pub async fn unwatch(&self, id: ConnectionId, path: impl AsRef<Path>) -> io::Result<()> {
|
||||
let (cb, rx) = oneshot::channel();
|
||||
let path = tokio::fs::canonicalize(path.as_ref())
|
||||
.await
|
||||
.unwrap_or_else(|_| path.as_ref().to_path_buf());
|
||||
self.tx
|
||||
.send(InnerWatcherMsg::Unwatch { id, path, cb })
|
||||
.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Internal watcher task closed"))?;
|
||||
rx.await
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to unwatch dropped"))?
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal message to pass to our task below to perform some action
|
||||
enum InnerWatcherMsg {
|
||||
Watch {
|
||||
registered_path: RegisteredPath,
|
||||
cb: oneshot::Sender<io::Result<()>>,
|
||||
},
|
||||
Unwatch {
|
||||
id: ConnectionId,
|
||||
path: PathBuf,
|
||||
cb: oneshot::Sender<io::Result<()>>,
|
||||
},
|
||||
Event {
|
||||
ev: WatcherEvent,
|
||||
},
|
||||
Error {
|
||||
err: WatcherError,
|
||||
},
|
||||
}
|
||||
|
||||
async fn watcher_task<W>(
|
||||
mut debouncer: Debouncer<W, FileIdMap>,
|
||||
mut rx: mpsc::Receiver<InnerWatcherMsg>,
|
||||
) where
|
||||
W: Watcher,
|
||||
{
|
||||
// TODO: Optimize this in some way to be more performant than
|
||||
// checking every path whenever an event comes in
|
||||
let mut registered_paths: Vec<RegisteredPath> = Vec::new();
|
||||
let mut path_cnt: HashMap<PathBuf, usize> = HashMap::new();
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
match msg {
|
||||
InnerWatcherMsg::Watch {
|
||||
registered_path,
|
||||
cb,
|
||||
} => {
|
||||
// Check if we are tracking the path across any connection
|
||||
if let Some(cnt) = path_cnt.get_mut(registered_path.path()) {
|
||||
// Increment the count of times we are watching that path
|
||||
*cnt += 1;
|
||||
|
||||
// Store the registered path in our collection without worry
|
||||
// since we are already watching a path that impacts this one
|
||||
registered_paths.push(registered_path);
|
||||
|
||||
// Send an okay because we always succeed in this case
|
||||
let _ = cb.send(Ok(()));
|
||||
} else {
|
||||
let res = debouncer
|
||||
.watcher()
|
||||
.watch(
|
||||
registered_path.path(),
|
||||
if registered_path.is_recursive() {
|
||||
RecursiveMode::Recursive
|
||||
} else {
|
||||
RecursiveMode::NonRecursive
|
||||
},
|
||||
)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x));
|
||||
|
||||
// If we succeeded, store our registered path and set the tracking cnt to 1
|
||||
if res.is_ok() {
|
||||
path_cnt.insert(registered_path.path().to_path_buf(), 1);
|
||||
registered_paths.push(registered_path);
|
||||
}
|
||||
|
||||
// Send the result of the watch, but don't worry if the channel was closed
|
||||
let _ = cb.send(res);
|
||||
}
|
||||
}
|
||||
InnerWatcherMsg::Unwatch { id, path, cb } => {
|
||||
// Check if we are tracking the path across any connection
|
||||
if let Some(cnt) = path_cnt.get(path.as_path()) {
|
||||
// Cycle through and remove all paths that match the given id and path,
|
||||
// capturing how many paths we removed
|
||||
let removed_cnt = {
|
||||
let old_len = registered_paths.len();
|
||||
registered_paths
|
||||
.retain(|p| p.id() != id || (p.path() != path && p.raw_path() != path));
|
||||
let new_len = registered_paths.len();
|
||||
old_len - new_len
|
||||
};
|
||||
|
||||
// 1. If we are now at zero cnt for our path, we want to actually unwatch the
|
||||
// path with our watcher
|
||||
// 2. If we removed nothing from our path list, we want to return an error
|
||||
// 3. Otherwise, we return okay because we succeeded
|
||||
if *cnt <= removed_cnt {
|
||||
let _ = cb.send(
|
||||
debouncer
|
||||
.watcher()
|
||||
.unwatch(&path)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x)),
|
||||
);
|
||||
} else if removed_cnt == 0 {
|
||||
// Send a failure as there was nothing to unwatch for this connection
|
||||
let _ = cb.send(Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("{path:?} is not being watched"),
|
||||
)));
|
||||
} else {
|
||||
// Send a success as we removed some paths
|
||||
let _ = cb.send(Ok(()));
|
||||
}
|
||||
} else {
|
||||
// Send a failure as there was nothing to unwatch
|
||||
let _ = cb.send(Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("{path:?} is not being watched"),
|
||||
)));
|
||||
}
|
||||
}
|
||||
InnerWatcherMsg::Event { ev } => {
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("System time before unix epoch")
|
||||
.as_secs();
|
||||
|
||||
let kind = match ev.kind {
|
||||
EventKind::Access(AccessKind::Read) => ChangeKind::Access,
|
||||
EventKind::Modify(ModifyKind::Metadata(_)) => ChangeKind::Attribute,
|
||||
EventKind::Access(AccessKind::Close(AccessMode::Write)) => {
|
||||
ChangeKind::CloseWrite
|
||||
}
|
||||
EventKind::Access(AccessKind::Close(_)) => ChangeKind::CloseNoWrite,
|
||||
EventKind::Create(_) => ChangeKind::Create,
|
||||
EventKind::Remove(_) => ChangeKind::Delete,
|
||||
EventKind::Modify(ModifyKind::Data(_)) => ChangeKind::Modify,
|
||||
EventKind::Access(AccessKind::Open(_)) => ChangeKind::Open,
|
||||
EventKind::Modify(ModifyKind::Name(_)) => ChangeKind::Rename,
|
||||
_ => ChangeKind::Unknown,
|
||||
};
|
||||
|
||||
for registered_path in registered_paths.iter() {
|
||||
// For rename both, we assume the paths is a pair that represents before and
|
||||
// after, so we want to grab the before and use it!
|
||||
let (paths, renamed): (&[PathBuf], Option<PathBuf>) = match ev.kind {
|
||||
EventKind::Modify(ModifyKind::Name(RenameMode::Both)) => (
|
||||
&ev.paths[0..1],
|
||||
if ev.paths.len() > 1 {
|
||||
ev.paths.last().cloned()
|
||||
} else {
|
||||
None
|
||||
},
|
||||
),
|
||||
_ => (&ev.paths, None),
|
||||
};
|
||||
|
||||
for path in paths {
|
||||
let attribute = match ev.kind {
|
||||
EventKind::Modify(ModifyKind::Metadata(MetadataKind::Ownership)) => {
|
||||
Some(ChangeDetailsAttribute::Ownership)
|
||||
}
|
||||
EventKind::Modify(ModifyKind::Metadata(MetadataKind::Permissions)) => {
|
||||
Some(ChangeDetailsAttribute::Permissions)
|
||||
}
|
||||
EventKind::Modify(ModifyKind::Metadata(MetadataKind::WriteTime)) => {
|
||||
Some(ChangeDetailsAttribute::Timestamp)
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Calculate a timestamp for creation & modification paths
|
||||
let details_timestamp = match ev.kind {
|
||||
EventKind::Create(_) => tokio::fs::symlink_metadata(path.as_path())
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|m| m.created().ok())
|
||||
.and_then(|t| t.duration_since(UNIX_EPOCH).ok())
|
||||
.map(|d| d.as_secs()),
|
||||
EventKind::Modify(_) => tokio::fs::symlink_metadata(path.as_path())
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|m| m.modified().ok())
|
||||
.and_then(|t| t.duration_since(UNIX_EPOCH).ok())
|
||||
.map(|d| d.as_secs()),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let change = Change {
|
||||
timestamp,
|
||||
kind,
|
||||
path: path.to_path_buf(),
|
||||
details: ChangeDetails {
|
||||
attribute,
|
||||
renamed: renamed.clone(),
|
||||
timestamp: details_timestamp,
|
||||
extra: ev.info().map(ToString::to_string),
|
||||
},
|
||||
};
|
||||
match registered_path.filter_and_send(change) {
|
||||
Ok(_) => (),
|
||||
Err(x) => error!(
|
||||
"[Conn {}] Failed to forward changes to paths: {}",
|
||||
registered_path.id(),
|
||||
x
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
InnerWatcherMsg::Error { err } => {
|
||||
let msg = err.to_string();
|
||||
error!("Watcher encountered an error {} for {:?}", msg, err.paths);
|
||||
|
||||
for registered_path in registered_paths.iter() {
|
||||
match registered_path.filter_and_send_error(
|
||||
&msg,
|
||||
&err.paths,
|
||||
!err.paths.is_empty(),
|
||||
) {
|
||||
Ok(_) => (),
|
||||
Err(x) => error!(
|
||||
"[Conn {}] Failed to forward changes to paths: {}",
|
||||
registered_path.id(),
|
||||
x
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
pub struct Config {
|
||||
pub watch: WatchConfig,
|
||||
}
|
||||
|
||||
/// Configuration specifically for watching files and directories.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct WatchConfig {
|
||||
pub native: bool,
|
||||
pub poll_interval: Option<Duration>,
|
||||
pub compare_contents: bool,
|
||||
pub debounce_timeout: Duration,
|
||||
pub debounce_tick_rate: Option<Duration>,
|
||||
}
|
||||
|
||||
impl Default for WatchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
native: true,
|
||||
poll_interval: None,
|
||||
compare_contents: false,
|
||||
debounce_timeout: Duration::from_millis(500),
|
||||
debounce_tick_rate: None,
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
use std::time::Duration;
|
||||
|
||||
/// Capacity associated with the server's file watcher to pass events outbound
|
||||
pub const SERVER_WATCHER_CAPACITY: usize = 10000;
|
||||
|
||||
/// Represents the maximum size (in bytes) that data will be read from pipes
|
||||
/// per individual `read` call
|
||||
///
|
||||
/// Current setting is 16k size
|
||||
pub const MAX_PIPE_CHUNK_SIZE: usize = 16384;
|
||||
|
||||
/// Duration in milliseconds to sleep between reading stdout/stderr chunks
|
||||
/// to avoid sending many small messages to clients
|
||||
pub const READ_PAUSE_DURATION: Duration = Duration::from_millis(1);
|
@ -0,0 +1,20 @@
|
||||
#![doc = include_str!("../README.md")]
|
||||
|
||||
#[doc = include_str!("../README.md")]
|
||||
#[cfg(doctest)]
|
||||
pub struct ReadmeDoctests;
|
||||
|
||||
mod api;
|
||||
mod config;
|
||||
mod constants;
|
||||
pub use api::Api;
|
||||
pub use config::*;
|
||||
use distant_core::DistantApiServerHandler;
|
||||
|
||||
/// Implementation of [`DistantApiServerHandler`] using [`Api`].
|
||||
pub type Handler = DistantApiServerHandler<Api>;
|
||||
|
||||
/// Initializes a new [`Handler`].
|
||||
pub fn new_handler(config: Config) -> std::io::Result<Handler> {
|
||||
Ok(Handler::new(Api::initialize(config)?))
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
use assert_fs::prelude::*;
|
||||
use distant_core::DistantChannelExt;
|
||||
use rstest::*;
|
||||
use test_log::test;
|
||||
|
||||
use crate::stress::fixtures::*;
|
||||
|
||||
// 64KB is maximum TCP packet size
|
||||
const MAX_TCP_PACKET_BYTES: usize = 65535;
|
||||
|
||||
// 640KB should be big enough to cause problems
|
||||
const LARGE_FILE_LEN: usize = MAX_TCP_PACKET_BYTES * 10;
|
||||
|
||||
#[rstest]
|
||||
#[test(tokio::test)]
|
||||
async fn should_handle_large_files(#[future] ctx: DistantClientCtx) {
|
||||
let ctx = ctx.await;
|
||||
let mut channel = ctx.client.clone_channel();
|
||||
|
||||
let root = assert_fs::TempDir::new().unwrap();
|
||||
|
||||
// Generate data
|
||||
eprintln!("Creating random data of size: {LARGE_FILE_LEN}");
|
||||
let mut data = Vec::with_capacity(LARGE_FILE_LEN);
|
||||
for i in 0..LARGE_FILE_LEN {
|
||||
data.push(i as u8);
|
||||
}
|
||||
|
||||
// Create our large file to read, write, and append
|
||||
let file = root.child("large_file.dat");
|
||||
eprintln!("Writing random file: {:?}", file.path());
|
||||
file.write_binary(&data)
|
||||
.expect("Failed to write large file");
|
||||
|
||||
// Perform the read
|
||||
eprintln!("Reading file using distant");
|
||||
let mut new_data = channel
|
||||
.read_file(file.path())
|
||||
.await
|
||||
.expect("Failed to read large file");
|
||||
assert_eq!(new_data, data, "Data mismatch");
|
||||
|
||||
// Perform the write after modifying one byte
|
||||
eprintln!("Writing file using distant");
|
||||
new_data[LARGE_FILE_LEN - 1] = new_data[LARGE_FILE_LEN - 1].overflowing_add(1).0;
|
||||
channel
|
||||
.write_file(file.path(), new_data.clone())
|
||||
.await
|
||||
.expect("Failed to write large file");
|
||||
let data = tokio::fs::read(file.path())
|
||||
.await
|
||||
.expect("Failed to read large file");
|
||||
assert_eq!(new_data, data, "Data was not written correctly");
|
||||
|
||||
// Perform append
|
||||
eprintln!("Appending to file using distant");
|
||||
channel
|
||||
.append_file(file.path(), vec![1, 2, 3])
|
||||
.await
|
||||
.expect("Failed to append to large file");
|
||||
let new_data = tokio::fs::read(file.path())
|
||||
.await
|
||||
.expect("Failed to read large file");
|
||||
assert_eq!(new_data[new_data.len() - 3..], [1, 2, 3]);
|
||||
}
|
@ -0,0 +1,2 @@
|
||||
mod large_file;
|
||||
mod watch;
|
@ -0,0 +1,70 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
|
||||
use distant_core::net::auth::{DummyAuthHandler, Verifier};
|
||||
use distant_core::net::client::{Client, TcpConnector};
|
||||
use distant_core::net::common::PortRange;
|
||||
use distant_core::net::server::Server;
|
||||
use distant_core::{DistantApiServerHandler, DistantClient};
|
||||
use distant_local::Api;
|
||||
use rstest::*;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub struct DistantClientCtx {
|
||||
pub client: DistantClient,
|
||||
_done_tx: mpsc::Sender<()>,
|
||||
}
|
||||
|
||||
impl DistantClientCtx {
|
||||
pub async fn initialize() -> Self {
|
||||
let ip_addr = "127.0.0.1".parse().unwrap();
|
||||
let (done_tx, mut done_rx) = mpsc::channel::<()>(1);
|
||||
let (started_tx, mut started_rx) = mpsc::channel::<u16>(1);
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Ok(api) = Api::initialize(Default::default()) {
|
||||
let port: PortRange = "0".parse().unwrap();
|
||||
let port = {
|
||||
let handler = DistantApiServerHandler::new(api);
|
||||
let server_ref = Server::new()
|
||||
.handler(handler)
|
||||
.verifier(Verifier::none())
|
||||
.into_tcp_builder()
|
||||
.start(ip_addr, port)
|
||||
.await
|
||||
.unwrap();
|
||||
server_ref.port()
|
||||
};
|
||||
|
||||
started_tx.send(port).await.unwrap();
|
||||
let _ = done_rx.recv().await;
|
||||
}
|
||||
});
|
||||
|
||||
// Extract our server startup data if we succeeded
|
||||
let port = started_rx.recv().await.unwrap();
|
||||
|
||||
// Now initialize our client
|
||||
let client: DistantClient = Client::build()
|
||||
.auth_handler(DummyAuthHandler)
|
||||
.connect_timeout(Duration::from_secs(1))
|
||||
.connector(TcpConnector::new(
|
||||
format!("{}:{}", ip_addr, port)
|
||||
.parse::<SocketAddr>()
|
||||
.unwrap(),
|
||||
))
|
||||
.connect()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
DistantClientCtx {
|
||||
client,
|
||||
_done_tx: done_tx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
pub async fn ctx() -> DistantClientCtx {
|
||||
DistantClientCtx::initialize().await
|
||||
}
|
@ -1,3 +1,2 @@
|
||||
mod distant;
|
||||
mod fixtures;
|
||||
mod utils;
|
@ -1,122 +0,0 @@
|
||||
use derive_more::Display;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
mod client;
|
||||
pub use client::*;
|
||||
|
||||
mod handshake;
|
||||
pub use handshake::*;
|
||||
|
||||
mod server;
|
||||
pub use server::*;
|
||||
|
||||
/// Represents authentication messages that can be sent over the wire
|
||||
///
|
||||
/// NOTE: Must use serde's content attribute with the tag attribute. Just the tag attribute will
|
||||
/// cause deserialization to fail
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type", content = "data")]
|
||||
pub enum Auth {
|
||||
/// Represents a request to perform an authentication handshake,
|
||||
/// providing the public key and salt from one side in order to
|
||||
/// derive the shared key
|
||||
#[serde(rename = "auth_handshake")]
|
||||
Handshake {
|
||||
/// Bytes of the public key
|
||||
#[serde(with = "serde_bytes")]
|
||||
public_key: PublicKeyBytes,
|
||||
|
||||
/// Randomly generated salt
|
||||
#[serde(with = "serde_bytes")]
|
||||
salt: Salt,
|
||||
},
|
||||
|
||||
/// Represents the bytes of an encrypted message
|
||||
///
|
||||
/// Underneath, will be one of either [`AuthRequest`] or [`AuthResponse`]
|
||||
#[serde(rename = "auth_msg")]
|
||||
Msg {
|
||||
#[serde(with = "serde_bytes")]
|
||||
encrypted_payload: Vec<u8>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Represents authentication messages that act as initiators such as providing
|
||||
/// a challenge, verifying information, presenting information, or highlighting an error
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthRequest {
|
||||
/// Represents a challenge comprising a series of questions to be presented
|
||||
Challenge {
|
||||
questions: Vec<AuthQuestion>,
|
||||
extra: HashMap<String, String>,
|
||||
},
|
||||
|
||||
/// Represents an ask to verify some information
|
||||
Verify { kind: AuthVerifyKind, text: String },
|
||||
|
||||
/// Represents some information to be presented
|
||||
Info { text: String },
|
||||
|
||||
/// Represents some error that occurred
|
||||
Error { kind: AuthErrorKind, text: String },
|
||||
}
|
||||
|
||||
/// Represents authentication messages that are responses to auth requests such
|
||||
/// as answers to challenges or verifying information
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthResponse {
|
||||
/// Represents the answers to a previously-asked challenge
|
||||
Challenge { answers: Vec<String> },
|
||||
|
||||
/// Represents the answer to a previously-asked verify
|
||||
Verify { valid: bool },
|
||||
}
|
||||
|
||||
/// Represents the type of verification being requested
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[non_exhaustive]
|
||||
pub enum AuthVerifyKind {
|
||||
/// An ask to verify the host such as with SSH
|
||||
#[display(fmt = "host")]
|
||||
Host,
|
||||
}
|
||||
|
||||
/// Represents a single question in a challenge
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct AuthQuestion {
|
||||
/// The text of the question
|
||||
pub text: String,
|
||||
|
||||
/// Any extra information specific to a particular auth domain
|
||||
/// such as including a username and instructions for SSH authentication
|
||||
pub extra: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl AuthQuestion {
|
||||
/// Creates a new question without any extra data
|
||||
pub fn new(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
text: text.into(),
|
||||
extra: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the type of error encountered during authentication
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AuthErrorKind {
|
||||
/// When the answer(s) to a challenge do not pass authentication
|
||||
FailedChallenge,
|
||||
|
||||
/// When verification during authentication fails
|
||||
/// (e.g. a host is not allowed or blocked)
|
||||
FailedVerification,
|
||||
|
||||
/// When the error is unknown
|
||||
Unknown,
|
||||
}
|
@ -1,817 +0,0 @@
|
||||
use crate::{
|
||||
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Client,
|
||||
Codec, Handshake, XChaCha20Poly1305Codec,
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io};
|
||||
|
||||
pub struct AuthClient {
|
||||
inner: Client<Auth, Auth>,
|
||||
codec: Option<XChaCha20Poly1305Codec>,
|
||||
jit_handshake: bool,
|
||||
}
|
||||
|
||||
impl From<Client<Auth, Auth>> for AuthClient {
|
||||
fn from(client: Client<Auth, Auth>) -> Self {
|
||||
Self {
|
||||
inner: client,
|
||||
codec: None,
|
||||
jit_handshake: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthClient {
|
||||
/// Sends a request to the server to establish an encrypted connection
|
||||
pub async fn handshake(&mut self) -> io::Result<()> {
|
||||
let handshake = Handshake::default();
|
||||
|
||||
let response = self
|
||||
.inner
|
||||
.send(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let key = handshake.handshake(public_key, salt)?;
|
||||
self.codec.replace(XChaCha20Poly1305Codec::new(&key));
|
||||
Ok(())
|
||||
}
|
||||
Auth::Msg { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected encrypted message during handshake",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a handshake only if jit is enabled and no handshake has succeeded yet
|
||||
async fn jit_handshake(&mut self) -> io::Result<()> {
|
||||
if self.will_jit_handshake() && !self.is_ready() {
|
||||
self.handshake().await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if client has successfully performed a handshake
|
||||
/// and is ready to communicate with the server
|
||||
pub fn is_ready(&self) -> bool {
|
||||
self.codec.is_some()
|
||||
}
|
||||
|
||||
/// Returns true if this client will perform a handshake just-in-time (JIT) prior to making a
|
||||
/// request in the scenario where the client has not already performed a handshake
|
||||
#[inline]
|
||||
pub fn will_jit_handshake(&self) -> bool {
|
||||
self.jit_handshake
|
||||
}
|
||||
|
||||
/// Sets the jit flag on this client with `true` indicating that this client will perform a
|
||||
/// handshake just-in-time (JIT) prior to making a request in the scenario where the client has
|
||||
/// not already performed a handshake
|
||||
#[inline]
|
||||
pub fn set_jit_handshake(&mut self, flag: bool) {
|
||||
self.jit_handshake = flag;
|
||||
}
|
||||
|
||||
/// Provides a challenge to the server and returns the answers to the questions
|
||||
/// asked by the client
|
||||
pub async fn challenge(
|
||||
&mut self,
|
||||
questions: Vec<AuthQuestion>,
|
||||
extra: HashMap<String, String>,
|
||||
) -> io::Result<Vec<String>> {
|
||||
trace!(
|
||||
"AuthClient::challenge(questions = {:?}, extra = {:?})",
|
||||
questions,
|
||||
extra
|
||||
);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Challenge { questions, extra };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match self.decrypt_and_deserialize(&encrypted_payload)? {
|
||||
AuthResponse::Challenge { answers } => Ok(answers),
|
||||
AuthResponse::Verify { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected verify response during challenge",
|
||||
)),
|
||||
}
|
||||
}
|
||||
Auth::Handshake { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected handshake during challenge",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides a verification request to the server and returns whether or not
|
||||
/// the server approved
|
||||
pub async fn verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool> {
|
||||
trace!("AuthClient::verify(kind = {:?}, text = {:?})", kind, text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Verify { kind, text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match self.decrypt_and_deserialize(&encrypted_payload)? {
|
||||
AuthResponse::Verify { valid } => Ok(valid),
|
||||
AuthResponse::Challenge { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected challenge response during verify",
|
||||
)),
|
||||
}
|
||||
}
|
||||
Auth::Handshake { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected handshake during verify",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides information to the server to use as it pleases with no response expected
|
||||
pub async fn info(&mut self, text: String) -> io::Result<()> {
|
||||
trace!("AuthClient::info(text = {:?})", text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Info { text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
self.inner.fire(Auth::Msg { encrypted_payload }).await
|
||||
}
|
||||
|
||||
/// Provides an error to the server to use as it pleases with no response expected
|
||||
pub async fn error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> {
|
||||
trace!("AuthClient::error(kind = {:?}, text = {:?})", kind, text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Error { kind, text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
self.inner.fire(Auth::Msg { encrypted_payload }).await
|
||||
}
|
||||
|
||||
fn serialize_and_encrypt(&mut self, payload: &AuthRequest) -> io::Result<Vec<u8>> {
|
||||
let codec = self.codec.as_mut().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (client encrypt message)",
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut encryped_payload = BytesMut::new();
|
||||
let payload = utils::serialize_to_vec(payload)?;
|
||||
codec.encode(&payload, &mut encryped_payload)?;
|
||||
Ok(encryped_payload.freeze().to_vec())
|
||||
}
|
||||
|
||||
fn decrypt_and_deserialize(&mut self, payload: &[u8]) -> io::Result<AuthResponse> {
|
||||
let codec = self.codec.as_mut().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (client decrypt message)",
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut payload = BytesMut::from(payload);
|
||||
match codec.decode(&mut payload)? {
|
||||
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Client, FramedTransport, Request, Response, TypedAsyncRead, TypedAsyncWrite};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
const TIMEOUT_MILLIS: u64 = 100;
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_should_fail_if_get_unexpected_response_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.handshake().await });
|
||||
|
||||
// Get the request, but send a bad response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Handshake { .. } => server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: Vec::new(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap(),
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
}
|
||||
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Handshake succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.challenge(Vec::new(), HashMap::new()).await });
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Challenge succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_fail_if_receive_wrong_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.challenge(
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
extra: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Challenge { .. } => {
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Verify { valid: true },
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Challenge succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_return_answers_received_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.challenge(
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
extra: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Challenge { questions, extra } => {
|
||||
assert_eq!(
|
||||
questions,
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
extra: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
extra,
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
);
|
||||
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Challenge {
|
||||
answers: vec![
|
||||
"answer1".to_string(),
|
||||
"answer2".to_string(),
|
||||
],
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
let answers = task.await.unwrap().unwrap();
|
||||
assert_eq!(answers, vec!["answer1".to_string(), "answer2".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Verify succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_fail_if_receive_wrong_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a verify request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Verify { .. } => {
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Challenge {
|
||||
answers: Vec::new(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Verify succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_return_valid_bool_received_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Verify { kind, text } => {
|
||||
assert_eq!(kind, AuthVerifyKind::Host);
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Verify { valid: true },
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
let valid = task.await.unwrap().unwrap();
|
||||
assert!(valid, "Got verify response, but valid was set incorrectly");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.info("some text".to_string()).await });
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Info succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_send_the_server_a_request_but_not_wait_for_a_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client.info("some text".to_string()).await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a request
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Info { text } => {
|
||||
assert_eq!(text, "some text");
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
task.await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client
|
||||
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Error succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_should_send_the_server_a_request_but_not_wait_for_a_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a request
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Error { kind, text } => {
|
||||
assert_eq!(kind, AuthErrorKind::FailedChallenge);
|
||||
assert_eq!(text, "some text");
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
task.await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
async fn wait_ms(ms: u64) {
|
||||
use std::time::Duration;
|
||||
tokio::time::sleep(Duration::from_millis(ms)).await;
|
||||
}
|
||||
|
||||
fn serialize_and_encrypt<T: Serialize>(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &T,
|
||||
) -> io::Result<Vec<u8>> {
|
||||
let mut encryped_payload = BytesMut::new();
|
||||
let payload = utils::serialize_to_vec(payload)?;
|
||||
codec.encode(&payload, &mut encryped_payload)?;
|
||||
Ok(encryped_payload.freeze().to_vec())
|
||||
}
|
||||
|
||||
fn decrypt_and_deserialize<T: DeserializeOwned>(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &[u8],
|
||||
) -> io::Result<T> {
|
||||
let mut payload = BytesMut::from(payload);
|
||||
match codec.decode(&mut payload)? {
|
||||
Some(payload) => utils::deserialize_from_slice::<T>(&payload),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue