mirror of https://github.com/rwf2/Rocket.git
Implement #2871 by matching on outcome
This commit is contained in:
parent
3bf9ef02d6
commit
e70b79f59d
|
@ -521,12 +521,13 @@ impl<'r, T: FromRequest<'r>> FromRequest<'r> for Result<T, T::Error> {
|
||||||
|
|
||||||
#[crate::async_trait]
|
#[crate::async_trait]
|
||||||
impl<'r, T: FromRequest<'r>> FromRequest<'r> for Option<T> {
|
impl<'r, T: FromRequest<'r>> FromRequest<'r> for Option<T> {
|
||||||
type Error = Infallible;
|
type Error = T::Error;
|
||||||
|
|
||||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match T::from_request(request).await {
|
match T::from_request(request).await {
|
||||||
Success(val) => Success(Some(val)),
|
Success(val) => Success(Some(val)),
|
||||||
Error(_) | Forward(_) => Success(None),
|
Forward(_) => Success(None),
|
||||||
|
Error((status, error)) => Error((status, error)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -256,7 +256,7 @@ impl<'r> FromRequest<'r> for FlashMessage<'r> {
|
||||||
Ok(i) if i <= kv.len() => Ok(Flash::named(&kv[..i], &kv[i..], req)),
|
Ok(i) if i <= kv.len() => Ok(Flash::named(&kv[..i], &kv[i..], req)),
|
||||||
_ => Err(())
|
_ => Err(())
|
||||||
}
|
}
|
||||||
}).or_error(Status::BadRequest)
|
}).or_forward(Status::BadRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
#[macro_use]
|
||||||
|
extern crate rocket;
|
||||||
|
|
||||||
|
use std::num::ParseIntError;
|
||||||
|
|
||||||
|
use rocket::{outcome::IntoOutcome, request::{FromRequest, Outcome}, Request};
|
||||||
|
use rocket_http::{Header, Status};
|
||||||
|
|
||||||
|
pub struct SessionId {
|
||||||
|
session_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for SessionId {
|
||||||
|
type Error = ParseIntError;
|
||||||
|
|
||||||
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ParseIntError> {
|
||||||
|
let session_id_string = request.headers().get("Session-Id").next()
|
||||||
|
.or_forward(Status::BadRequest);
|
||||||
|
session_id_string.and_then(|v| v.parse()
|
||||||
|
.map(|id| SessionId { session_id: id })
|
||||||
|
.or_error(Status::BadRequest))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/mandatory")]
|
||||||
|
fn get_data_with_mandatory_header(header: SessionId) -> String {
|
||||||
|
format!("GET for session {:}", header.session_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/optional")]
|
||||||
|
fn get_data_with_opt_header(opt_header: Option<SessionId>) -> String {
|
||||||
|
if let Some(id) = opt_header {
|
||||||
|
format!("GET for session {:}", id.session_id)
|
||||||
|
} else {
|
||||||
|
format!("GET for new session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn read_optional_header() {
|
||||||
|
let rocket = rocket::build().mount(
|
||||||
|
"/",
|
||||||
|
routes![get_data_with_opt_header, get_data_with_mandatory_header]);
|
||||||
|
let client = rocket::local::blocking::Client::debug(rocket).unwrap();
|
||||||
|
|
||||||
|
// If we supply the header, the handler sees it
|
||||||
|
let response = client.get("/optional")
|
||||||
|
.header(Header::new("session-id", "1234567")).dispatch();
|
||||||
|
assert_eq!(response.into_string().unwrap(), "GET for session 1234567".to_string());
|
||||||
|
|
||||||
|
// If no header, means that the handler sees a None
|
||||||
|
let response = client.get("/optional").dispatch();
|
||||||
|
assert_eq!(response.into_string().unwrap(), "GET for new session".to_string());
|
||||||
|
|
||||||
|
// If we supply a malformed header, the handler will not be called, but the request will fail
|
||||||
|
let response = client.get("/optional")
|
||||||
|
.header(Header::new("session-id", "Xw23")).dispatch();
|
||||||
|
assert_eq!(response.status(), Status::BadRequest);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn read_mandatory_header() {
|
||||||
|
let rocket = rocket::build().mount(
|
||||||
|
"/",
|
||||||
|
routes![get_data_with_opt_header, get_data_with_mandatory_header]);
|
||||||
|
let client = rocket::local::blocking::Client::debug(rocket).unwrap();
|
||||||
|
|
||||||
|
// If the header is missing, it's a bad request (extra info would be nice, though)
|
||||||
|
let response = client.get("/mandatory").dispatch();
|
||||||
|
assert_eq!(response.status(), Status::BadRequest);
|
||||||
|
|
||||||
|
// If the header is malformed, it's a bad request too (extra info would be nice, though)
|
||||||
|
let response = client.get("/mandatory")
|
||||||
|
.header(Header::new("session-id", "Xw23")).dispatch();
|
||||||
|
assert_eq!(response.status(), Status::BadRequest);
|
||||||
|
|
||||||
|
// If the header is fine, just do the stuff
|
||||||
|
let response = client.get("/mandatory")
|
||||||
|
.header(Header::new("session-id", "64535")).dispatch();
|
||||||
|
assert_eq!(response.status(), Status::Ok);
|
||||||
|
assert_eq!(response.into_string().unwrap(), "GET for session 64535".to_string());
|
||||||
|
}
|
Loading…
Reference in New Issue