Deserialize scalars (#14)

Co-authored-by: Dirkjan Ochtman <dirkjan@ochtman.nl>
This commit is contained in:
choinskib 2022-09-01 13:28:40 +02:00 committed by GitHub
parent c553b22310
commit ebd913f603
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 424 additions and 56 deletions

View File

@ -1,5 +1,5 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use quote::{quote, ToTokens};
use crate::{get_namespaces, retrieve_field_attribute, FieldAttribute};
@ -34,6 +34,25 @@ impl quote::ToTokens for Deserializer {
impl Deserializer {
pub fn new(input: &syn::DeriveInput) -> Deserializer {
let ident = &input.ident;
let generics = (&input.generics).into_token_stream();
let lifetimes = (&input.generics.params).into_token_stream();
let mut lifetime_xml = TokenStream::new();
let mut lifetime_visitor = TokenStream::new();
let iter = &mut input.generics.params.iter();
if let Some(it) = iter.next() {
lifetime_xml = quote!(:);
lifetime_xml.extend(it.into_token_stream());
while let Some(it) = iter.by_ref().next() {
lifetime_xml.extend(quote!(+));
lifetime_xml.extend(it.into_token_stream());
}
lifetime_xml.extend(quote!(,));
lifetime_xml.extend(lifetimes.clone());
lifetime_visitor.extend(quote!(,));
lifetime_visitor.extend(lifetimes);
}
let name = ident.to_string();
let mut out = TokenStream::new();
@ -113,7 +132,7 @@ impl Deserializer {
let attr_type_match = attributes_tokens.match_;
out.extend(quote!(
fn deserialize(deserializer: &mut ::instant_xml::Deserializer) -> Result<Self, ::instant_xml::Error> {
fn deserialize(deserializer: &mut ::instant_xml::Deserializer<'xml>) -> Result<Self, ::instant_xml::Error> {
use ::instant_xml::parse::XmlRecord;
use ::instant_xml::{Error, Deserializer, Visitor} ;
@ -143,12 +162,18 @@ impl Deserializer {
}
}
struct StructVisitor;
impl<'xml> Visitor<'xml> for StructVisitor {
type Value = #ident;
struct StructVisitor<'xml #lifetime_xml> {
marker: std::marker::PhantomData<#ident #generics>,
lifetime: std::marker::PhantomData<&'xml ()>,
}
fn visit_struct<'a>(&self, deserializer: &mut ::instant_xml::Deserializer) -> Result<Self::Value, ::instant_xml::Error>
{
impl<'xml #lifetime_xml> Visitor<'xml> for StructVisitor<'xml #lifetime_visitor> {
type Value = #ident #generics;
fn visit_struct(
&self,
deserializer: &mut ::instant_xml::Deserializer<'xml>
) -> Result<Self::Value, ::instant_xml::Error> {
#declare_values
while let Some(( key, _ )) = deserializer.peek_next_attribute() {
match get_attribute(&key) {
@ -181,7 +206,15 @@ impl Deserializer {
}
#namespaces_map;
deserializer.deserialize_struct(StructVisitor{}, #name, #default_namespace, &namespaces_map)
deserializer.deserialize_struct(
StructVisitor{
marker: std::marker::PhantomData,
lifetime: std::marker::PhantomData
},
#name,
#default_namespace,
&namespaces_map
)
}
));
@ -189,6 +222,12 @@ impl Deserializer {
const TAG_NAME: ::instant_xml::TagName<'xml> = ::instant_xml::TagName::Custom(#name);
));
out = quote!(
impl<'xml #lifetime_xml> FromXml<'xml> for #ident #generics {
#out
}
);
Deserializer { out }
}
@ -206,16 +245,14 @@ impl Deserializer {
let field_var = field.ident.as_ref().unwrap();
let field_var_str = field_var.to_string();
let const_field_var_str = Ident::new(&field_var_str.to_uppercase(), Span::call_site());
let field_type = match &field.ty {
syn::Type::Path(v) => v.path.get_ident(),
_ => panic!("Wrong field attribute format"),
};
let mut no_lifetime_type = field.ty.clone();
discard_lifetimes(&mut no_lifetime_type);
let enum_name = Ident::new(&format!("__Value{index}"), Span::call_site());
tokens.enum_.extend(quote!(#enum_name,));
tokens.consts.extend(quote!(
const #const_field_var_str: &str = match #field_type::TAG_NAME {
const #const_field_var_str: &str = match <#no_lifetime_type>::TAG_NAME {
::instant_xml::TagName::FieldName => #field_var_str,
::instant_xml::TagName::Custom(v) => v,
};
@ -232,7 +269,7 @@ impl Deserializer {
}
declare_values.extend(quote!(
let mut #enum_name: Option<#field_type> = None;
let mut #enum_name: Option<#no_lifetime_type> = None;
));
let def_prefix = match def_prefix {
@ -283,7 +320,7 @@ impl Deserializer {
}
#field_namespace
deserializer.set_next_def_namespace(field_namespace)?;
#enum_name = Some(#field_type::deserialize(deserializer)?);
#enum_name = Some(<#no_lifetime_type>::deserialize(deserializer)?);
},
));
} else {
@ -294,13 +331,53 @@ impl Deserializer {
}
deserializer.set_next_type_as_attribute()?;
#enum_name = Some(#field_type::deserialize(deserializer)?);
#enum_name = Some(<#no_lifetime_type>::deserialize(deserializer)?);
},
));
}
return_val.extend(quote!(
#field_var: #enum_name.expect("Expected some value"),
#field_var: match #enum_name {
Some(v) => v,
None => <#no_lifetime_type>::missing_value()?,
},
));
}
}
fn discard_lifetimes(ty: &mut syn::Type) {
match ty {
syn::Type::Path(ty) => discard_path_lifetimes(ty),
syn::Type::Reference(ty) => {
ty.lifetime = None;
discard_lifetimes(&mut ty.elem);
}
_ => {}
}
}
fn discard_path_lifetimes(path: &mut syn::TypePath) {
if let Some(q) = &mut path.qself {
discard_lifetimes(&mut q.ty);
}
for segment in &mut path.path.segments {
match &mut segment.arguments {
syn::PathArguments::None => {}
syn::PathArguments::AngleBracketed(args) => {
args.args.iter_mut().for_each(|arg| match arg {
syn::GenericArgument::Lifetime(lt) => {
*lt = syn::Lifetime::new("'_", Span::call_site())
}
syn::GenericArgument::Type(ty) => discard_lifetimes(ty),
syn::GenericArgument::Binding(_)
| syn::GenericArgument::Constraint(_)
| syn::GenericArgument::Const(_) => {}
})
}
syn::PathArguments::Parenthesized(args) => {
args.inputs.iter_mut().for_each(discard_lifetimes)
}
}
}
}

View File

@ -168,12 +168,9 @@ pub fn to_xml(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
#[proc_macro_derive(FromXml, attributes(xml))]
pub fn from_xml(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
let ident = &ast.ident;
let deserializer = de::Deserializer::new(&ast);
proc_macro::TokenStream::from(quote!(
impl<'xml> FromXml<'xml> for #ident {
#deserializer
}
))
}

View File

@ -1,17 +1,25 @@
use std::borrow::Cow;
use std::fmt;
use std::marker::PhantomData;
use std::str::FromStr;
use crate::{
Deserializer, EntityType, Error, FieldAttribute, FromXml, Serializer, TagName, ToXml, Visitor,
};
struct BoolVisitor;
// Deserializer
struct FromStrToVisitor<T: FromStr>(PhantomData<T>)
where
T: FromStr,
<T as FromStr>::Err: std::fmt::Display;
impl<'de> Visitor<'de> for BoolVisitor {
type Value = bool;
fn visit_str<'a>(self, value: &str) -> Result<Self::Value, Error> {
impl<'xml, T> Visitor<'xml> for FromStrToVisitor<T>
where
T: FromStr,
<T as FromStr>::Err: std::fmt::Display,
{
type Value = T;
fn visit_str(self, value: &str) -> Result<Self::Value, Error> {
match FromStr::from_str(value) {
Ok(v) => Ok(v),
Err(e) => Err(Error::Other(e.to_string())),
@ -19,12 +27,22 @@ impl<'de> Visitor<'de> for BoolVisitor {
}
}
struct BoolVisitor;
impl<'xml> Visitor<'xml> for BoolVisitor {
type Value = bool;
fn visit_str(self, value: &str) -> Result<Self::Value, Error> {
FromStrToVisitor(PhantomData::<Self::Value>).visit_str(value)
}
}
impl<'xml> FromXml<'xml> for bool {
const TAG_NAME: TagName<'xml> = TagName::FieldName;
fn deserialize(deserializer: &mut Deserializer) -> Result<Self, Error> {
match deserializer.consume_next_type() {
EntityType::Element => deserializer.deserialize_bool(BoolVisitor),
EntityType::Element => deserializer.deserialize_element(BoolVisitor),
EntityType::Attribute => deserializer.deserialize_attribute(BoolVisitor),
}
}
@ -73,6 +91,223 @@ macro_rules! to_xml_for_number {
};
}
struct NumberVisitor<T>
where
T: FromStr,
<T as FromStr>::Err: std::fmt::Display,
{
marker: PhantomData<T>,
}
impl<'xml, T> Visitor<'xml> for NumberVisitor<T>
where
T: FromStr,
<T as FromStr>::Err: std::fmt::Display,
{
type Value = T;
fn visit_str(self, value: &str) -> Result<Self::Value, Error> {
FromStrToVisitor(PhantomData::<Self::Value>).visit_str(value)
}
}
macro_rules! from_xml_for_number {
($typ:ty) => {
impl<'xml> FromXml<'xml> for $typ {
const TAG_NAME: TagName<'xml> = TagName::FieldName;
fn deserialize(deserializer: &mut Deserializer) -> Result<Self, Error> {
match deserializer.consume_next_type() {
EntityType::Element => deserializer.deserialize_element(NumberVisitor {
marker: PhantomData,
}),
EntityType::Attribute => deserializer.deserialize_attribute(NumberVisitor {
marker: PhantomData,
}),
}
}
}
};
}
from_xml_for_number!(i8);
from_xml_for_number!(i16);
from_xml_for_number!(i32);
from_xml_for_number!(i64);
from_xml_for_number!(isize);
from_xml_for_number!(u8);
from_xml_for_number!(u16);
from_xml_for_number!(u32);
from_xml_for_number!(u64);
from_xml_for_number!(usize);
from_xml_for_number!(f32);
from_xml_for_number!(f64);
struct StringVisitor;
impl<'xml> Visitor<'xml> for StringVisitor {
type Value = String;
fn visit_str(self, value: &str) -> Result<Self::Value, Error> {
Ok(escape_back(value).into_owned())
}
}
impl<'xml> FromXml<'xml> for String {
const TAG_NAME: TagName<'xml> = TagName::FieldName;
fn deserialize(deserializer: &mut Deserializer) -> Result<Self, Error> {
//<&'xml str>::deserialize(deserializer);
match deserializer.consume_next_type() {
EntityType::Element => deserializer.deserialize_element(StringVisitor),
EntityType::Attribute => deserializer.deserialize_attribute(StringVisitor),
}
}
}
struct CharVisitor;
impl<'xml> Visitor<'xml> for CharVisitor {
type Value = char;
fn visit_str(self, value: &str) -> Result<Self::Value, Error> {
match value.len() {
1 => Ok(value.chars().next().expect("char type")),
_ => Err(Error::Other("Expected char type".to_string())),
}
}
}
impl<'xml> FromXml<'xml> for char {
const TAG_NAME: TagName<'xml> = TagName::FieldName;
fn deserialize(deserializer: &mut Deserializer) -> Result<Self, Error> {
match deserializer.consume_next_type() {
EntityType::Element => deserializer.deserialize_element(CharVisitor),
EntityType::Attribute => deserializer.deserialize_attribute(CharVisitor),
}
}
}
struct StrVisitor;
impl<'a> Visitor<'a> for StrVisitor {
type Value = &'a str;
fn visit_str(self, value: &'a str) -> Result<Self::Value, Error> {
match escape_back(value) {
Cow::Owned(v) => Err(Error::Other(format!("Unsupported char: {}", v))),
Cow::Borrowed(v) => Ok(v),
}
}
}
impl<'xml> FromXml<'xml> for &'xml str {
const TAG_NAME: TagName<'xml> = TagName::FieldName;
fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result<Self, Error> {
match deserializer.consume_next_type() {
EntityType::Element => deserializer.deserialize_element(StrVisitor),
EntityType::Attribute => deserializer.deserialize_attribute(StrVisitor),
}
}
}
struct CowStrVisitor;
impl<'a> Visitor<'a> for CowStrVisitor {
type Value = Cow<'a, str>;
fn visit_str(self, value: &'a str) -> Result<Self::Value, Error> {
Ok(escape_back(value))
}
}
impl<'xml> FromXml<'xml> for Cow<'xml, str> {
const TAG_NAME: TagName<'xml> = <&str>::TAG_NAME;
fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result<Self, Error> {
match deserializer.consume_next_type() {
EntityType::Element => deserializer.deserialize_element(CowStrVisitor),
EntityType::Attribute => deserializer.deserialize_attribute(CowStrVisitor),
}
}
}
impl<'xml, T> FromXml<'xml> for Option<T>
where
T: FromXml<'xml>,
{
const TAG_NAME: TagName<'xml> = <T>::TAG_NAME;
fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result<Self, Error> {
match <T>::deserialize(deserializer) {
Ok(v) => Ok(Some(v)),
Err(e) => Err(e),
}
}
fn missing_value() -> Result<Self, Error> {
Ok(None)
}
}
fn escape_back(input: &str) -> Cow<'_, str> {
let mut result = String::with_capacity(input.len());
let input_len = input.len();
let mut last_end = 0;
while input_len - last_end >= 4 {
match &input[last_end..(last_end + 4)] {
"&lt;" => {
result.push('<');
last_end += 4;
continue;
}
"&gt;" => {
result.push('>');
last_end += 4;
continue;
}
_ => (),
};
if input_len - last_end >= 5 {
if &input[last_end..(last_end + 5)] == "&amp;" {
result.push('&');
last_end += 5;
continue;
}
if input_len - last_end >= 6 {
match &input[last_end..(last_end + 6)] {
"&apos;" => {
result.push('\'');
last_end += 6;
continue;
}
"&quot;" => {
result.push('"');
last_end += 6;
continue;
}
_ => (),
};
}
}
result.push_str(input.get(last_end..last_end + 1).unwrap());
last_end += 1;
}
result.push_str(input.get(last_end..).unwrap());
if result.len() == input.len() {
return Cow::Borrowed(input);
}
Cow::Owned(result)
}
to_xml_for_number!(i8);
to_xml_for_number!(i16);
to_xml_for_number!(i32);

View File

@ -198,22 +198,28 @@ pub enum TagName<'xml> {
pub trait FromXml<'xml>: Sized {
const TAG_NAME: TagName<'xml>;
fn from_xml(input: &str) -> Result<Self, Error> {
fn from_xml(input: &'xml str) -> Result<Self, Error> {
let mut deserializer = Deserializer::new(input);
Self::deserialize(&mut deserializer)
}
fn deserialize(deserializer: &mut Deserializer) -> Result<Self, Error>;
fn deserialize(deserializer: &mut Deserializer<'xml>) -> Result<Self, Error>;
// If the missing field is of type `Option<T>` then treat is as `None`,
// otherwise it is an error.
fn missing_value() -> Result<Self, Error> {
Err(Error::MissingValue)
}
}
pub trait Visitor<'xml>: Sized {
type Value;
fn visit_str(self, _value: &str) -> Result<Self::Value, Error> {
fn visit_str(self, _value: &'xml str) -> Result<Self::Value, Error> {
unimplemented!();
}
fn visit_struct<'a>(&self, _deserializer: &'a mut Deserializer) -> Result<Self::Value, Error> {
fn visit_struct(&self, _deserializer: &mut Deserializer<'xml>) -> Result<Self::Value, Error> {
unimplemented!();
}
}
@ -373,7 +379,7 @@ impl<'xml> Deserializer<'xml> {
ret
}
fn deserialize_bool<V>(&mut self, visitor: V) -> Result<V::Value, Error>
fn deserialize_element<V>(&mut self, visitor: V) -> Result<V::Value, Error>
where
V: Visitor<'xml>,
{
@ -422,11 +428,6 @@ impl<'xml> Deserializer<'xml> {
pub trait FromXmlOwned: for<'xml> FromXml<'xml> {}
#[allow(dead_code)]
struct State<'a> {
prefix: HashMap<&'a str, &'a str>,
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum Error {
#[error("format: {0}")]

View File

@ -2,8 +2,6 @@ use std::borrow::Cow;
use instant_xml::{Error, FromXml, ToXml};
//TODO: Add compile time errors check?
#[derive(Debug, Eq, PartialEq, ToXml)]
struct Unit;
@ -381,7 +379,14 @@ fn direct_namespaces() {
);
}
#[derive(Debug, PartialEq, ToXml)]
#[derive(Debug, PartialEq, Eq, FromXml, ToXml)]
#[xml(namespace("URI"))]
struct NestedLifetimes<'a> {
flag: bool,
str_type_a: &'a str,
}
#[derive(Debug, PartialEq, FromXml, ToXml)]
#[xml(namespace("URI"))]
struct StructDeserializerScalars<'a, 'b> {
bool_type: bool,
@ -392,14 +397,18 @@ struct StructDeserializerScalars<'a, 'b> {
str_type_b: &'b str,
char_type: char,
f32_type: f32,
nested: NestedLifetimes<'a>,
cow: Cow<'a, str>,
option: Option<&'a str>,
}
#[test]
fn scalars() {
// Option some
assert_eq!(
StructDeserializerScalars::from_xml(
"<StructDeserializerScalars xmlns=\"URI\"><bool_type>true</bool_type><i8_type>1</i8_type><u32_type>42</u32_type><string_type>string</string_type><str_type_a>lifetime a</str_type_a><str_type_b>lifetime b</str_type_b><char_type>c</char_type><f32_type>1.20</f32_type><NestedLifetimes><flag>true</flag><str_type_a>asd</str_type_a></NestedLifetimes><cow>123</cow></StructDeserializerScalars>"
)
.unwrap(),
StructDeserializerScalars{
bool_type: true,
i8_type: 1,
@ -409,16 +418,20 @@ fn scalars() {
str_type_b: "lifetime b",
char_type: 'c',
f32_type: 1.20,
nested: NestedLifetimes {
flag: true,
str_type_a: "asd"
},
cow: Cow::from("123"),
option: Some("asd"),
option: None,
}
.to_xml()
.unwrap(),
"<StructDeserializerScalars xmlns=\"URI\"><bool_type>true</bool_type><i8_type>1</i8_type><u32_type>42</u32_type><string_type>string</string_type><str_type_a>lifetime a</str_type_a><str_type_b>lifetime b</str_type_b><char_type>c</char_type><f32_type>1.2</f32_type><cow>123</cow><option>asd</option></StructDeserializerScalars>"
);
// Option none
assert_eq!(
StructDeserializerScalars::from_xml(
"<StructDeserializerScalars xmlns=\"URI\"><bool_type>true</bool_type><i8_type>1</i8_type><u32_type>42</u32_type><string_type>string</string_type><str_type_a>lifetime a</str_type_a><str_type_b>lifetime b</str_type_b><char_type>c</char_type><f32_type>1.2</f32_type><NestedLifetimes><flag>true</flag><str_type_a>asd</str_type_a></NestedLifetimes><cow>123</cow><option>asd</option></StructDeserializerScalars>"
).unwrap(),
StructDeserializerScalars{
bool_type: true,
i8_type: 1,
@ -428,33 +441,78 @@ fn scalars() {
str_type_b: "lifetime b",
char_type: 'c',
f32_type: 1.20,
nested: NestedLifetimes {
flag: true,
str_type_a: "asd"
},
cow: Cow::from("123"),
option: None,
option: Some("asd"),
}
.to_xml()
.unwrap(),
"<StructDeserializerScalars xmlns=\"URI\"><bool_type>true</bool_type><i8_type>1</i8_type><u32_type>42</u32_type><string_type>string</string_type><str_type_a>lifetime a</str_type_a><str_type_b>lifetime b</str_type_b><char_type>c</char_type><f32_type>1.2</f32_type><cow>123</cow></StructDeserializerScalars>"
);
}
#[derive(Debug, PartialEq, Eq, ToXml)]
#[derive(Debug, PartialEq, Eq, FromXml, ToXml)]
#[xml(namespace("URI"))]
struct StructSpecialEntities<'a> {
string_type: String,
str_type_a: &'a str,
string: String,
str: &'a str,
cow: Cow<'a, str>,
}
#[test]
fn escape_back() {
assert_eq!(
StructSpecialEntities::from_xml(
"<StructSpecialEntities xmlns=\"URI\"><string>&lt;&gt;&amp;&quot;&apos;adsad&quot;</string><str>str</str><cow>str&amp;</cow></StructSpecialEntities>"
)
.unwrap(),
StructSpecialEntities {
string: String::from("<>&\"'adsad\""),
str: "str",
cow: Cow::Owned("str&".to_string()),
}
);
// Wrong str char
assert_eq!(
StructSpecialEntities::from_xml(
"<StructSpecialEntities xmlns=\"URI\"><string>&lt;&gt;&amp;&quot;&apos;adsad&quot;</string><str>str&amp;</str></StructSpecialEntities>"
)
.unwrap_err(),
Error::Other("Unsupported char: str&".to_string())
);
// Borrowed
let escape_back = StructSpecialEntities::from_xml(
"<StructSpecialEntities xmlns=\"URI\"><string>&lt;&gt;&amp;&quot;&apos;adsad&quot;</string><str>str</str><cow>str</cow></StructSpecialEntities>"
)
.unwrap();
if let Cow::Owned(_) = escape_back.cow {
panic!("Should be Borrowed")
}
// Owned
let escape_back = StructSpecialEntities::from_xml(
"<StructSpecialEntities xmlns=\"URI\"><string>&lt;&gt;&amp;&quot;&apos;adsad&quot;</string><str>str</str><cow>str&amp;</cow></StructSpecialEntities>"
)
.unwrap();
if let Cow::Borrowed(_) = escape_back.cow {
panic!("Should be Owned")
}
}
#[test]
fn special_entities() {
assert_eq!(
StructSpecialEntities{
string_type: "&\"<>\'aa".to_string(),
str_type_a: "&\"<>\'bb",
string: "&\"<>\'aa".to_string(),
str: "&\"<>\'bb",
cow: Cow::from("&\"<>\'cc"),
}
.to_xml()
.unwrap(),
"<StructSpecialEntities xmlns=\"URI\"><string_type>&amp;&quot;&lt;&gt;&apos;aa</string_type><str_type_a>&amp;&quot;&lt;&gt;&apos;bb</str_type_a><cow>&amp;&quot;&lt;&gt;&apos;cc</cow></StructSpecialEntities>"
"<StructSpecialEntities xmlns=\"URI\"><string>&amp;&quot;&lt;&gt;&apos;aa</string><str>&amp;&quot;&lt;&gt;&apos;bb</str><cow>&amp;&quot;&lt;&gt;&apos;cc</cow></StructSpecialEntities>"
);
}