Last active
July 25, 2022 21:23
-
-
Save cletustboone/036409522c609f460004d6708c6f4062 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{error::Error, fmt::Display, task::Poll, time::Duration}; | |
use futures::Future; | |
use pin_project::pin_project; | |
use tokio::time::{sleep, Sleep}; | |
use tower::{BoxError, Layer, Service}; | |
use tracing::{instrument, warn}; | |
use crate::gateway::GatewayRequest; | |
#[derive(Debug, Clone)] | |
pub struct CapabilityCheck<S> { | |
pub inner: S, | |
} | |
impl<S> CapabilityCheck<S> { | |
pub fn new(inner: S) -> Self { | |
Self { inner } | |
} | |
} | |
pub struct CapabilityCheckLayer {} | |
impl CapabilityCheckLayer { | |
pub fn new() -> Self { | |
Self {} | |
} | |
} | |
impl<S> Layer<S> for CapabilityCheckLayer { | |
type Service = CapabilityCheck<S>; | |
fn layer(&self, inner: S) -> Self::Service { | |
CapabilityCheck::new(inner) | |
} | |
} | |
// Implement tower::Service for CapbilibytCheck. | |
// It introduces a fake async delay, then checks the request, then passes it to | |
// the inner service if the request passes the capability check. | |
impl<S> Service<GatewayRequest> for CapabilityCheck<S> | |
where | |
S: Service<GatewayRequest> + std::fmt::Debug, | |
GatewayRequest: std::fmt::Debug, | |
S::Error: Into<BoxError>, | |
{ | |
type Response = S::Response; | |
type Error = BoxError; | |
type Future = CapabilityCheckFuture<S::Future>; | |
fn poll_ready( | |
&mut self, | |
_cx: &mut std::task::Context<'_>, | |
) -> std::task::Poll<Result<(), Self::Error>> { | |
Poll::Ready(Ok(())) | |
} | |
#[instrument(name = "cap check")] | |
fn call(&mut self, req: GatewayRequest) -> Self::Future { | |
let delay = sleep(Duration::from_millis(200)); | |
let response_future = self.inner.call(req.clone()); | |
CapabilityCheckFuture { | |
response_future, | |
delay, | |
req, | |
} | |
} | |
} | |
#[pin_project] | |
pub struct CapabilityCheckFuture<F> { | |
#[pin] | |
response_future: F, | |
#[pin] | |
delay: Sleep, | |
req: GatewayRequest, | |
} | |
impl<F, Response, Error> Future for CapabilityCheckFuture<F> | |
where | |
F: Future<Output = Result<Response, Error>>, | |
Error: Into<BoxError>, | |
{ | |
type Output = Result<Response, BoxError>; | |
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { | |
let this = self.project(); | |
match this.delay.poll(cx) { | |
Poll::Ready(_) => { | |
// Now that we've waited, let's check the request. No "foo" for you! | |
if this.req.req == "foo".to_string() { | |
warn!(request = %this.req.req, "failed capability check"); | |
let error = Box::new(CapabilityError(())); | |
return Poll::Ready(Err(error)); | |
} | |
// If the request passed, let's poll the inner future and return | |
// its result whenever that happens | |
match this.response_future.poll(cx) { | |
Poll::Ready(result) => { | |
let result = result.map_err(Into::into); | |
return Poll::Ready(result); | |
} | |
Poll::Pending => {} | |
} | |
} | |
Poll::Pending => {} | |
} | |
Poll::Pending | |
} | |
} | |
#[derive(Debug, Default)] | |
pub struct CapabilityError(()); | |
// Something nice to give back to the client | |
impl Display for CapabilityError { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
f.pad("incapable") | |
} | |
} | |
// All of error's default implementations are fine with us | |
// Note implementing Error allows us to satisfy an outer service's BoxError requirement. | |
impl Error for CapabilityError {} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment