diff --git a/instant-xml/src/de.rs b/instant-xml/src/de.rs index 7d4c268..47ae278 100644 --- a/instant-xml/src/de.rs +++ b/instant-xml/src/de.rs @@ -341,7 +341,7 @@ pub fn borrow_cow_str<'xml>( } if let Some(value) = deserializer.take_str()? { - *into = Some(decode(value)); + *into = Some(decode(value)?); }; deserializer.ignore()?; @@ -357,7 +357,7 @@ pub fn borrow_cow_slice_u8<'xml>( } if let Some(value) = deserializer.take_str()? { - *into = Some(match decode(value) { + *into = Some(match decode(value)? { Cow::Borrowed(v) => Cow::Borrowed(v.as_bytes()), Cow::Owned(v) => Cow::Owned(v.into_bytes()), }); diff --git a/instant-xml/src/impls.rs b/instant-xml/src/impls.rs index c33dd5e..099810d 100644 --- a/instant-xml/src/impls.rs +++ b/instant-xml/src/impls.rs @@ -260,7 +260,7 @@ impl<'xml> FromXml<'xml> for String { } match deserializer.take_str()? { - Some(value) => *into = Some(decode(value).into_owned()), + Some(value) => *into = Some(decode(value)?.into_owned()), None => return Ok(()), } @@ -292,7 +292,7 @@ impl<'xml> FromXml<'xml> for &'xml str { None => return Ok(()), }; - match decode(value) { + match decode(value)? { Cow::Borrowed(str) => *into = Some(str), Cow::Owned(_) => { return Err(Error::UnexpectedValue(format!( @@ -484,62 +484,68 @@ fn encode(input: &str) -> Result, Error> { Ok(Cow::Owned(result)) } -pub(crate) fn decode(input: &str) -> Cow<'_, str> { +pub(crate) fn decode(input: &str) -> Result, Error> { let mut result = String::with_capacity(input.len()); - let input_len = input.len(); - - let mut last_end = 0; - while input_len - last_end >= 4 { - if input.is_char_boundary(last_end + 4) { - match &input[last_end..(last_end + 4)] { - "<" => { - result.push('<'); - last_end += 4; - continue; - } - ">" => { - result.push('>'); - last_end += 4; - continue; - } - _ => (), - }; - } - - if input_len - last_end >= 5 && input.is_char_boundary(last_end + 5) { - if &input[last_end..(last_end + 5)] == "&" { - result.push('&'); - last_end += 5; - continue; - } - - if input_len - last_end >= 6 && input.is_char_boundary(last_end + 6) { - match &input[last_end..(last_end + 6)] { - "'" => { - result.push('\''); - last_end += 6; - continue; + let (mut state, mut last_end) = (DecodeState::Normal, 0); + for (i, &b) in input.as_bytes().iter().enumerate() { + // use a state machine to find entities + state = match (state, b) { + (DecodeState::Normal, b'&') => DecodeState::Entity([0; 4], 0), + (DecodeState::Normal, _) => DecodeState::Normal, + (DecodeState::Entity(chars, len), b';') => { + let decoded = match chars[..len] { + [b'a', b'm', b'p'] => '&', + [b'a', b'p', b'o', b's'] => '\'', + [b'g', b't'] => '>', + [b'l', b't'] => '<', + [b'q', b'u', b'o', b't'] => '"', + _ => { + return Err(Error::InvalidEntity( + String::from_utf8_lossy(&chars[..len]).into_owned(), + )) } - """ => { - result.push('"'); - last_end += 6; - continue; - } - _ => (), }; + + let start = i - (len + 1); // current position - (length of entity characters + 1 for '&') + if last_end < start { + // Unwrap should be safe: `last_end` and `start` must be at character boundaries. + result.push_str(input.get(last_end..start).unwrap()); + } + + last_end = i + 1; + result.push(decoded); + DecodeState::Normal } + (DecodeState::Entity(mut chars, len), b) => { + if len >= 4 { + let mut bytes = Vec::with_capacity(5); + bytes.extend(&chars[..len]); + bytes.push(b); + return Err(Error::InvalidEntity( + String::from_utf8_lossy(&bytes).into_owned(), + )); + } + + chars[len] = b; + DecodeState::Entity(chars, len + 1) + } + }; + } + + Ok(match result.is_empty() { + true => Cow::Borrowed(input), + false => { + // Unwrap should be safe: `last_end` and `input.len()` must be at character boundaries. + result.push_str(input.get(last_end..input.len()).unwrap()); + Cow::Owned(result) } + }) +} - result.push_str(input.get(last_end..last_end + 1).unwrap()); - last_end += 1; - } - - result.push_str(input.get(last_end..).unwrap()); - if result.len() == input.len() { - return Cow::Borrowed(input); - } - - Cow::Owned(result) +#[derive(Debug)] +enum DecodeState { + Normal, + Entity([u8; 4], usize), } impl<'xml, T: FromXml<'xml>> FromXml<'xml> for Vec { @@ -771,3 +777,25 @@ impl<'xml> FromXml<'xml> for IpAddr { const KIND: Kind = Kind::Scalar; } + +#[cfg(test)] +mod tests { + use super::decode; + + #[test] + fn test_decode() { + assert_eq!(decode("foo").unwrap(), "foo"); + assert_eq!(decode("foo & bar").unwrap(), "foo & bar"); + assert_eq!(decode("foo < bar").unwrap(), "foo < bar"); + assert_eq!(decode("foo > bar").unwrap(), "foo > bar"); + assert_eq!(decode("foo " bar").unwrap(), "foo \" bar"); + assert_eq!(decode("foo ' bar").unwrap(), "foo ' bar"); + assert_eq!(decode("foo &lt; bar").unwrap(), "foo < bar"); + assert_eq!(decode("& foo").unwrap(), "& foo"); + assert_eq!(decode("foo &").unwrap(), "foo &"); + assert_eq!(decode("cbdtéda&sü").unwrap(), "cbdtéda&sü"); + assert!(decode("&foo;").is_err()); + assert!(decode("&foobar;").is_err()); + assert!(decode("cbdtéd&ü").is_err()); + } +} diff --git a/instant-xml/src/lib.rs b/instant-xml/src/lib.rs index 66efba0..aed6d8e 100644 --- a/instant-xml/src/lib.rs +++ b/instant-xml/src/lib.rs @@ -96,6 +96,8 @@ impl FromXmlOwned for T where T: for<'xml> FromXml<'xml> {} pub enum Error { #[error("format: {0}")] Format(#[from] fmt::Error), + #[error("invalid entity: {0}")] + InvalidEntity(String), #[error("parse: {0}")] Parse(#[from] xmlparser::Error), #[error("other: {0}")]