use std::{io, time::Duration}; use std::task::{Poll, Context}; use std::pin::Pin; use bytes::BytesMut; use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::time::{sleep, Sleep}; use futures::stream::Stream; use futures::future::{self, Future, FutureExt}; use crate::http::hyper::Bytes; pin_project! { pub struct ReaderStream { #[pin] reader: Option, buf: BytesMut, cap: usize, } } impl Stream for ReaderStream { type Item = std::io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use tokio_util::io::poll_read_buf; let mut this = self.as_mut().project(); let reader = match this.reader.as_pin_mut() { Some(r) => r, None => return Poll::Ready(None), }; if this.buf.capacity() == 0 { this.buf.reserve(*this.cap); } match poll_read_buf(reader, cx, &mut this.buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(err)) => { self.project().reader.set(None); Poll::Ready(Some(Err(err))) } Poll::Ready(Ok(0)) => { self.project().reader.set(None); Poll::Ready(None) } Poll::Ready(Ok(_)) => { let chunk = this.buf.split(); Poll::Ready(Some(Ok(chunk.freeze()))) } } } } pub trait AsyncReadExt: AsyncRead + Sized { fn into_bytes_stream(self, cap: usize) -> ReaderStream { ReaderStream { reader: Some(self), cap, buf: BytesMut::with_capacity(cap) } } } impl AsyncReadExt for T { } pub trait PollExt { fn map_err_ext(self, f: F) -> Poll>> where F: FnOnce(E) -> U; } impl PollExt for Poll>> { /// Changes the error value of this `Poll` with the closure provided. fn map_err_ext(self, f: F) -> Poll>> where F: FnOnce(E) -> U { match self { Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(t))), Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(f(e)))), Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } } pin_project! { /// Stream for the [`chain`](super::AsyncReadExt::chain) method. #[must_use = "streams do nothing unless polled"] pub struct Chain { #[pin] first: T, #[pin] second: U, done_first: bool, } } impl Chain { pub(crate) fn new(first: T, second: U) -> Self { Self { first, second, done_first: false } } } impl Chain { /// Gets references to the underlying readers in this `Chain`. pub fn get_ref(&self) -> (&T, &U) { (&self.first, &self.second) } } impl AsyncRead for Chain { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let me = self.project(); if !*me.done_first { let init_rem = buf.remaining(); futures::ready!(me.first.poll_read(cx, buf))?; if buf.remaining() == init_rem { *me.done_first = true; } else { return Poll::Ready(Ok(())); } } me.second.poll_read(cx, buf) } } enum State { /// I/O has not been cancelled. Proceed as normal. Active, /// I/O has been cancelled. See if we can finish before the timer expires. Grace(Pin>), /// Grace period elapsed. Shutdown the connection, waiting for the timer /// until we force close. Mercy(Pin>), /// We failed to shutdown and are force-closing the connection. Terminated, /// We successfully shutdown the connection. Inactive, } pin_project! { /// I/O that can be cancelled when a future `F` resolves. #[must_use = "futures do nothing unless polled"] pub struct CancellableIo { #[pin] io: I, #[pin] trigger: future::Fuse, state: State, grace: Duration, mercy: Duration, } } impl CancellableIo { pub fn new(trigger: F, io: I, grace: Duration, mercy: Duration) -> Self { CancellableIo { io, grace, mercy, trigger: trigger.fuse(), state: State::Active } } /// Returns `Ok(true)` if connection processing should continue. fn poll_trigger_then( self: Pin<&mut Self>, cx: &mut Context<'_>, io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll>, ) -> Poll> { let mut me = self.project(); // CORRECTNESS: _EVERY_ branch must reset `state`! If `state` is // unchanged in a branch, that branch _must_ `break`! No `return`! let mut state = std::mem::replace(me.state, State::Active); let result = loop { match state { State::Active => { if me.trigger.as_mut().poll(cx).is_ready() { state = State::Grace(Box::pin(sleep(*me.grace))); } else { state = State::Active; break io(me.io, cx); } } State::Grace(mut sleep) => { if sleep.as_mut().poll(cx).is_ready() { if let Some(deadline) = sleep.deadline().checked_add(*me.mercy) { sleep.as_mut().reset(deadline); state = State::Mercy(sleep); } else { state = State::Terminated; } } else { state = State::Grace(sleep); break io(me.io, cx); } }, State::Mercy(mut sleep) => { if sleep.as_mut().poll(cx).is_ready() { state = State::Terminated; continue; } match me.io.as_mut().poll_shutdown(cx) { Poll::Ready(Err(e)) => { state = State::Terminated; break Poll::Ready(Err(e)); } Poll::Ready(Ok(())) => { state = State::Inactive; break Poll::Ready(Err(gone())); } Poll::Pending => { state = State::Mercy(sleep); break Poll::Pending; } } }, State::Terminated => { // Just in case, as a last ditch effort. Ignore pending. state = State::Terminated; let _ = me.io.as_mut().poll_shutdown(cx); break Poll::Ready(Err(time_out())); }, State::Inactive => { state = State::Inactive; break Poll::Ready(Err(gone())); } } }; *me.state = state; result } } fn time_out() -> io::Error { io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out") } fn gone() -> io::Error { io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated") } impl AsyncRead for CancellableIo { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf)) } } impl AsyncWrite for CancellableIo { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf)) } fn poll_flush( mut self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll> { self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx)) } fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll> { self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx)) } fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs)) } fn is_write_vectored(&self) -> bool { self.io.is_write_vectored() } } use crate::http::private::{Listener, Connection}; impl Connection for CancellableIo { fn remote_addr(&self) -> Option { self.io.remote_addr() } } pin_project! { pub struct CancellableListener { pub trigger: F, #[pin] pub listener: L, pub grace: Duration, pub mercy: Duration, } } impl CancellableListener { pub fn new(trigger: F, listener: L, grace: u64, mercy: u64) -> Self { let (grace, mercy) = (Duration::from_secs(grace), Duration::from_secs(mercy)); CancellableListener { trigger, listener, grace, mercy } } } impl Listener for CancellableListener { type Connection = CancellableIo; fn local_addr(&self) -> Option { self.listener.local_addr() } fn poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll> { self.as_mut().project().listener .poll_accept(cx) .map(|res| res.map(|conn| { CancellableIo::new(self.trigger.clone(), conn, self.grace, self.mercy) })) } } pub trait StreamExt: Sized + Stream { fn join(self, other: U) -> Join where U: Stream; } impl StreamExt for S { fn join(self, other: U) -> Join where U: Stream { Join::new(self, other) } } pin_project! { /// Stream returned by the [`join`](super::StreamExt::join) method. pub struct Join { #[pin] a: T, #[pin] b: U, // When `true`, poll `a` first, otherwise, `poll` b`. toggle: bool, // Set when either `a` or `b` return `None`. done: bool, } } impl Join { pub(super) fn new(a: T, b: U) -> Join where T: Stream, U: Stream, { Join { a, b, toggle: false, done: false, } } fn poll_next>( first: Pin<&mut A>, second: Pin<&mut B>, done: &mut bool, cx: &mut Context<'_>, ) -> Poll> { match first.poll_next(cx) { Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) } Poll::Pending => match second.poll_next(cx) { Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) } Poll::Pending => Poll::Pending } } } } impl Stream for Join where T: Stream, U: Stream, { type Item = T::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.done { return Poll::Ready(None); } let me = self.project(); *me.toggle = !*me.toggle; match *me.toggle { true => Self::poll_next(me.a, me.b, me.done, cx), false => Self::poll_next(me.b, me.a, me.done, cx), } } fn size_hint(&self) -> (usize, Option) { let (left_low, left_high) = self.a.size_hint(); let (right_low, right_high) = self.b.size_hint(); let low = left_low.saturating_add(right_low); let high = match (left_high, right_high) { (Some(h1), Some(h2)) => h1.checked_add(h2), _ => None, }; (low, high) } }