diff --git a/codegen/tests/run-pass/segments.rs b/codegen/tests/run-pass/segments.rs index fb22fd49..8e7ec45a 100644 --- a/codegen/tests/run-pass/segments.rs +++ b/codegen/tests/run-pass/segments.rs @@ -4,7 +4,7 @@ extern crate rocket; use std::path::PathBuf; -use std::str::Utf8Error; +use rocket::http::uri::SegmentError; #[post("//")] fn get(a: String, b: PathBuf) -> String { @@ -12,7 +12,7 @@ fn get(a: String, b: PathBuf) -> String { } #[post("//")] -fn get2(a: String, b: Result) -> String { +fn get2(a: String, b: Result) -> String { format!("{}/{}", a, b.unwrap().to_string_lossy()) } diff --git a/lib/src/http/uri.rs b/lib/src/http/uri.rs index fe01b854..f873efe3 100644 --- a/lib/src/http/uri.rs +++ b/lib/src/http/uri.rs @@ -373,6 +373,20 @@ impl<'a> Iterator for Segments<'a> { // } } +/// Errors which can occur when attempting to interpret a segment string as a +/// valid path segment. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum SegmentError { + /// The segment contained invalid UTF8 characters when percent decoded. + Utf8(Utf8Error), + /// The segment started with the wrapped invalid character. + BadStart(char), + /// The segment contained the wrapped invalid character. + BadChar(char), + /// The segment ended with the wrapped invalid character. + BadEnd(char), +} + #[cfg(test)] mod tests { use super::URI; diff --git a/lib/src/request/param.rs b/lib/src/request/param.rs index da85f3b6..385dc4a8 100644 --- a/lib/src/request/param.rs +++ b/lib/src/request/param.rs @@ -1,9 +1,9 @@ -use std::str::{Utf8Error, FromStr}; +use std::str::FromStr; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr}; use std::path::PathBuf; use std::fmt::Debug; -use http::uri::{URI, Segments}; +use http::uri::{URI, Segments, SegmentError}; /// Trait to convert a dynamic path segment string to a concrete value. /// @@ -274,6 +274,7 @@ pub trait FromSegments<'a>: Sized { impl<'a> FromSegments<'a> for Segments<'a> { type Error = (); + fn from_segments(segments: Segments<'a>) -> Result, ()> { Ok(segments) } @@ -281,19 +282,46 @@ impl<'a> FromSegments<'a> for Segments<'a> { /// Creates a `PathBuf` from a `Segments` iterator. The returned `PathBuf` is /// percent-decoded. If a segment is equal to "..", the previous segment (if -/// any) is skipped. For security purposes, any other segments that begin with -/// "*" or "." are ignored. If a percent-decoded segment results in invalid -/// UTF8, an `Err` is returned. +/// any) is skipped. +/// +/// For security purposes, if a segment meets any of the following conditions, +/// an `Err` is returned indicating the condition met: +/// +/// * Decoded segment starts with any of: `.`, `*` +/// * Decoded segment ends with any of: `:`, `>`, `<` +/// * Decoded segment contains any of: `/` +/// * On Windows, decoded segment contains any of: '\' +/// * Percent-encoding results in invalid UTF8. +/// +/// As a result of these conditions, a `PathBuf` derived via `FromSegments` is +/// safe to interpolate within, or use as a suffix of, a path without additional +/// checks. impl<'a> FromSegments<'a> for PathBuf { - type Error = Utf8Error; + type Error = SegmentError; - fn from_segments(segments: Segments<'a>) -> Result { + fn from_segments(segments: Segments<'a>) -> Result { let mut buf = PathBuf::new(); for segment in segments { - let decoded = URI::percent_decode(segment.as_bytes())?; + let decoded = URI::percent_decode(segment.as_bytes()) + .map_err(|e| SegmentError::Utf8(e))?; + if decoded == ".." { buf.pop(); - } else if !(decoded.starts_with('.') || decoded.starts_with('*')) { + } else if decoded.starts_with('.') { + return Err(SegmentError::BadStart('.')) + } else if decoded.starts_with('*') { + return Err(SegmentError::BadStart('*')) + } else if decoded.ends_with(':') { + return Err(SegmentError::BadEnd(':')) + } else if decoded.ends_with('>') { + return Err(SegmentError::BadEnd('>')) + } else if decoded.ends_with('<') { + return Err(SegmentError::BadEnd('<')) + } else if decoded.contains('/') { + return Err(SegmentError::BadChar('/')) + } else if cfg!(windows) && decoded.contains('\\') { + return Err(SegmentError::BadChar('\\')) + } else { buf.push(&*decoded) } }