Enforce using 'MsgPack<T>' to deserialize.

This commit enforces using 'MsgPack<T>', and not 'MsgPack<T, Foo>' or
'Compact<T>', to deserialize MsgPack-encoded data. It also simplifies
the round-trip msgpack test and removes the dev-dependency on `rmp`.
This commit is contained in:
Sergio Benitez 2024-08-09 23:10:33 -07:00
parent 0998b37aeb
commit 39ed4a4909
3 changed files with 45 additions and 90 deletions

View File

@ -140,4 +140,3 @@ version_check = "0.9.1"
tokio = { version = "1", features = ["macros", "io-std"] }
figment = { version = "0.10.17", features = ["test"] }
pretty_assertions = "1"
rmp = "0.8"

View File

@ -171,7 +171,7 @@ impl<T, const COMPACT: bool> MsgPack<T, COMPACT> {
}
}
impl<'r, T: Deserialize<'r>, const COMPACT: bool> MsgPack<T, COMPACT> {
impl<'r, T: Deserialize<'r>> MsgPack<T> {
fn from_bytes(buf: &'r [u8]) -> Result<Self, Error> {
rmp_serde::from_slice(buf).map(MsgPack)
}
@ -192,7 +192,7 @@ impl<'r, T: Deserialize<'r>, const COMPACT: bool> MsgPack<T, COMPACT> {
}
#[crate::async_trait]
impl<'r, T: Deserialize<'r>, const COMPACT: bool> FromData<'r> for MsgPack<T, COMPACT> {
impl<'r, T: Deserialize<'r>> FromData<'r> for MsgPack<T> {
type Error = Error;
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> {
@ -233,9 +233,7 @@ impl<'r, T: Serialize, const COMPACT: bool> Responder<'r, 'static> for MsgPack<T
}
#[crate::async_trait]
impl<'v, T, const COMPACT: bool> form::FromFormField<'v> for MsgPack<T, COMPACT>
where T: Deserialize<'v> + Send
{
impl<'v, T: Deserialize<'v> + Send> form::FromFormField<'v> for MsgPack<T> {
// TODO: To implement `from_value`, we need to the raw string so we can
// decode it into bytes as opposed to a string as it won't be UTF-8.

View File

@ -1,113 +1,71 @@
#![cfg(feature = "msgpack")]
use std::borrow::Cow;
use rocket::{Rocket, Build};
use rocket::serde::msgpack::{MsgPack, Compact};
use rocket::serde::msgpack::{self, MsgPack, Compact};
use rocket::local::blocking::Client;
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Eq)]
struct Person {
name: String,
struct Person<'r> {
name: &'r str,
age: u8,
gender: Gender,
}
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Eq)]
#[serde(tag = "gender")]
enum Gender {
Male,
Female,
NonBinary,
}
#[rocket::post("/age_named", data = "<person>")]
fn named(person: MsgPack<Person>) -> MsgPack<Person> {
let person = Person { age: person.age + 1, ..person.into_inner() };
MsgPack(person)
#[rocket::post("/named", data = "<person>")]
fn named(person: MsgPack<Person<'_>>) -> MsgPack<Person<'_>> {
person
}
#[rocket::post("/age_compact", data = "<person>")]
fn compact(person: MsgPack<Person>) -> Compact<Person> {
let person = Person { age: person.age + 1, ..person.into_inner() };
MsgPack(person)
#[rocket::post("/compact", data = "<person>")]
fn compact(person: MsgPack<Person<'_>>) -> Compact<Person<'_>> {
MsgPack(person.into_inner())
}
fn rocket() -> Rocket<Build> {
rocket::build()
.mount("/", rocket::routes![named, compact])
rocket::build().mount("/", rocket::routes![named, compact])
}
fn read_string(buf: &mut rmp::decode::Bytes) -> String {
let mut string_buf = vec![0; 32]; // Awful but we're just testing.
rmp::decode::read_str(buf, &mut string_buf).unwrap().to_string()
}
#[test]
fn check_named_roundtrip() {
let client = Client::debug(rocket()).unwrap();
let person = Person {
name: "Cal".to_string(),
// The object we're going to roundtrip through the API.
const OBJECT: Person<'static> = Person {
name: "Cal",
age: 17,
gender: Gender::NonBinary,
};
let response = client
.post("/age_named")
.body(rmp_serde::to_vec_named(&person).unwrap())
.dispatch()
.into_bytes()
.unwrap();
let mut bytes = rmp::decode::Bytes::new(&response);
assert_eq!(rmp::decode::read_map_len(&mut bytes).unwrap(), 3);
assert_eq!(&read_string(&mut bytes), "name");
assert_eq!(&read_string(&mut bytes), "Cal");
assert_eq!(&read_string(&mut bytes), "age");
assert_eq!(rmp::decode::read_int::<u8, _>(&mut bytes).unwrap(), 18);
assert_eq!(&read_string(&mut bytes), "gender");
// Enums are complicated in serde. In this test, they're encoded like this:
// (JSON equivalent) `{ "gender": "NonBinary" }`, where that object is itself
// the value of the `gender` key in the outer object. `#[serde(flatten)]`
// on the `gender` key in the outer object fixes this, but it prevents `rmp`
// from using compact mode, which would break the test.
assert_eq!(rmp::decode::read_map_len(&mut bytes).unwrap(), 1);
assert_eq!(&read_string(&mut bytes), "gender");
assert_eq!(&read_string(&mut bytes), "NonBinary");
let response_from_compact = client
.post("/age_named")
.body(rmp_serde::to_vec(&person).unwrap())
.dispatch()
.into_bytes()
.unwrap();
assert_eq!(response, response_from_compact);
}
// [ "Cal", 17, "NonBinary" ]
const COMPACT_BYTES: &[u8] = &[
147, 163, 67, 97, 108, 17, 169, 78, 111, 110, 66, 105, 110, 97, 114, 121
];
// { "name": "Cal", "age": 17, "gender": "NonBinary" }
const NAMED_BYTES: &[u8] = &[
131, 164, 110, 97, 109, 101, 163, 67, 97, 108, 163, 97, 103, 101, 17, 166,
103, 101, 110, 100, 101, 114, 169, 78, 111, 110, 66, 105, 110, 97, 114, 121
];
#[test]
fn check_compact_roundtrip() {
fn check_roundtrip() {
let client = Client::debug(rocket()).unwrap();
let person = Person {
name: "Maeve".to_string(),
age: 15,
gender: Gender::Female,
};
let response = client
.post("/age_compact")
.body(rmp_serde::to_vec(&person).unwrap())
.dispatch()
.into_bytes()
.unwrap();
let mut bytes = rmp::decode::Bytes::new(&response);
assert_eq!(rmp::decode::read_array_len(&mut bytes).unwrap(), 3);
assert_eq!(&read_string(&mut bytes), "Maeve");
assert_eq!(rmp::decode::read_int::<u8, _>(&mut bytes).unwrap(), 16);
// Equivalent to the named representation, gender here is encoded like this:
// `[ "Female" ]`.
assert_eq!(rmp::decode::read_array_len(&mut bytes).unwrap(), 1);
assert_eq!(&read_string(&mut bytes), "Female");
let inputs: &[(&'static str, Cow<'static, [u8]>)] = &[
("objpack", msgpack::to_vec(&OBJECT).unwrap().into()),
("named bytes", NAMED_BYTES.into()),
("compact bytes", COMPACT_BYTES.into()),
];
let response_from_named = client
.post("/age_compact")
.body(rmp_serde::to_vec_named(&person).unwrap())
.dispatch()
.into_bytes()
.unwrap();
assert_eq!(response, response_from_named);
for (name, input) in inputs {
let compact = client.post("/compact").body(input).dispatch();
assert_eq!(compact.into_bytes().unwrap(), COMPACT_BYTES, "{name} mismatch");
let named = client.post("/named").body(input).dispatch();
assert_eq!(named.into_bytes().unwrap(), NAMED_BYTES, "{name} mismatch");
}
}