Create a new struct to hold receiver logic

v5-api
Dominik Nakamura 2 years ago
parent a536ee7f34
commit 04f4b75ee0
No known key found for this signature in database
GPG Key ID: E4C6A749B2491910

@ -79,7 +79,7 @@ pub struct Client {
/// A list of currently waiting requests to get a response back. The key is the string version
/// of a request ID and the value is a oneshot sender that allows to send the response back to
/// the other end that waits for the response.
receivers: Arc<Mutex<ReceiverList>>,
receivers: Arc<ReceiverList>,
/// A list of awaiting [`reidentify`](Self::reidentify) requests, waiting for confirmation. As
/// these requests don't carry any kind of ID, they're handled sequentially and must be tracked
/// separate from normal requests.
@ -97,8 +97,50 @@ pub struct Client {
/// Shorthand for the writer side of a web-socket stream that has been split into reader and writer.
type MessageWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
/// Shorthand for the list of ongoing requests that wait for a response.
type ReceiverList = HashMap<u64, oneshot::Sender<(Status, serde_json::Value)>>;
/// Wrapper for the list of ongoing requests that wait for response.
#[derive(Default)]
struct ReceiverList(Mutex<HashMap<u64, oneshot::Sender<(Status, serde_json::Value)>>>);
impl ReceiverList {
/// Add a new receiver to the wait list, that will be notified once a request with the given
/// ID is received.
async fn add(&self, id: u64) -> oneshot::Receiver<(Status, serde_json::Value)> {
let (tx, rx) = oneshot::channel();
self.0.lock().await.insert(id, tx);
rx
}
/// Remove a previously added receiver. Used to free up resources, in case sending the request
/// failed.
async fn remove(&self, id: u64) {
self.0.lock().await.remove(&id);
}
/// Notify a waiting receiver with the response to a request.
async fn notify(&self, response: RequestResponse) -> Result<(), InnerError> {
let RequestResponse {
request_type: _,
request_id,
request_status,
response_data,
} = response;
let request_id = request_id
.parse()
.map_err(|e| InnerError::InvalidRequestId(e, request_id))?;
if let Some(tx) = self.0.lock().await.remove(&request_id) {
tx.send((request_status, response_data)).ok();
}
Ok(())
}
/// Reset the list, cancelling any outstanding receivers.
async fn reset(&self) {
self.0.lock().await.clear();
}
}
/// Wrapper around a thread-safe queue to park and notify re-identify listener.
#[derive(Default)]
@ -118,6 +160,11 @@ impl ReidentifyReceiverList {
tx.send(identified).ok();
}
}
/// Reset the list, cancelling any outstanding receivers.
async fn reset(&self) {
self.0.lock().await.clear();
}
}
/// Default broadcast capacity used when not overwritten by the user.
@ -218,7 +265,7 @@ impl Client {
let (mut write, mut read) = socket.split();
let receivers = Arc::new(Mutex::new(HashMap::<_, oneshot::Sender<_>>::new()));
let receivers = Arc::new(ReceiverList::default());
let receivers2 = Arc::clone(&receivers);
let reidentify_receivers = Arc::new(ReidentifyReceiverList::default());
@ -257,20 +304,9 @@ impl Client {
.map_err(InnerError::DeserializeMessage)?;
match message {
ServerMessage::RequestResponse(RequestResponse {
request_type: _,
request_id,
request_status,
response_data,
}) => {
let request_id = request_id
.parse()
.map_err(|e| InnerError::InvalidRequestId(e, request_id))?;
debug!("got message with id {}", request_id);
if let Some(tx) = receivers2.lock().await.remove(&request_id) {
tx.send((request_status, response_data)).ok();
}
ServerMessage::RequestResponse(response) => {
trace!("got message with id {}", response.request_id);
receivers2.notify(response).await?;
}
#[cfg(feature = "events")]
ServerMessage::Event(event) => {
@ -296,7 +332,8 @@ impl Client {
// clear all outstanding receivers to stop them from waiting forever on responses
// they'll never receive.
receivers2.lock().await.clear();
receivers2.reset().await;
reidentify_receivers2.reset().await;
});
let write = Mutex::new(write);
@ -356,8 +393,7 @@ impl Client {
});
let json = serde_json::to_string(&req).map_err(Error::SerializeMessage)?;
let (tx, rx) = oneshot::channel();
self.receivers.lock().await.insert(id, tx);
let rx = self.receivers.add(id).await;
debug!("sending message: {}", json);
let write_result = self
@ -369,7 +405,7 @@ impl Client {
.map_err(Error::Send);
if let Err(e) = write_result {
self.receivers.lock().await.remove(&id);
self.receivers.remove(id).await;
return Err(e);
}

Loading…
Cancel
Save