diff --git a/src/parse/response.rs b/src/parse/response.rs index 9a255f0..00a5a26 100644 --- a/src/parse/response.rs +++ b/src/parse/response.rs @@ -12,7 +12,7 @@ use nom::{ use crate::{ parse::{address::address_literal, number, Domain}, - types::{AuthMechanism, Capability, ReplyCode, Response}, + types::{AuthMechanism, Capability, ReplyCode, Response, TextString}, }; /// Greeting = ( "220 " (Domain / address-literal) [ SP textstring ] CRLF ) / @@ -79,14 +79,15 @@ pub fn Greeting(input: &[u8]) -> IResult<&[u8], Response> { /// HT, SP, Printable US-ASCII /// /// textstring = 1*(%d09 / %d32-126) -pub fn textstring(input: &[u8]) -> IResult<&[u8], &str> { - fn is_value(byte: u8) -> bool { - matches!(byte, 9 | 32..=126) - } +pub fn textstring(input: &[u8]) -> IResult<&[u8], TextString<'_>> { + let (remaining, parsed) = + map_res(take_while1(is_text_string_byte), std::str::from_utf8)(input)?; - let (remaining, parsed) = map_res(take_while1(is_value), std::str::from_utf8)(input)?; + Ok((remaining, TextString(parsed.into()))) +} - Ok((remaining, parsed)) +pub(crate) fn is_text_string_byte(byte: u8) -> bool { + matches!(byte, 9 | 32..=126) } // ------------------------------------------------------------------------------------------------- @@ -106,12 +107,12 @@ pub fn Reply_lines(input: &[u8]) -> IResult<&[u8], Response> { Vec::with_capacity(intermediate.len() + if text.is_some() { 1 } else { 0 }); for (_, _, text, _) in intermediate { if let Some(line) = text { - lines.push(line.to_owned()); + lines.push(line.into_owned()); } } if let Some((_, line)) = text { - lines.push(line.to_owned()); + lines.push(line.into_owned()); } Response::Other { code, lines } diff --git a/src/types.rs b/src/types.rs index 6dd7a70..ca685ee 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,9 +1,9 @@ -use std::io::Write; +use std::{borrow::Cow, fmt, io::Write, ops::Deref}; #[cfg(feature = "serdex")] use serde::{Deserialize, Serialize}; -use crate::utils::escape_quoted; +use crate::{parse::response::is_text_string_byte, utils::escape_quoted}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum Command { @@ -302,7 +302,7 @@ pub enum Response { }, Other { code: ReplyCode, - text: String, + lines: Vec>, }, } @@ -330,13 +330,13 @@ impl Response { } } - pub fn other(code: ReplyCode, text: T) -> Response + pub fn other(code: ReplyCode, text: TextString<'static>) -> Response where T: Into, { Response::Other { code, - text: text.into(), + lines: vec![text], } } @@ -387,19 +387,16 @@ impl Response { writer.write_all(format!("250 {}{}\r\n", domain, greet).as_bytes())?; } } - Response::Other { code, text } => { + Response::Other { code, lines } => { let code = u16::from(*code); - let lines = text.lines().collect::>(); - - if let Some((last, head)) = lines.split_last() { - for line in head { - write!(writer, "{}-{}\r\n", code, line)?; - } - - write!(writer, "{} {}\r\n", code, last)?; - } else { - write!(writer, "{}\r\n", code)?; + for line in lines.iter().take(lines.len().saturating_sub(1)) { + write!(writer, "{}-{}\r\n", code, line,)?; } + + match lines.last() { + Some(s) => write!(writer, "{} {}\r\n", code, s)?, + None => write!(writer, "{}\r\n", code)?, + }; } } @@ -785,9 +782,52 @@ impl AuthMechanism { } } +/// A string containing of tab, space and printable ASCII characters +#[cfg_attr(feature = "serdex", derive(Serialize, Deserialize))] +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TextString<'a>(pub(crate) Cow<'a, str>); + +impl<'a> TextString<'a> { + pub fn new(s: &'a str) -> Result { + match s.as_bytes().iter().all(|&b| is_text_string_byte(b)) { + true => Ok(TextString(Cow::Borrowed(s))), + false => Err(InvalidTextString(())), + } + } + + pub fn into_owned(self) -> TextString<'static> { + TextString(self.0.into_owned().into()) + } +} + +impl Deref for TextString<'_> { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl fmt::Display for TextString<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Debug)] +pub struct InvalidTextString(()); + +impl fmt::Display for InvalidTextString { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "input contains invalid characters") + } +} + +impl std::error::Error for InvalidTextString {} + #[cfg(test)] mod tests { - use super::{Capability, ReplyCode, Response}; + use super::{Capability, ReplyCode, Response, TextString}; #[test] fn test_serialize_greeting() { @@ -879,21 +919,21 @@ mod tests { ( Response::Other { code: ReplyCode::StartMailInput, - text: String::new(), + lines: vec![], }, b"354\r\n".as_ref(), ), ( Response::Other { code: ReplyCode::StartMailInput, - text: "A".into(), + lines: vec![TextString::new("A").unwrap()], }, b"354 A\r\n".as_ref(), ), ( Response::Other { code: ReplyCode::StartMailInput, - text: "A\nB".into(), + lines: vec![TextString::new("A").unwrap(), TextString::new("B").unwrap()], }, b"354-A\r\n354 B\r\n".as_ref(), ), diff --git a/tests/trace.rs b/tests/trace.rs index 34384f7..d4a8416 100644 --- a/tests/trace.rs +++ b/tests/trace.rs @@ -2,7 +2,7 @@ use nom::FindSubstring; use smtp_codec::{ parse::{ command::command, - response::{ehlo_ok_rsp, Greeting, Reply_line}, + response::{ehlo_ok_rsp, Greeting, Reply_lines}, }, types::Command, }; @@ -28,7 +28,7 @@ fn parse_trace(mut trace: &[u8]) { trace = rem; } Command::Data { .. } => { - let (rem, rsp) = Reply_line(trace).unwrap(); + let (rem, rsp) = Reply_lines(trace).unwrap(); println!("S: {:?}", rsp); trace = rem; @@ -37,12 +37,12 @@ fn parse_trace(mut trace: &[u8]) { println!("C (data): <{}>", std::str::from_utf8(data).unwrap()); trace = rem; - let (rem, rsp) = Reply_line(trace).unwrap(); + let (rem, rsp) = Reply_lines(trace).unwrap(); println!("S: {:?}", rsp); trace = rem; } _ => { - let (rem, rsp) = Reply_line(trace).unwrap(); + let (rem, rsp) = Reply_lines(trace).unwrap(); println!("S: {:?}", rsp); trace = rem; }