Implement #2871 by matching on outcome

This commit is contained in:
Jesper Steen Møller 2024-10-03 00:32:08 +02:00
parent 3bf9ef02d6
commit e70b79f59d
3 changed files with 88 additions and 4 deletions

View File

@ -521,12 +521,13 @@ impl<'r, T: FromRequest<'r>> FromRequest<'r> for Result<T, T::Error> {
#[crate::async_trait] #[crate::async_trait]
impl<'r, T: FromRequest<'r>> FromRequest<'r> for Option<T> { impl<'r, T: FromRequest<'r>> FromRequest<'r> for Option<T> {
type Error = Infallible; type Error = T::Error;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match T::from_request(request).await { match T::from_request(request).await {
Success(val) => Success(Some(val)), Success(val) => Success(Some(val)),
Error(_) | Forward(_) => Success(None), Forward(_) => Success(None),
Error((status, error)) => Error((status, error)),
} }
} }
} }

View File

@ -256,7 +256,7 @@ impl<'r> FromRequest<'r> for FlashMessage<'r> {
Ok(i) if i <= kv.len() => Ok(Flash::named(&kv[..i], &kv[i..], req)), Ok(i) if i <= kv.len() => Ok(Flash::named(&kv[..i], &kv[i..], req)),
_ => Err(()) _ => Err(())
} }
}).or_error(Status::BadRequest) }).or_forward(Status::BadRequest)
} }
} }

View File

@ -0,0 +1,83 @@
#[macro_use]
extern crate rocket;
use std::num::ParseIntError;
use rocket::{outcome::IntoOutcome, request::{FromRequest, Outcome}, Request};
use rocket_http::{Header, Status};
pub struct SessionId {
session_id: u64,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for SessionId {
type Error = ParseIntError;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ParseIntError> {
let session_id_string = request.headers().get("Session-Id").next()
.or_forward(Status::BadRequest);
session_id_string.and_then(|v| v.parse()
.map(|id| SessionId { session_id: id })
.or_error(Status::BadRequest))
}
}
#[get("/mandatory")]
fn get_data_with_mandatory_header(header: SessionId) -> String {
format!("GET for session {:}", header.session_id)
}
#[get("/optional")]
fn get_data_with_opt_header(opt_header: Option<SessionId>) -> String {
if let Some(id) = opt_header {
format!("GET for session {:}", id.session_id)
} else {
format!("GET for new session")
}
}
#[test]
fn read_optional_header() {
let rocket = rocket::build().mount(
"/",
routes![get_data_with_opt_header, get_data_with_mandatory_header]);
let client = rocket::local::blocking::Client::debug(rocket).unwrap();
// If we supply the header, the handler sees it
let response = client.get("/optional")
.header(Header::new("session-id", "1234567")).dispatch();
assert_eq!(response.into_string().unwrap(), "GET for session 1234567".to_string());
// If no header, means that the handler sees a None
let response = client.get("/optional").dispatch();
assert_eq!(response.into_string().unwrap(), "GET for new session".to_string());
// If we supply a malformed header, the handler will not be called, but the request will fail
let response = client.get("/optional")
.header(Header::new("session-id", "Xw23")).dispatch();
assert_eq!(response.status(), Status::BadRequest);
}
#[test]
fn read_mandatory_header() {
let rocket = rocket::build().mount(
"/",
routes![get_data_with_opt_header, get_data_with_mandatory_header]);
let client = rocket::local::blocking::Client::debug(rocket).unwrap();
// If the header is missing, it's a bad request (extra info would be nice, though)
let response = client.get("/mandatory").dispatch();
assert_eq!(response.status(), Status::BadRequest);
// If the header is malformed, it's a bad request too (extra info would be nice, though)
let response = client.get("/mandatory")
.header(Header::new("session-id", "Xw23")).dispatch();
assert_eq!(response.status(), Status::BadRequest);
// If the header is fine, just do the stuff
let response = client.get("/mandatory")
.header(Header::new("session-id", "64535")).dispatch();
assert_eq!(response.status(), Status::Ok);
assert_eq!(response.into_string().unwrap(), "GET for session 64535".to_string());
}