diff --git a/contrib/lib/src/msgpack.rs b/contrib/lib/src/msgpack.rs index 13f91043..7a87f6f4 100644 --- a/contrib/lib/src/msgpack.rs +++ b/contrib/lib/src/msgpack.rs @@ -14,14 +14,16 @@ //! features = ["msgpack"] //! ``` -use std::io::Read; use std::ops::{Deref, DerefMut}; +use futures::io::AsyncReadExt; + use rocket::request::Request; use rocket::outcome::Outcome::*; -use rocket::data::{Outcome, Transform, Transform::*, Transformed, Data, FromData}; -use rocket::response::{self, Responder, content}; +use rocket::data::{Data, FromData, FromDataFuture, Transform::*, TransformFuture, Transformed}; use rocket::http::Status; +use rocket::response::{self, content, Responder}; +use rocket::AsyncReadExt as _; use serde::Serialize; use serde::de::Deserialize; @@ -40,7 +42,7 @@ pub use rmp_serde::decode::Error; /// request body. /// /// ```rust -/// # #![feature(proc_macro_hygiene)] +/// # #![feature(proc_macro_hygiene, async_await)] /// # #[macro_use] extern crate rocket; /// # extern crate rocket_contrib; /// # type User = usize; @@ -64,7 +66,7 @@ pub use rmp_serde::decode::Error; /// response is set to `application/msgpack` automatically. /// /// ```rust -/// # #![feature(proc_macro_hygiene)] +/// # #![feature(proc_macro_hygiene, async_await)] /// # #[macro_use] extern crate rocket; /// # extern crate rocket_contrib; /// # type User = usize; @@ -119,45 +121,52 @@ impl<'a, T: Deserialize<'a>> FromData<'a> for MsgPack { type Owned = Vec; type Borrowed = [u8]; - fn transform(r: &Request<'_>, d: Data) -> Transform> { - let mut buf = Vec::new(); + fn transform(r: &Request<'_>, d: Data) -> TransformFuture<'a, Self::Owned, Self::Error> { let size_limit = r.limits().get("msgpack").unwrap_or(LIMIT); - match d.open().take(size_limit).read_to_end(&mut buf) { - Ok(_) => Borrowed(Success(buf)), - Err(e) => Borrowed(Failure((Status::BadRequest, Error::InvalidDataRead(e)))) - } + + Box::pin(async move { + let mut buf = Vec::new(); + let mut reader = d.open().take(size_limit); + match reader.read_to_end(&mut buf).await { + Ok(_) => Borrowed(Success(buf)), + Err(e) => Borrowed(Failure((Status::BadRequest, Error::InvalidDataRead(e)))), + } + }) } - fn from_data(_: &Request<'_>, o: Transformed<'a, Self>) -> Outcome { + fn from_data(_: &Request<'_>, o: Transformed<'a, Self>) -> FromDataFuture<'a, Self, Self::Error> { use self::Error::*; - let buf = try_outcome!(o.borrowed()); - match rmp_serde::from_slice(&buf) { - Ok(val) => Success(MsgPack(val)), - Err(e) => { - error_!("Couldn't parse MessagePack body: {:?}", e); - match e { - TypeMismatch(_) | OutOfRange | LengthMismatch(_) => { - Failure((Status::UnprocessableEntity, e)) + Box::pin(async move { + let buf = try_outcome!(o.borrowed()); + match rmp_serde::from_slice(&buf) { + Ok(val) => Success(MsgPack(val)), + Err(e) => { + error_!("Couldn't parse MessagePack body: {:?}", e); + match e { + TypeMismatch(_) | OutOfRange | LengthMismatch(_) => { + Failure((Status::UnprocessableEntity, e)) + } + _ => Failure((Status::BadRequest, e)), } - _ => Failure((Status::BadRequest, e)) } } - } + }) } } /// Serializes the wrapped value into MessagePack. Returns a response with /// Content-Type `MsgPack` and a fixed-size body with the serialization. If /// serialization fails, an `Err` of `Status::InternalServerError` is returned. -impl Responder<'static> for MsgPack { - fn respond_to(self, req: &Request<'_>) -> response::Result<'static> { - rmp_serde::to_vec(&self.0).map_err(|e| { - error_!("MsgPack failed to serialize: {:?}", e); - Status::InternalServerError - }).and_then(|buf| { - content::MsgPack(buf).respond_to(req) - }) +impl<'r, T: Serialize> Responder<'r> for MsgPack { + fn respond_to(self, req: &'r Request<'_>) -> response::ResultFuture<'r> { + match rmp_serde::to_vec(&self.0) { + Ok(buf) => content::MsgPack(buf).respond_to(req), + Err(e) => Box::pin(async move { + error_!("MsgPack failed to serialize: {:?}", e); + Err(Status::InternalServerError) + }), + } } } diff --git a/examples/msgpack/src/tests.rs b/examples/msgpack/src/tests.rs index 0c7e3461..5aa35c7b 100644 --- a/examples/msgpack/src/tests.rs +++ b/examples/msgpack/src/tests.rs @@ -16,7 +16,7 @@ fn msgpack_get() { assert_eq!(res.content_type(), Some(ContentType::MsgPack)); // Check that the message is `[1, "Hello, world!"]` - assert_eq!(&res.body_bytes().unwrap(), + assert_eq!(&res.body_bytes_wait().unwrap(), &[146, 1, 173, 72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33]); } @@ -30,5 +30,5 @@ fn msgpack_post() { .dispatch(); assert_eq!(res.status(), Status::Ok); - assert_eq!(res.body_string(), Some("Goodbye, world!".into())); + assert_eq!(res.body_string_wait(), Some("Goodbye, world!".into())); }