Remove unsafe 'from_utf8_unchecked'; improve form parsing.

The 'FormItems' iterator now successfully parses empty keys and values
as well as keys without values.
This commit is contained in:
Sergio Benitez 2018-02-14 14:02:56 -08:00
parent 8b1aaed0ce
commit de8e1978c5
4 changed files with 84 additions and 61 deletions

View File

@ -55,14 +55,15 @@ fn check_bad_form(form_str: &str, status: Status) {
#[test]
fn test_bad_form_abnromal_inputs() {
check_bad_form("&", Status::BadRequest);
check_bad_form("=", Status::BadRequest);
check_bad_form("&&&===&", Status::BadRequest);
check_bad_form("&&&=hi==&", Status::BadRequest);
}
#[test]
fn test_bad_form_missing_fields() {
let bad_inputs: [&str; 6] = [
let bad_inputs: [&str; 8] = [
"&",
"=",
"username=Sergio",
"password=pass",
"age=30",

View File

@ -36,8 +36,12 @@ use http::RawStr;
/// completion. The iterator attempts to be lenient. In particular, it allows
/// the following oddball behavior:
///
/// * A single trailing `&` character is allowed.
/// * Empty values are allowed.
/// * Trailing and consecutive `&` characters are allowed.
/// * Empty keys and/or values are allowed.
///
/// Additionally, the iterator skips items with both an empty key _and_ an empty
/// value: at least one of the two must be non-empty to be returned from this
/// iterator.
///
/// # Examples
///
@ -46,8 +50,8 @@ use http::RawStr;
/// ```rust
/// use rocket::request::FormItems;
///
/// // prints "greeting = hello" then "username = jake"
/// let form_string = "greeting=hello&username=jake";
/// // prints "greeting = hello", "username = jake", and "done = "
/// let form_string = "greeting=hello&username=jake&done";
/// for (key, value) in FormItems::from(form_string) {
/// println!("{} = {}", key, value);
/// }
@ -58,7 +62,7 @@ use http::RawStr;
/// ```rust
/// use rocket::request::FormItems;
///
/// let form_string = "greeting=hello&username=jake";
/// let form_string = "greeting=hello&username=jake&done";
/// let mut items = FormItems::from(form_string);
///
/// let next = items.next().unwrap();
@ -69,6 +73,10 @@ use http::RawStr;
/// assert_eq!(next.0, "username");
/// assert_eq!(next.1, "jake");
///
/// let next = items.next().unwrap();
/// assert_eq!(next.0, "done");
/// assert_eq!(next.1, "");
///
/// assert_eq!(items.next(), None);
/// assert!(items.completed());
/// ```
@ -101,7 +109,7 @@ impl<'f> FormItems<'f> {
/// ```rust
/// use rocket::request::FormItems;
///
/// let mut items = FormItems::from("a=b&=d");
/// let mut items = FormItems::from("a=b&==d");
/// let key_values: Vec<_> = items.by_ref().collect();
///
/// assert_eq!(key_values.len(), 1);
@ -136,12 +144,13 @@ impl<'f> FormItems<'f> {
/// ```rust
/// use rocket::request::FormItems;
///
/// let mut items = FormItems::from("a=b&=d");
/// let mut items = FormItems::from("a=b&=d=");
///
/// assert!(items.next().is_some());
/// assert_eq!(items.completed(), false);
/// assert_eq!(items.exhaust(), false);
/// assert_eq!(items.completed(), false);
/// assert!(items.next().is_none());
/// ```
#[inline]
pub fn exhaust(&mut self) -> bool {
@ -200,10 +209,7 @@ impl<'f> From<&'f str> for FormItems<'f> {
/// `x-www-form-urlencoded` form `string`.
#[inline(always)]
fn from(string: &'f str) -> FormItems<'f> {
FormItems {
string: string.into(),
next_index: 0
}
FormItems::from(RawStr::from_str(string))
}
}
@ -211,27 +217,32 @@ impl<'f> Iterator for FormItems<'f> {
type Item = (&'f RawStr, &'f RawStr);
fn next(&mut self) -> Option<Self::Item> {
loop {
let s = &self.string[self.next_index..];
let (key, rest) = match memchr2(b'=', b'&', s.as_bytes()) {
Some(i) if s.as_bytes()[i] == b'=' => (&s[..i], &s[(i + 1)..]),
Some(_) => return None,
None => return None,
};
if key.is_empty() {
if s.is_empty() {
return None;
}
let (value, consumed) = match rest.find('&') {
Some(index) => (&rest[..index], index + 1),
None => (rest, rest.len()),
let (key, rest, key_consumed) = match memchr2(b'=', b'&', s.as_bytes()) {
Some(i) if s.as_bytes()[i] == b'=' => (&s[..i], &s[(i + 1)..], i + 1),
Some(i) => (&s[..i], &s[i..], i),
None => (s, &s[s.len()..], s.len())
};
self.next_index += key.len() + 1 + consumed;
Some((key.into(), value.into()))
}
}
let (value, val_consumed) = match memchr2(b'=', b'&', rest.as_bytes()) {
Some(i) if rest.as_bytes()[i] == b'=' => return None,
Some(i) => (&rest[..i], i + 1),
None => (rest, rest.len())
};
self.next_index += key_consumed + val_consumed;
match (key.is_empty(), value.is_empty()) {
(true, true) => continue,
_ => return Some((key.into(), value.into()))
}
}
}
}
#[cfg(test)]
mod test {
@ -246,22 +257,23 @@ mod test {
let mut items = FormItems::from(string);
let results: Vec<_> = items.by_ref().collect();
if let Some(expected) = expected {
assert_eq!(expected.len(), results.len());
assert_eq!(expected.len(), results.len(),
"expected {:?}, got {:?} for {:?}", expected, results, string);
for i in 0..results.len() {
let (expected_key, actual_key) = (expected[i].0, results[i].0);
let (expected_val, actual_val) = (expected[i].1, results[i].1);
assert!(actual_key == expected_key,
"key [{}] mismatch: expected {}, got {}",
i, expected_key, actual_key);
"key [{}] mismatch for {}: expected {}, got {}",
i, string, expected_key, actual_key);
assert!(actual_val == expected_val,
"val [{}] mismatch: expected {}, got {}",
i, expected_val, actual_val);
"val [{}] mismatch for {}: expected {}, got {}",
i, string, expected_val, actual_val);
}
} else {
assert!(!items.exhaust());
assert!(!items.exhaust(), "{} unexpectedly parsed successfully", string);
}
}
@ -270,11 +282,10 @@ mod test {
check_form!("username=user&password=pass",
&[("username", "user"), ("password", "pass")]);
check_form!("user=user&user=pass",
&[("user", "user"), ("user", "pass")]);
check_form!("user=&password=pass",
&[("user", ""), ("password", "pass")]);
check_form!("user=user&user=pass", &[("user", "user"), ("user", "pass")]);
check_form!("user=&password=pass", &[("user", ""), ("password", "pass")]);
check_form!("user&password=pass", &[("user", ""), ("password", "pass")]);
check_form!("foo&bar", &[("foo", ""), ("bar", "")]);
check_form!("a=b", &[("a", "b")]);
check_form!("value=Hello+World", &[("value", "Hello+World")]);
@ -282,14 +293,27 @@ mod test {
check_form!("user=", &[("user", "")]);
check_form!("user=&", &[("user", "")]);
check_form!("a=b&a=", &[("a", "b"), ("a", "")]);
check_form!("user=&password", &[("user", ""), ("password", "")]);
check_form!("a=b&a", &[("a", "b"), ("a", "")]);
check_form!(@bad "user=&password");
check_form!(@bad "user=x&&");
check_form!(@bad "a=b&a");
check_form!(@bad "=");
check_form!(@bad "&");
check_form!(@bad "=&");
check_form!(@bad "&=&");
check_form!(@bad "=&=");
check_form!("user=x&&", &[("user", "x")]);
check_form!("user=x&&&&pass=word", &[("user", "x"), ("pass", "word")]);
check_form!("user=x&&&&pass=word&&&x=z&d&&&e",
&[("user", "x"), ("pass", "word"), ("x", "z"), ("d", ""), ("e", "")]);
check_form!("=&a=b&&=", &[("a", "b")]);
check_form!("=b", &[("", "b")]);
check_form!("=b&=c", &[("", "b"), ("", "c")]);
check_form!("=", &[]);
check_form!("&=&", &[]);
check_form!("&", &[]);
check_form!("=&=", &[]);
check_form!(@bad "=b&==");
check_form!(@bad "==");
check_form!(@bad "=k=");
check_form!(@bad "=abc=");
check_form!(@bad "=abc=cd");
}
}

View File

@ -1,5 +1,5 @@
use std::collections::HashMap;
use std::str::from_utf8_unchecked;
use std::str::from_utf8;
use std::cmp::min;
use std::io::{self, Write};
use std::mem;
@ -176,20 +176,18 @@ impl Rocket {
let (min_len, max_len) = ("_method=get".len(), "_method=delete".len());
let is_form = req.content_type().map_or(false, |ct| ct.is_form());
if is_form && req.method() == Method::Post && data_len >= min_len {
// We're only using this for comparison and throwing it away
// afterwards, so it doesn't matter if we have invalid UTF8.
let form =
unsafe { from_utf8_unchecked(&data.peek()[..min(data_len, max_len)]) };
if let Ok(form) = from_utf8(&data.peek()[..min(data_len, max_len)]) {
let method: Option<Result<Method, _>> = FormItems::from(form)
.filter(|&(key, _)| key.as_str() == "_method")
.map(|(_, value)| value.parse())
.next();
if let Some((key, value)) = FormItems::from(form).next() {
if key == "_method" {
if let Ok(method) = value.parse() {
if let Some(Ok(method)) = method {
req.set_method(method);
}
}
}
}
}
#[inline]
pub(crate) fn dispatch<'s, 'r>(

View File

@ -59,7 +59,7 @@ mod limits_tests {
.header(ContentType::Form)
.dispatch();
assert_eq!(response.status(), Status::BadRequest);
assert_eq!(response.status(), Status::UnprocessableEntity);
}
#[test]