Replace decode() function with an explicit state machine

This commit is contained in:
Dirkjan Ochtman 2023-02-22 11:15:06 +01:00
parent 93d7c3d572
commit b276ee173f
3 changed files with 84 additions and 54 deletions

View File

@ -341,7 +341,7 @@ pub fn borrow_cow_str<'xml>(
} }
if let Some(value) = deserializer.take_str()? { if let Some(value) = deserializer.take_str()? {
*into = Some(decode(value)); *into = Some(decode(value)?);
}; };
deserializer.ignore()?; deserializer.ignore()?;
@ -357,7 +357,7 @@ pub fn borrow_cow_slice_u8<'xml>(
} }
if let Some(value) = deserializer.take_str()? { if let Some(value) = deserializer.take_str()? {
*into = Some(match decode(value) { *into = Some(match decode(value)? {
Cow::Borrowed(v) => Cow::Borrowed(v.as_bytes()), Cow::Borrowed(v) => Cow::Borrowed(v.as_bytes()),
Cow::Owned(v) => Cow::Owned(v.into_bytes()), Cow::Owned(v) => Cow::Owned(v.into_bytes()),
}); });

View File

@ -260,7 +260,7 @@ impl<'xml> FromXml<'xml> for String {
} }
match deserializer.take_str()? { match deserializer.take_str()? {
Some(value) => *into = Some(decode(value).into_owned()), Some(value) => *into = Some(decode(value)?.into_owned()),
None => return Ok(()), None => return Ok(()),
} }
@ -292,7 +292,7 @@ impl<'xml> FromXml<'xml> for &'xml str {
None => return Ok(()), None => return Ok(()),
}; };
match decode(value) { match decode(value)? {
Cow::Borrowed(str) => *into = Some(str), Cow::Borrowed(str) => *into = Some(str),
Cow::Owned(_) => { Cow::Owned(_) => {
return Err(Error::UnexpectedValue(format!( return Err(Error::UnexpectedValue(format!(
@ -484,63 +484,69 @@ fn encode(input: &str) -> Result<Cow<'_, str>, Error> {
Ok(Cow::Owned(result)) Ok(Cow::Owned(result))
} }
pub(crate) fn decode(input: &str) -> Cow<'_, str> { pub(crate) fn decode(input: &str) -> Result<Cow<'_, str>, Error> {
let mut result = String::with_capacity(input.len()); let mut result = String::with_capacity(input.len());
let input_len = input.len(); let (mut state, mut last_end) = (DecodeState::Normal, 0);
for (i, &b) in input.as_bytes().iter().enumerate() {
// use a state machine to find entities
state = match (state, b) {
(DecodeState::Normal, b'&') => DecodeState::Entity([0; 4], 0),
(DecodeState::Normal, _) => DecodeState::Normal,
(DecodeState::Entity(chars, len), b';') => {
let decoded = match chars[..len] {
[b'a', b'm', b'p'] => '&',
[b'a', b'p', b'o', b's'] => '\'',
[b'g', b't'] => '>',
[b'l', b't'] => '<',
[b'q', b'u', b'o', b't'] => '"',
_ => {
return Err(Error::InvalidEntity(
String::from_utf8_lossy(&chars[..len]).into_owned(),
))
}
};
let mut last_end = 0; let start = i - (len + 1); // current position - (length of entity characters + 1 for '&')
while input_len - last_end >= 4 { if last_end < start {
if input.is_char_boundary(last_end + 4) { // Unwrap should be safe: `last_end` and `start` must be at character boundaries.
match &input[last_end..(last_end + 4)] { result.push_str(input.get(last_end..start).unwrap());
"&lt;" => {
result.push('<');
last_end += 4;
continue;
} }
"&gt;" => {
result.push('>'); last_end = i + 1;
last_end += 4; result.push(decoded);
continue; DecodeState::Normal
}
(DecodeState::Entity(mut chars, len), b) => {
if len >= 4 {
let mut bytes = Vec::with_capacity(5);
bytes.extend(&chars[..len]);
bytes.push(b);
return Err(Error::InvalidEntity(
String::from_utf8_lossy(&bytes).into_owned(),
));
}
chars[len] = b;
DecodeState::Entity(chars, len + 1)
} }
_ => (),
}; };
} }
if input_len - last_end >= 5 && input.is_char_boundary(last_end + 5) { Ok(match result.is_empty() {
if &input[last_end..(last_end + 5)] == "&amp;" { true => Cow::Borrowed(input),
result.push('&'); false => {
last_end += 5; // Unwrap should be safe: `last_end` and `input.len()` must be at character boundaries.
continue; result.push_str(input.get(last_end..input.len()).unwrap());
}
if input_len - last_end >= 6 && input.is_char_boundary(last_end + 6) {
match &input[last_end..(last_end + 6)] {
"&apos;" => {
result.push('\'');
last_end += 6;
continue;
}
"&quot;" => {
result.push('"');
last_end += 6;
continue;
}
_ => (),
};
}
}
result.push_str(input.get(last_end..last_end + 1).unwrap());
last_end += 1;
}
result.push_str(input.get(last_end..).unwrap());
if result.len() == input.len() {
return Cow::Borrowed(input);
}
Cow::Owned(result) Cow::Owned(result)
} }
})
}
#[derive(Debug)]
enum DecodeState {
Normal,
Entity([u8; 4], usize),
}
impl<'xml, T: FromXml<'xml>> FromXml<'xml> for Vec<T> { impl<'xml, T: FromXml<'xml>> FromXml<'xml> for Vec<T> {
#[inline] #[inline]
@ -771,3 +777,25 @@ impl<'xml> FromXml<'xml> for IpAddr {
const KIND: Kind = Kind::Scalar; const KIND: Kind = Kind::Scalar;
} }
#[cfg(test)]
mod tests {
use super::decode;
#[test]
fn test_decode() {
assert_eq!(decode("foo").unwrap(), "foo");
assert_eq!(decode("foo &amp; bar").unwrap(), "foo & bar");
assert_eq!(decode("foo &lt; bar").unwrap(), "foo < bar");
assert_eq!(decode("foo &gt; bar").unwrap(), "foo > bar");
assert_eq!(decode("foo &quot; bar").unwrap(), "foo \" bar");
assert_eq!(decode("foo &apos; bar").unwrap(), "foo ' bar");
assert_eq!(decode("foo &amp;lt; bar").unwrap(), "foo &lt; bar");
assert_eq!(decode("&amp; foo").unwrap(), "& foo");
assert_eq!(decode("foo &amp;").unwrap(), "foo &");
assert_eq!(decode("cbdtéda&amp;sü").unwrap(), "cbdtéda&sü");
assert!(decode("&foo;").is_err());
assert!(decode("&foobar;").is_err());
assert!(decode("cbdtéd&ampü").is_err());
}
}

View File

@ -96,6 +96,8 @@ impl<T> FromXmlOwned for T where T: for<'xml> FromXml<'xml> {}
pub enum Error { pub enum Error {
#[error("format: {0}")] #[error("format: {0}")]
Format(#[from] fmt::Error), Format(#[from] fmt::Error),
#[error("invalid entity: {0}")]
InvalidEntity(String),
#[error("parse: {0}")] #[error("parse: {0}")]
Parse(#[from] xmlparser::Error), Parse(#[from] xmlparser::Error),
#[error("other: {0}")] #[error("other: {0}")]