From 3ad6917ba0dd30361c17b99699f5edcd8732831f Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Tue, 6 Sep 2022 12:26:23 +0200 Subject: [PATCH] Separate Context out of Deserializer --- instant-xml-macros/src/de.rs | 73 +++-- instant-xml/src/de.rs | 522 ++++++++++++++--------------------- instant-xml/src/impls.rs | 70 +++-- instant-xml/src/lib.rs | 18 +- instant-xml/tests/de-ns.rs | 2 +- 5 files changed, 290 insertions(+), 395 deletions(-) diff --git a/instant-xml-macros/src/de.rs b/instant-xml-macros/src/de.rs index 831715b..2540ec3 100644 --- a/instant-xml-macros/src/de.rs +++ b/instant-xml-macros/src/de.rs @@ -115,7 +115,7 @@ impl Deserializer { let name = ident.to_string(); let mut out = TokenStream::new(); out.extend(quote!( - fn deserialize(deserializer: &mut ::instant_xml::Deserializer<'xml>) -> Result { + fn deserialize<'cx>(deserializer: &'cx mut ::instant_xml::Deserializer<'cx, 'xml>) -> Result { use ::instant_xml::de::{XmlRecord, Deserializer, Visitor}; use ::instant_xml::Error; @@ -137,31 +137,34 @@ impl Deserializer { impl #xml_impl_generics Visitor<'xml> for StructVisitor #xml_ty_generics #xml_where_clause { type Value = #ident #ty_generics; - fn visit_struct( - deserializer: &mut ::instant_xml::Deserializer<'xml>, - ) -> Result { - use ::instant_xml::de::Node; - + fn visit_struct<'cx>( + deserializer: &'cx mut ::instant_xml::Deserializer<'cx, 'xml>, + ) -> Result { #declare_values - while let Some(attr) = deserializer.peek_next_attribute()? { - let attr = { - #attributes_consts - match attr.id { - #attributes_names - _ => __Attributes::__Ignore - } + loop { + let node = match deserializer.next() { + Some(result) => result?, + None => break, }; - match attr { - #attr_type_match - __Attributes::__Ignore => {} - } - } - - while let Some(node) = deserializer.peek_next_tag()? { match node { - Node::Open { ns, name } => { - let id = ::instant_xml::Id { ns, name }; + XmlRecord::Attribute(attr) => { + let id = deserializer.attribute_id(&attr)?; + let field = { + #attributes_consts + match id { + #attributes_names + _ => __Attributes::__Ignore + } + }; + + match field { + #attr_type_match + __Attributes::__Ignore => {} + } + } + XmlRecord::Open(data) => { + let id = deserializer.element_id(&data)?; let element = { #elements_consts match id { @@ -173,31 +176,22 @@ impl Deserializer { match element { #elem_type_match __Elements::__Ignore => { - deserializer.ignore(id)?; + let mut nested = deserializer.nested(data); + nested.ignore()?; } } } - Node::Close { name } => { - if name == #name { - break; - } - }, - Node::Text { text } => panic!("Unexpected element"), + _ => return Err(Error::UnexpectedState), } } Ok(Self::Value { #return_val - }) + }) } } - #namespaces_map; - deserializer.deserialize_struct::( - #name, - #default_namespace, - &namespaces_map - ) + StructVisitor::visit_struct(deserializer) } )); @@ -277,7 +271,8 @@ impl Deserializer { panic!("duplicated value"); } - #enum_name = Some(<#no_lifetime_type>::deserialize(deserializer)?); + let mut nested = deserializer.nested(data); + #enum_name = Some(<#no_lifetime_type>::deserialize(&mut nested)?); }, )); } else { @@ -287,8 +282,8 @@ impl Deserializer { panic!("duplicated value"); } - deserializer.set_next_type_as_attribute()?; - #enum_name = Some(<#no_lifetime_type>::deserialize(deserializer)?); + let mut nested = deserializer.for_attr(attr); + #enum_name = Some(<#no_lifetime_type>::deserialize(&mut nested)?); }, )); } diff --git a/instant-xml/src/de.rs b/instant-xml/src/de.rs index c7f13bf..9a44b87 100644 --- a/instant-xml/src/de.rs +++ b/instant-xml/src/de.rs @@ -1,68 +1,134 @@ -use std::collections::HashMap; -use std::iter::Peekable; +use std::collections::{BTreeMap, VecDeque}; use super::{Error, Id}; use xmlparser::{ElementEnd, Token, Tokenizer}; -pub struct Deserializer<'xml> { - parser: Peekable>, - def_namespaces: HashMap<&'xml str, &'xml str>, - pub parser_namespaces: HashMap<&'xml str, &'xml str>, - def_default_namespace: &'xml str, - parser_default_namespace: &'xml str, - tag_attributes: Vec>, - next_type: EntityType, +pub struct Deserializer<'cx, 'xml> { + pub(crate) local: &'xml str, + prefix: Option<&'xml str>, + level: usize, + done: bool, + context: &'cx mut Context<'xml>, } -impl<'xml> Deserializer<'xml> { - pub fn new(input: &'xml str) -> Self { +impl<'cx, 'xml> Deserializer<'cx, 'xml> { + pub(crate) fn new(data: TagData<'xml>, context: &'cx mut Context<'xml>) -> Self { + let level = context.stack.len(); + context.stack.push(data.level); + Self { - parser: XmlParser::new(input).peekable(), - def_namespaces: std::collections::HashMap::new(), - parser_namespaces: std::collections::HashMap::new(), - def_default_namespace: "", - parser_default_namespace: "", - tag_attributes: Vec::new(), - next_type: EntityType::Element, + local: data.key, + prefix: data.prefix, + level, + done: false, + context, } } - pub fn peek_next_tag(&mut self) -> Result>, Error> { - let record = match self.parser.peek() { - Some(Ok(record)) => record, - Some(Err(err)) => return Err(err.clone()), - None => return Ok(None), - }; - - Ok(Some(match record { - XmlRecord::Open(TagData { - key, ns, prefix, .. - }) => { - let ns = match (ns, prefix) { - (_, Some(prefix)) => match self.parser_namespaces.get(prefix) { - Some(ns) => ns, - None => return Err(Error::WrongNamespace), - }, - (Some(ns), None) => ns, - (None, None) => self.parser_default_namespace, - }; - - Node::Open { ns, name: key } - } - XmlRecord::Element(text) => Node::Text { text }, - XmlRecord::Close(name) => Node::Close { name }, - })) + pub fn nested<'a>(&'a mut self, data: TagData<'xml>) -> Deserializer<'a, 'xml> + where + 'cx: 'a, + { + Deserializer::new(data, self.context) } - pub fn id(&self, item: &TagData<'xml>) -> Result, Error> { + pub fn for_attr<'a>(&'a mut self, attr: Attribute<'xml>) -> Deserializer<'a, 'xml> + where + 'cx: 'a, + { + self.context + .records + .push_front(XmlRecord::AttributeValue(attr.value)); + + Deserializer { + local: self.local, + prefix: self.prefix, + level: self.level, + done: self.done, + context: self.context, + } + } + + pub fn ignore(&mut self) -> Result<(), Error> { + loop { + match self.next() { + Some(Err(e)) => return Err(e), + Some(Ok(XmlRecord::Open(data))) => { + let mut nested = self.nested(data); + nested.ignore()?; + } + Some(_) => continue, + None => return Ok(()), + } + } + } + + #[inline] + pub fn element_id(&self, item: &TagData<'xml>) -> Result, Error> { + self.context.element_id(item) + } + + #[inline] + pub fn attribute_id(&self, attr: &Attribute<'xml>) -> Result, Error> { + self.context.attribute_id(attr) + } +} + +impl<'xml> Iterator for Deserializer<'_, 'xml> { + type Item = Result, Error>; + + fn next(&mut self) -> Option { + if self.done { + return None; + } + + let (prefix, local) = match self.context.next() { + Some(Ok(XmlRecord::Close { prefix, local })) => (prefix, local), + item => return item, + }; + + if self.context.stack.len() == self.level && local == self.local && prefix == self.prefix { + self.done = true; + return None; + } + + Some(Err(Error::UnexpectedState)) + } +} + +pub(crate) struct Context<'xml> { + parser: Tokenizer<'xml>, + stack: Vec>, + records: VecDeque>, +} + +impl<'xml> Context<'xml> { + pub(crate) fn new(input: &'xml str) -> Result<(Self, TagData<'xml>), Error> { + let mut new = Self { + parser: Tokenizer::from(input), + stack: Vec::new(), + records: VecDeque::new(), + }; + + let root = match new.next() { + Some(result) => match result? { + XmlRecord::Open(data) => data, + _ => return Err(Error::UnexpectedState), + }, + None => return Err(Error::UnexpectedEndOfStream), + }; + + Ok((new, root)) + } + + pub(crate) fn element_id(&self, item: &TagData<'xml>) -> Result, Error> { let ns = match (item.ns, item.prefix) { - (Some(_), Some(_)) => return Err(Error::WrongNamespace), - (Some(ns), None) => ns, - (None, Some(prefix)) => match self.parser_namespaces.get(prefix) { + (_, Some(prefix)) => match self.lookup(prefix) { Some(ns) => ns, None => return Err(Error::WrongNamespace), }, - (None, None) => "", + (Some(ns), None) => ns, + (None, None) => self.default_ns(), }; Ok(Id { @@ -71,276 +137,95 @@ impl<'xml> Deserializer<'xml> { }) } - pub fn peek_next_attribute(&self) -> Result>, Error> { - let attr = match self.tag_attributes.last() { - Some(attr) => attr, - None => return Ok(None), - }; - + fn attribute_id(&self, attr: &Attribute<'xml>) -> Result, Error> { let ns = match attr.prefix { - Some(key) => match self.parser_namespaces.get(key) { + Some(ns) => match self.lookup(ns) { Some(ns) => ns, None => return Err(Error::WrongNamespace), }, - None => self.parser_default_namespace, + None => self.default_ns(), }; - Ok(Some(AttributeNode { - id: Id { - ns, - name: attr.local, - }, - value: attr.value, - })) + Ok(Id { + ns, + name: &attr.local, + }) } - pub fn deserialize_struct( - &mut self, - name: &str, - def_default_namespace: &'xml str, - def_namespaces: &HashMap<&'xml str, &'xml str>, - ) -> Result - where - V: Visitor<'xml>, - { - // Saveing current defined default namespace - let def_default_namespace_to_revert = self.def_default_namespace; - self.def_default_namespace = def_default_namespace; - - // Adding struct defined namespaces - let new_def_namespaces = def_namespaces + fn default_ns(&self) -> &'xml str { + self.stack .iter() - .filter(|(k, v)| self.def_namespaces.insert(k, v).is_none()) - .collect::>(); + .rev() + .find_map(|level| level.default_ns) + .unwrap_or("") + } - // Process open tag - let tag_data = match self.parser.next() { - Some(Ok(XmlRecord::Open(item))) if item.key == name => item, - _ => return Err(Error::UnexpectedValue), - }; - - // Set current attributes - self.tag_attributes = tag_data.attributes; - - // Saveing current parser default namespace - let parser_default_namespace_to_revert = self.parser_default_namespace; - - // Set parser default namespace - match tag_data.ns { - Some(namespace) => { - self.parser_default_namespace = namespace; - } - None => { - // If there is no default namespace in the tag, check if parent default namespace equals the current one - if def_default_namespace_to_revert != self.def_default_namespace { - return Err(Error::WrongNamespace); - } - } - } - - // Compare parser namespace with defined one - if self.parser_default_namespace != self.def_default_namespace { - return Err(Error::WrongNamespace); - } - - // Adding parser namespaces - let new_parser_namespaces = tag_data - .prefixes + fn lookup(&self, prefix: &str) -> Option<&'xml str> { + self.stack .iter() - .filter(|(k, v)| self.parser_namespaces.insert(k, v).is_none()) - .collect::>(); - - let ret = V::visit_struct(self)?; - - // Process close tag - let item = match self.parser.next() { - Some(item) => item?, - None => return Err(Error::MissingTag), - }; - - match item { - XmlRecord::Close(v) if v == name => {} - _ => return Err(Error::UnexpectedTag), - } - - // Removing parser namespaces - let _ = new_parser_namespaces - .iter() - .map(|(k, _)| self.parser_namespaces.remove(*k)); - - // Removing struct defined namespaces - let _ = new_def_namespaces - .iter() - .map(|(k, _)| self.def_namespaces.remove(*k)); - - // Retriving old defined namespace - self.def_default_namespace = def_default_namespace_to_revert; - - // Retriving old parser namespace - self.parser_default_namespace = parser_default_namespace_to_revert; - Ok(ret) - } - - pub fn set_next_type_as_attribute(&mut self) -> Result<(), Error> { - if self.next_type == EntityType::Attribute { - return Err(Error::UnexpectedState); - } - - self.next_type = EntityType::Attribute; - Ok(()) - } - - pub fn consume_next_type(&mut self) -> EntityType { - let ret = self.next_type.clone(); - self.next_type = EntityType::Element; - ret - } - - pub(crate) fn deserialize_element(&mut self) -> Result - where - V: Visitor<'xml>, - { - // Process open tag - match self.parser.next() { - Some(Ok(XmlRecord::Open(_))) => {} - _ => return Err(Error::UnexpectedValue), - }; - - match self.parser.next() { - Some(Ok(XmlRecord::Element(v))) => { - let ret = V::visit_str(v); - self.parser.next(); - ret - } - _ => Err(Error::UnexpectedValue), - } - } - - pub(crate) fn deserialize_attribute>(&mut self) -> Result { - match self.tag_attributes.pop() { - Some(attr) => V::visit_str(attr.value), - None => Err(Error::UnexpectedEndOfStream), - } - } - - pub fn ignore(&mut self, id: Id<'xml>) -> Result<(), Error> { - let mut levels = 0; - while let Some(result) = self.parser.next() { - match result? { - XmlRecord::Open(item) => { - if self.id(&item)? == id { - levels += 1; - } - } - XmlRecord::Close(item) => { - if item == id.name { - levels -= 1; - if levels == 0 { - return Ok(()); - } - } - } - _ => {} - } - } - - Ok(()) + .rev() + .find_map(|level| level.prefixes.get(prefix).map(|ns| *ns)) } } -pub struct XmlParser<'xml> { - stack: Vec<&'xml str>, - iter: Peekable>, -} - -impl<'a> XmlParser<'a> { - pub fn new(input: &'a str) -> XmlParser<'a> { - XmlParser { - stack: Vec::new(), - iter: Tokenizer::from(input).peekable(), - } - } - - pub fn peek_next_tag(&mut self) -> Result>, Error> { - let item = match self.iter.peek() { - Some(v) => v, - None => return Ok(None), - }; - - match item { - Ok(Token::ElementStart { prefix, local, .. }) => { - let prefix = match prefix.is_empty() { - true => None, - false => Some(prefix.as_str()), - }; - - Ok(Some(XmlRecord::Open(TagData { - key: local.as_str(), - attributes: Vec::new(), - ns: Some(""), - prefixes: HashMap::new(), - prefix, - }))) - } - Ok(Token::ElementEnd { - end: ElementEnd::Close(..), - .. - }) => { - if self.stack.is_empty() { - return Err(Error::UnexpectedEndOfStream); - } - - return Ok(Some(XmlRecord::Close(self.stack.last().unwrap()))); - } - Ok(_) => Err(Error::UnexpectedToken), - Err(e) => Err(Error::Parse(*e)), - } - } -} - -impl<'xml> Iterator for XmlParser<'xml> { +impl<'xml> Iterator for Context<'xml> { type Item = Result, Error>; - #[inline] fn next(&mut self) -> Option { - let mut key: Option<&str> = None; - let mut prefix_ret: Option<&str> = None; - let mut default_namespace = None; - let mut namespaces = HashMap::new(); - let mut attributes = Vec::new(); + if let Some(record) = self.records.pop_front() { + return Some(Ok(record)); + } + let mut current = None; loop { - let token = match self.iter.next() { + let token = match self.parser.next() { Some(v) => v, None => return None, }; match token { Ok(Token::ElementStart { prefix, local, .. }) => { - key = Some(local.as_str()); - prefix_ret = match prefix.is_empty() { - true => None, - false => Some(prefix.as_str()), - }; + let prefix = prefix.as_str(); + current = Some(Level { + local: local.as_str(), + prefix: (!prefix.is_empty()).then_some(prefix), + default_ns: None, + prefixes: BTreeMap::new(), + }); } Ok(Token::ElementEnd { end, .. }) => match end { ElementEnd::Open => { - self.stack.push(key.unwrap()); + let level = match current { + Some(level) => level, + None => return Some(Err(Error::UnexpectedState)), + }; - return Some(Ok(XmlRecord::Open(TagData { - key: key.unwrap(), - attributes, - ns: default_namespace, - prefixes: namespaces, - prefix: prefix_ret, - }))); + let data = TagData { + key: level.local, + prefix: level.prefix, + ns: level.default_ns, + level, + }; + + return Some(Ok(XmlRecord::Open(data))); } - ElementEnd::Close(_, v) => match self.stack.pop() { - Some(last) if last == v.as_str() => { - return Some(Ok(XmlRecord::Close(last))); + ElementEnd::Close(prefix, v) => { + let level = match self.stack.pop() { + Some(level) => level, + None => return Some(Err(Error::UnexpectedState)), + }; + + let prefix = (!prefix.is_empty()).then_some(prefix.as_str()); + match v.as_str() == level.local && prefix == level.prefix { + true => { + return Some(Ok(XmlRecord::Close { + prefix, + local: level.local, + })) + } + false => return Some(Err(Error::UnexpectedState)), } - _ => return Some(Err(Error::UnexpectedValue)), - }, + } ElementEnd::Empty => { todo!(); } @@ -352,18 +237,24 @@ impl<'xml> Iterator for XmlParser<'xml> { .. }) => { if prefix.is_empty() && local.as_str() == "xmlns" { - // Default namespace - default_namespace = Some(value.as_str()); + match &mut current { + Some(level) => level.default_ns = Some(value.as_str()), + None => return Some(Err(Error::UnexpectedState)), + } } else if prefix.as_str() == "xmlns" { - // Namespaces - namespaces.insert(local.as_str(), value.as_str()); + match &mut current { + Some(level) => { + level.prefixes.insert(local.as_str(), value.as_str()); + } + None => return Some(Err(Error::UnexpectedState)), + } } else { let prefix = (!prefix.is_empty()).then_some(prefix.as_str()); - attributes.push(Attribute { + self.records.push_back(XmlRecord::Attribute(Attribute { prefix, local: local.as_str(), value: value.as_str(), - }); + })); } } Ok(Token::Text { text }) => { @@ -383,30 +274,39 @@ pub trait Visitor<'xml>: Sized { unimplemented!(); } - fn visit_struct(_deserializer: &mut Deserializer<'xml>) -> Result { + fn visit_struct<'cx>( + _deserializer: &'cx mut Deserializer<'cx, 'xml>, + ) -> Result { unimplemented!(); } } #[derive(Debug)] pub enum XmlRecord<'xml> { - Open(TagData<'xml>), + Attribute(Attribute<'xml>), + AttributeValue(&'xml str), + Close { + prefix: Option<&'xml str>, + local: &'xml str, + }, Element(&'xml str), - Close(&'xml str), + Open(TagData<'xml>), } #[derive(Debug)] pub struct TagData<'xml> { - pub key: &'xml str, - pub attributes: Vec>, - pub ns: Option<&'xml str>, - pub prefixes: HashMap<&'xml str, &'xml str>, - pub prefix: Option<&'xml str>, + key: &'xml str, + ns: Option<&'xml str>, + prefix: Option<&'xml str>, + level: Level<'xml>, } -pub struct AttributeNode<'xml> { - pub id: Id<'xml>, - pub value: &'xml str, +#[derive(Debug)] +struct Level<'xml> { + local: &'xml str, + prefix: Option<&'xml str>, + default_ns: Option<&'xml str>, + prefixes: BTreeMap<&'xml str, &'xml str>, } #[derive(Debug)] @@ -415,15 +315,3 @@ pub struct Attribute<'xml> { pub local: &'xml str, pub value: &'xml str, } - -pub enum Node<'xml> { - Open { ns: &'xml str, name: &'xml str }, - Close { name: &'xml str }, - Text { text: &'xml str }, -} - -#[derive(Clone, PartialEq, Eq)] -pub enum EntityType { - Element, - Attribute, -} diff --git a/instant-xml/src/impls.rs b/instant-xml/src/impls.rs index 208e852..9a1908d 100644 --- a/instant-xml/src/impls.rs +++ b/instant-xml/src/impls.rs @@ -3,7 +3,7 @@ use std::fmt; use std::marker::PhantomData; use std::str::FromStr; -use crate::de::{EntityType, Visitor}; +use crate::de::{Visitor, XmlRecord}; use crate::{Deserializer, Error, FieldAttribute, FromXml, Kind, Serializer, ToXml}; // Deserializer @@ -12,7 +12,7 @@ where T: FromStr, ::Err: std::fmt::Display; -impl<'xml, T> Visitor<'xml> for FromStrToVisitor +impl<'xml, T: 'xml> Visitor<'xml> for FromStrToVisitor where T: FromStr, ::Err: std::fmt::Display, @@ -40,11 +40,8 @@ impl<'xml> Visitor<'xml> for BoolVisitor { impl<'xml> FromXml<'xml> for bool { const KIND: Kind = Kind::Scalar; - fn deserialize(deserializer: &mut Deserializer) -> Result { - match deserializer.consume_next_type() { - EntityType::Element => deserializer.deserialize_element::(), - EntityType::Attribute => deserializer.deserialize_attribute::(), - } + fn deserialize(deserializer: &mut Deserializer<'_, 'xml>) -> Result { + deserialize_scalar::(deserializer) } } @@ -99,7 +96,7 @@ where marker: PhantomData, } -impl<'xml, T> Visitor<'xml> for NumberVisitor +impl<'xml, T: 'xml> Visitor<'xml> for NumberVisitor where T: FromStr, ::Err: std::fmt::Display, @@ -115,14 +112,7 @@ macro_rules! from_xml_for_number { ($typ:ty) => { impl<'xml> FromXml<'xml> for $typ { fn deserialize(deserializer: &mut Deserializer) -> Result { - match deserializer.consume_next_type() { - EntityType::Element => { - deserializer.deserialize_element::>() - } - EntityType::Attribute => { - deserializer.deserialize_attribute::>() - } - } + deserialize_scalar::>(deserializer) } const KIND: Kind = Kind::Scalar; @@ -157,11 +147,7 @@ impl<'xml> FromXml<'xml> for String { const KIND: Kind = Kind::Scalar; fn deserialize(deserializer: &mut Deserializer) -> Result { - //<&'xml str>::deserialize(deserializer); - match deserializer.consume_next_type() { - EntityType::Element => deserializer.deserialize_element::(), - EntityType::Attribute => deserializer.deserialize_attribute::(), - } + deserialize_scalar::(deserializer) } } @@ -182,10 +168,7 @@ impl<'xml> FromXml<'xml> for char { const KIND: Kind = Kind::Scalar; fn deserialize(deserializer: &mut Deserializer) -> Result { - match deserializer.consume_next_type() { - EntityType::Element => deserializer.deserialize_element::(), - EntityType::Attribute => deserializer.deserialize_attribute::(), - } + deserialize_scalar::(deserializer) } } @@ -205,11 +188,8 @@ impl<'a> Visitor<'a> for StrVisitor { impl<'xml> FromXml<'xml> for &'xml str { const KIND: Kind = Kind::Scalar; - fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result { - match deserializer.consume_next_type() { - EntityType::Element => deserializer.deserialize_element::(), - EntityType::Attribute => deserializer.deserialize_attribute::(), - } + fn deserialize(deserializer: &mut Deserializer<'_, 'xml>) -> Result { + deserialize_scalar::(deserializer) } } @@ -226,11 +206,8 @@ impl<'a> Visitor<'a> for CowStrVisitor { impl<'xml> FromXml<'xml> for Cow<'xml, str> { const KIND: Kind = Kind::Scalar; - fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result { - match deserializer.consume_next_type() { - EntityType::Element => deserializer.deserialize_element::(), - EntityType::Attribute => deserializer.deserialize_attribute::(), - } + fn deserialize(deserializer: &mut Deserializer<'_, 'xml>) -> Result { + deserialize_scalar::(deserializer) } } @@ -240,7 +217,7 @@ where { const KIND: Kind = ::KIND; - fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result { + fn deserialize<'cx>(deserializer: &'cx mut Deserializer<'cx, 'xml>) -> Result { match ::deserialize(deserializer) { Ok(v) => Ok(Some(v)), Err(e) => Err(e), @@ -384,6 +361,27 @@ impl ToXml for Option { } } +fn deserialize_scalar<'xml, V: Visitor<'xml>>( + deserializer: &mut Deserializer<'_, 'xml>, +) -> Result +where + V::Value: FromXml<'xml>, +{ + let value = match deserializer.next() { + Some(Ok(XmlRecord::AttributeValue(s))) => return V::visit_str(s), + Some(Ok(XmlRecord::Element(s))) => V::visit_str(s)?, + Some(Ok(_)) => return Err(Error::ExpectedScalar), + Some(Err(e)) => return Err(e), + None => return >::missing_value(), + }; + + match deserializer.next() { + Some(Ok(_)) => Err(Error::UnexpectedState), + Some(Err(e)) => Err(e), + None => Ok(value), + } +} + fn escape(input: &str) -> Result, Error> { let mut result = String::with_capacity(input.len()); let mut last_end = 0; diff --git a/instant-xml/src/lib.rs b/instant-xml/src/lib.rs index 44c3a52..99e8577 100644 --- a/instant-xml/src/lib.rs +++ b/instant-xml/src/lib.rs @@ -7,6 +7,7 @@ pub use macros::{FromXml, ToXml}; #[doc(hidden)] pub mod de; mod impls; +use de::Context; pub use de::Deserializer; #[doc(hidden)] pub mod ser; @@ -26,7 +27,7 @@ pub enum FieldAttribute<'xml> { } pub trait FromXml<'xml>: Sized { - fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result; + fn deserialize<'cx>(deserializer: &'cx mut Deserializer<'cx, 'xml>) -> Result; // If the missing field is of type `Option` then treat is as `None`, // otherwise it is an error. @@ -38,7 +39,18 @@ pub trait FromXml<'xml>: Sized { } pub fn from_str<'xml, T: FromXml<'xml>>(input: &'xml str) -> Result { - T::deserialize(&mut Deserializer::new(input)) + let (mut context, root) = Context::new(input)?; + let id = context.element_id(&root)?; + let expected = match T::KIND { + Kind::Scalar => return Err(Error::UnexpectedState), + Kind::Element(expected) => expected, + }; + + if id != expected { + return Err(Error::UnexpectedValue); + } + + T::deserialize(&mut Deserializer::new(root, &mut context)) } pub fn to_string(value: &(impl ToXml + ?Sized)) -> Result { @@ -100,6 +112,8 @@ pub enum Error { MissingdPrefix, #[error("unexpected state")] UnexpectedState, + #[error("expected scalar")] + ExpectedScalar, #[error("wrong namespace")] WrongNamespace, } diff --git a/instant-xml/tests/de-ns.rs b/instant-xml/tests/de-ns.rs index 7e39b6d..d0e931b 100644 --- a/instant-xml/tests/de-ns.rs +++ b/instant-xml/tests/de-ns.rs @@ -39,7 +39,7 @@ fn default_namespaces() { from_str( "true" ), - Err::(Error::WrongNamespace) + Err::(Error::UnexpectedValue) ); // Correct child namespace