diff --git a/instant-xml-macros/src/de.rs b/instant-xml-macros/src/de.rs index a9feeaa..e247578 100644 --- a/instant-xml-macros/src/de.rs +++ b/instant-xml-macros/src/de.rs @@ -1,5 +1,5 @@ use proc_macro2::{Ident, Span, TokenStream}; -use quote::quote; +use quote::{quote, ToTokens}; use crate::{get_namespaces, retrieve_field_attribute, FieldAttribute}; @@ -34,6 +34,25 @@ impl quote::ToTokens for Deserializer { impl Deserializer { pub fn new(input: &syn::DeriveInput) -> Deserializer { let ident = &input.ident; + let generics = (&input.generics).into_token_stream(); + let lifetimes = (&input.generics.params).into_token_stream(); + + let mut lifetime_xml = TokenStream::new(); + let mut lifetime_visitor = TokenStream::new(); + let iter = &mut input.generics.params.iter(); + if let Some(it) = iter.next() { + lifetime_xml = quote!(:); + lifetime_xml.extend(it.into_token_stream()); + while let Some(it) = iter.by_ref().next() { + lifetime_xml.extend(quote!(+)); + lifetime_xml.extend(it.into_token_stream()); + } + lifetime_xml.extend(quote!(,)); + lifetime_xml.extend(lifetimes.clone()); + lifetime_visitor.extend(quote!(,)); + lifetime_visitor.extend(lifetimes); + } + let name = ident.to_string(); let mut out = TokenStream::new(); @@ -113,7 +132,7 @@ impl Deserializer { let attr_type_match = attributes_tokens.match_; out.extend(quote!( - fn deserialize(deserializer: &mut ::instant_xml::Deserializer) -> Result { + fn deserialize(deserializer: &mut ::instant_xml::Deserializer<'xml>) -> Result { use ::instant_xml::parse::XmlRecord; use ::instant_xml::{Error, Deserializer, Visitor} ; @@ -143,12 +162,18 @@ impl Deserializer { } } - struct StructVisitor; - impl<'xml> Visitor<'xml> for StructVisitor { - type Value = #ident; + struct StructVisitor<'xml #lifetime_xml> { + marker: std::marker::PhantomData<#ident #generics>, + lifetime: std::marker::PhantomData<&'xml ()>, + } - fn visit_struct<'a>(&self, deserializer: &mut ::instant_xml::Deserializer) -> Result - { + impl<'xml #lifetime_xml> Visitor<'xml> for StructVisitor<'xml #lifetime_visitor> { + type Value = #ident #generics; + + fn visit_struct( + &self, + deserializer: &mut ::instant_xml::Deserializer<'xml> + ) -> Result { #declare_values while let Some(( key, _ )) = deserializer.peek_next_attribute() { match get_attribute(&key) { @@ -181,7 +206,15 @@ impl Deserializer { } #namespaces_map; - deserializer.deserialize_struct(StructVisitor{}, #name, #default_namespace, &namespaces_map) + deserializer.deserialize_struct( + StructVisitor{ + marker: std::marker::PhantomData, + lifetime: std::marker::PhantomData + }, + #name, + #default_namespace, + &namespaces_map + ) } )); @@ -189,6 +222,12 @@ impl Deserializer { const TAG_NAME: ::instant_xml::TagName<'xml> = ::instant_xml::TagName::Custom(#name); )); + out = quote!( + impl<'xml #lifetime_xml> FromXml<'xml> for #ident #generics { + #out + } + ); + Deserializer { out } } @@ -206,16 +245,14 @@ impl Deserializer { let field_var = field.ident.as_ref().unwrap(); let field_var_str = field_var.to_string(); let const_field_var_str = Ident::new(&field_var_str.to_uppercase(), Span::call_site()); - let field_type = match &field.ty { - syn::Type::Path(v) => v.path.get_ident(), - _ => panic!("Wrong field attribute format"), - }; + let mut no_lifetime_type = field.ty.clone(); + discard_lifetimes(&mut no_lifetime_type); let enum_name = Ident::new(&format!("__Value{index}"), Span::call_site()); tokens.enum_.extend(quote!(#enum_name,)); tokens.consts.extend(quote!( - const #const_field_var_str: &str = match #field_type::TAG_NAME { + const #const_field_var_str: &str = match <#no_lifetime_type>::TAG_NAME { ::instant_xml::TagName::FieldName => #field_var_str, ::instant_xml::TagName::Custom(v) => v, }; @@ -232,7 +269,7 @@ impl Deserializer { } declare_values.extend(quote!( - let mut #enum_name: Option<#field_type> = None; + let mut #enum_name: Option<#no_lifetime_type> = None; )); let def_prefix = match def_prefix { @@ -283,7 +320,7 @@ impl Deserializer { } #field_namespace deserializer.set_next_def_namespace(field_namespace)?; - #enum_name = Some(#field_type::deserialize(deserializer)?); + #enum_name = Some(<#no_lifetime_type>::deserialize(deserializer)?); }, )); } else { @@ -294,13 +331,53 @@ impl Deserializer { } deserializer.set_next_type_as_attribute()?; - #enum_name = Some(#field_type::deserialize(deserializer)?); + #enum_name = Some(<#no_lifetime_type>::deserialize(deserializer)?); }, )); } return_val.extend(quote!( - #field_var: #enum_name.expect("Expected some value"), + #field_var: match #enum_name { + Some(v) => v, + None => <#no_lifetime_type>::missing_value()?, + }, )); } } + +fn discard_lifetimes(ty: &mut syn::Type) { + match ty { + syn::Type::Path(ty) => discard_path_lifetimes(ty), + syn::Type::Reference(ty) => { + ty.lifetime = None; + discard_lifetimes(&mut ty.elem); + } + _ => {} + } +} + +fn discard_path_lifetimes(path: &mut syn::TypePath) { + if let Some(q) = &mut path.qself { + discard_lifetimes(&mut q.ty); + } + + for segment in &mut path.path.segments { + match &mut segment.arguments { + syn::PathArguments::None => {} + syn::PathArguments::AngleBracketed(args) => { + args.args.iter_mut().for_each(|arg| match arg { + syn::GenericArgument::Lifetime(lt) => { + *lt = syn::Lifetime::new("'_", Span::call_site()) + } + syn::GenericArgument::Type(ty) => discard_lifetimes(ty), + syn::GenericArgument::Binding(_) + | syn::GenericArgument::Constraint(_) + | syn::GenericArgument::Const(_) => {} + }) + } + syn::PathArguments::Parenthesized(args) => { + args.inputs.iter_mut().for_each(discard_lifetimes) + } + } + } +} diff --git a/instant-xml-macros/src/lib.rs b/instant-xml-macros/src/lib.rs index 6971722..3520d03 100644 --- a/instant-xml-macros/src/lib.rs +++ b/instant-xml-macros/src/lib.rs @@ -168,12 +168,9 @@ pub fn to_xml(input: proc_macro::TokenStream) -> proc_macro::TokenStream { #[proc_macro_derive(FromXml, attributes(xml))] pub fn from_xml(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); - let ident = &ast.ident; - let deserializer = de::Deserializer::new(&ast); + proc_macro::TokenStream::from(quote!( - impl<'xml> FromXml<'xml> for #ident { - #deserializer - } + #deserializer )) } diff --git a/instant-xml/src/impls.rs b/instant-xml/src/impls.rs index 433e15a..48cb40d 100644 --- a/instant-xml/src/impls.rs +++ b/instant-xml/src/impls.rs @@ -1,17 +1,25 @@ use std::borrow::Cow; use std::fmt; +use std::marker::PhantomData; use std::str::FromStr; use crate::{ Deserializer, EntityType, Error, FieldAttribute, FromXml, Serializer, TagName, ToXml, Visitor, }; -struct BoolVisitor; +// Deserializer +struct FromStrToVisitor(PhantomData) +where + T: FromStr, + ::Err: std::fmt::Display; -impl<'de> Visitor<'de> for BoolVisitor { - type Value = bool; - - fn visit_str<'a>(self, value: &str) -> Result { +impl<'xml, T> Visitor<'xml> for FromStrToVisitor +where + T: FromStr, + ::Err: std::fmt::Display, +{ + type Value = T; + fn visit_str(self, value: &str) -> Result { match FromStr::from_str(value) { Ok(v) => Ok(v), Err(e) => Err(Error::Other(e.to_string())), @@ -19,12 +27,22 @@ impl<'de> Visitor<'de> for BoolVisitor { } } +struct BoolVisitor; + +impl<'xml> Visitor<'xml> for BoolVisitor { + type Value = bool; + + fn visit_str(self, value: &str) -> Result { + FromStrToVisitor(PhantomData::).visit_str(value) + } +} + impl<'xml> FromXml<'xml> for bool { const TAG_NAME: TagName<'xml> = TagName::FieldName; fn deserialize(deserializer: &mut Deserializer) -> Result { match deserializer.consume_next_type() { - EntityType::Element => deserializer.deserialize_bool(BoolVisitor), + EntityType::Element => deserializer.deserialize_element(BoolVisitor), EntityType::Attribute => deserializer.deserialize_attribute(BoolVisitor), } } @@ -73,6 +91,223 @@ macro_rules! to_xml_for_number { }; } +struct NumberVisitor +where + T: FromStr, + ::Err: std::fmt::Display, +{ + marker: PhantomData, +} + +impl<'xml, T> Visitor<'xml> for NumberVisitor +where + T: FromStr, + ::Err: std::fmt::Display, +{ + type Value = T; + + fn visit_str(self, value: &str) -> Result { + FromStrToVisitor(PhantomData::).visit_str(value) + } +} + +macro_rules! from_xml_for_number { + ($typ:ty) => { + impl<'xml> FromXml<'xml> for $typ { + const TAG_NAME: TagName<'xml> = TagName::FieldName; + + fn deserialize(deserializer: &mut Deserializer) -> Result { + match deserializer.consume_next_type() { + EntityType::Element => deserializer.deserialize_element(NumberVisitor { + marker: PhantomData, + }), + EntityType::Attribute => deserializer.deserialize_attribute(NumberVisitor { + marker: PhantomData, + }), + } + } + } + }; +} + +from_xml_for_number!(i8); +from_xml_for_number!(i16); +from_xml_for_number!(i32); +from_xml_for_number!(i64); +from_xml_for_number!(isize); +from_xml_for_number!(u8); +from_xml_for_number!(u16); +from_xml_for_number!(u32); +from_xml_for_number!(u64); +from_xml_for_number!(usize); +from_xml_for_number!(f32); +from_xml_for_number!(f64); + +struct StringVisitor; + +impl<'xml> Visitor<'xml> for StringVisitor { + type Value = String; + + fn visit_str(self, value: &str) -> Result { + Ok(escape_back(value).into_owned()) + } +} + +impl<'xml> FromXml<'xml> for String { + const TAG_NAME: TagName<'xml> = TagName::FieldName; + + fn deserialize(deserializer: &mut Deserializer) -> Result { + //<&'xml str>::deserialize(deserializer); + match deserializer.consume_next_type() { + EntityType::Element => deserializer.deserialize_element(StringVisitor), + EntityType::Attribute => deserializer.deserialize_attribute(StringVisitor), + } + } +} + +struct CharVisitor; + +impl<'xml> Visitor<'xml> for CharVisitor { + type Value = char; + + fn visit_str(self, value: &str) -> Result { + match value.len() { + 1 => Ok(value.chars().next().expect("char type")), + _ => Err(Error::Other("Expected char type".to_string())), + } + } +} + +impl<'xml> FromXml<'xml> for char { + const TAG_NAME: TagName<'xml> = TagName::FieldName; + + fn deserialize(deserializer: &mut Deserializer) -> Result { + match deserializer.consume_next_type() { + EntityType::Element => deserializer.deserialize_element(CharVisitor), + EntityType::Attribute => deserializer.deserialize_attribute(CharVisitor), + } + } +} + +struct StrVisitor; + +impl<'a> Visitor<'a> for StrVisitor { + type Value = &'a str; + + fn visit_str(self, value: &'a str) -> Result { + match escape_back(value) { + Cow::Owned(v) => Err(Error::Other(format!("Unsupported char: {}", v))), + Cow::Borrowed(v) => Ok(v), + } + } +} + +impl<'xml> FromXml<'xml> for &'xml str { + const TAG_NAME: TagName<'xml> = TagName::FieldName; + + fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result { + match deserializer.consume_next_type() { + EntityType::Element => deserializer.deserialize_element(StrVisitor), + EntityType::Attribute => deserializer.deserialize_attribute(StrVisitor), + } + } +} + +struct CowStrVisitor; + +impl<'a> Visitor<'a> for CowStrVisitor { + type Value = Cow<'a, str>; + + fn visit_str(self, value: &'a str) -> Result { + Ok(escape_back(value)) + } +} + +impl<'xml> FromXml<'xml> for Cow<'xml, str> { + const TAG_NAME: TagName<'xml> = <&str>::TAG_NAME; + + fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result { + match deserializer.consume_next_type() { + EntityType::Element => deserializer.deserialize_element(CowStrVisitor), + EntityType::Attribute => deserializer.deserialize_attribute(CowStrVisitor), + } + } +} + +impl<'xml, T> FromXml<'xml> for Option +where + T: FromXml<'xml>, +{ + const TAG_NAME: TagName<'xml> = ::TAG_NAME; + + fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result { + match ::deserialize(deserializer) { + Ok(v) => Ok(Some(v)), + Err(e) => Err(e), + } + } + + fn missing_value() -> Result { + Ok(None) + } +} + +fn escape_back(input: &str) -> Cow<'_, str> { + let mut result = String::with_capacity(input.len()); + let input_len = input.len(); + + let mut last_end = 0; + while input_len - last_end >= 4 { + match &input[last_end..(last_end + 4)] { + "<" => { + result.push('<'); + last_end += 4; + continue; + } + ">" => { + result.push('>'); + last_end += 4; + continue; + } + _ => (), + }; + + if input_len - last_end >= 5 { + if &input[last_end..(last_end + 5)] == "&" { + result.push('&'); + last_end += 5; + continue; + } + + if input_len - last_end >= 6 { + match &input[last_end..(last_end + 6)] { + "'" => { + result.push('\''); + last_end += 6; + continue; + } + """ => { + result.push('"'); + last_end += 6; + continue; + } + _ => (), + }; + } + } + + result.push_str(input.get(last_end..last_end + 1).unwrap()); + last_end += 1; + } + + result.push_str(input.get(last_end..).unwrap()); + if result.len() == input.len() { + return Cow::Borrowed(input); + } + + Cow::Owned(result) +} + to_xml_for_number!(i8); to_xml_for_number!(i16); to_xml_for_number!(i32); diff --git a/instant-xml/src/lib.rs b/instant-xml/src/lib.rs index 83b2cc3..26df04b 100644 --- a/instant-xml/src/lib.rs +++ b/instant-xml/src/lib.rs @@ -198,22 +198,28 @@ pub enum TagName<'xml> { pub trait FromXml<'xml>: Sized { const TAG_NAME: TagName<'xml>; - fn from_xml(input: &str) -> Result { + fn from_xml(input: &'xml str) -> Result { let mut deserializer = Deserializer::new(input); Self::deserialize(&mut deserializer) } - fn deserialize(deserializer: &mut Deserializer) -> Result; + fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result; + + // If the missing field is of type `Option` then treat is as `None`, + // otherwise it is an error. + fn missing_value() -> Result { + Err(Error::MissingValue) + } } pub trait Visitor<'xml>: Sized { type Value; - fn visit_str(self, _value: &str) -> Result { + fn visit_str(self, _value: &'xml str) -> Result { unimplemented!(); } - fn visit_struct<'a>(&self, _deserializer: &'a mut Deserializer) -> Result { + fn visit_struct(&self, _deserializer: &mut Deserializer<'xml>) -> Result { unimplemented!(); } } @@ -373,7 +379,7 @@ impl<'xml> Deserializer<'xml> { ret } - fn deserialize_bool(&mut self, visitor: V) -> Result + fn deserialize_element(&mut self, visitor: V) -> Result where V: Visitor<'xml>, { @@ -422,11 +428,6 @@ impl<'xml> Deserializer<'xml> { pub trait FromXmlOwned: for<'xml> FromXml<'xml> {} -#[allow(dead_code)] -struct State<'a> { - prefix: HashMap<&'a str, &'a str>, -} - #[derive(Debug, Error, PartialEq, Eq)] pub enum Error { #[error("format: {0}")] diff --git a/instant-xml/tests/all.rs b/instant-xml/tests/all.rs index fb35ced..bd72de9 100644 --- a/instant-xml/tests/all.rs +++ b/instant-xml/tests/all.rs @@ -2,8 +2,6 @@ use std::borrow::Cow; use instant_xml::{Error, FromXml, ToXml}; -//TODO: Add compile time errors check? - #[derive(Debug, Eq, PartialEq, ToXml)] struct Unit; @@ -381,7 +379,14 @@ fn direct_namespaces() { ); } -#[derive(Debug, PartialEq, ToXml)] +#[derive(Debug, PartialEq, Eq, FromXml, ToXml)] +#[xml(namespace("URI"))] +struct NestedLifetimes<'a> { + flag: bool, + str_type_a: &'a str, +} + +#[derive(Debug, PartialEq, FromXml, ToXml)] #[xml(namespace("URI"))] struct StructDeserializerScalars<'a, 'b> { bool_type: bool, @@ -392,14 +397,18 @@ struct StructDeserializerScalars<'a, 'b> { str_type_b: &'b str, char_type: char, f32_type: f32, + nested: NestedLifetimes<'a>, cow: Cow<'a, str>, option: Option<&'a str>, } #[test] fn scalars() { - // Option some assert_eq!( + StructDeserializerScalars::from_xml( + "true142stringlifetime alifetime bc1.20trueasd123" + ) + .unwrap(), StructDeserializerScalars{ bool_type: true, i8_type: 1, @@ -409,16 +418,20 @@ fn scalars() { str_type_b: "lifetime b", char_type: 'c', f32_type: 1.20, + nested: NestedLifetimes { + flag: true, + str_type_a: "asd" + }, cow: Cow::from("123"), - option: Some("asd"), + option: None, } - .to_xml() - .unwrap(), - "true142stringlifetime alifetime bc1.2123" ); // Option none assert_eq!( + StructDeserializerScalars::from_xml( + "true142stringlifetime alifetime bc1.2trueasd123" + ).unwrap(), StructDeserializerScalars{ bool_type: true, i8_type: 1, @@ -428,33 +441,78 @@ fn scalars() { str_type_b: "lifetime b", char_type: 'c', f32_type: 1.20, + nested: NestedLifetimes { + flag: true, + str_type_a: "asd" + }, cow: Cow::from("123"), - option: None, + option: Some("asd"), } - .to_xml() - .unwrap(), - "true142stringlifetime alifetime bc1.2123" ); } -#[derive(Debug, PartialEq, Eq, ToXml)] +#[derive(Debug, PartialEq, Eq, FromXml, ToXml)] #[xml(namespace("URI"))] struct StructSpecialEntities<'a> { - string_type: String, - str_type_a: &'a str, + string: String, + str: &'a str, cow: Cow<'a, str>, } +#[test] +fn escape_back() { + assert_eq!( + StructSpecialEntities::from_xml( + "<>&"'adsad"strstr&" + ) + .unwrap(), + StructSpecialEntities { + string: String::from("<>&\"'adsad\""), + str: "str", + cow: Cow::Owned("str&".to_string()), + } + ); + + // Wrong str char + assert_eq!( + StructSpecialEntities::from_xml( + "<>&"'adsad"str&" + ) + .unwrap_err(), + Error::Other("Unsupported char: str&".to_string()) + ); + + // Borrowed + let escape_back = StructSpecialEntities::from_xml( + "<>&"'adsad"strstr" + ) + .unwrap(); + + if let Cow::Owned(_) = escape_back.cow { + panic!("Should be Borrowed") + } + + // Owned + let escape_back = StructSpecialEntities::from_xml( + "<>&"'adsad"strstr&" + ) + .unwrap(); + + if let Cow::Borrowed(_) = escape_back.cow { + panic!("Should be Owned") + } +} + #[test] fn special_entities() { assert_eq!( StructSpecialEntities{ - string_type: "&\"<>\'aa".to_string(), - str_type_a: "&\"<>\'bb", + string: "&\"<>\'aa".to_string(), + str: "&\"<>\'bb", cow: Cow::from("&\"<>\'cc"), } .to_xml() .unwrap(), - "&"<>'aa&"<>'bb&"<>'cc" + "&"<>'aa&"<>'bb&"<>'cc" ); }