commit 124caf28e6db1e7bf8cdfef25c55760c81fb91b5 Author: Corey Farwell coreyf@rwell.org Date: Wed Dec 20 23:34:05 2017 -0500
Wrap types in protover.rs.
https://trac.torproject.org/projects/tor/ticket/24030
Introduce new wrapper types:
- `SupportedProtocols` - `Versions`
Introduce a type alias:
- `Version` (`u32`) --- src/rust/protover/protover.rs | 307 ++++++++++++++++++++++-------------------- 1 file changed, 164 insertions(+), 143 deletions(-)
diff --git a/src/rust/protover/protover.rs b/src/rust/protover/protover.rs index 5e281a3e9..af0049a41 100644 --- a/src/rust/protover/protover.rs +++ b/src/rust/protover/protover.rs @@ -98,94 +98,120 @@ pub fn get_supported_protocols() -> String { SUPPORTED_PROTOCOLS.join(" ") }
-/// Translates a vector representation of a protocol list into a HashMap -fn parse_protocols<P, S>( - protocols: P, -) -> Result<HashMap<Proto, HashSet<u32>>, &'static str> -where - P: Iterator<Item = S>, - S: AsRef<str>, -{ - let mut parsed = HashMap::new(); - - for subproto in protocols { - let (name, version) = get_proto_and_vers(subproto.as_ref())?; - parsed.insert(name, version); +pub struct SupportedProtocols(HashMap<Proto, Versions>); + +impl SupportedProtocols { + /// # Examples + /// + /// ``` + /// use protover::SupportedProtocols; + /// + /// let supported_protocols = SupportedProtocols::from_proto_entries_string( + /// "HSDir=1-2 HSIntro=3-4" + /// ); + /// ``` + pub fn from_proto_entries_string( + proto_entries: &str, + ) -> Result<Self, &'static str> { + Self::from_proto_entries(proto_entries.split(" ")) } - Ok(parsed) -}
-/// Translates a string representation of a protocol list to a HashMap -fn parse_protocols_from_string<'a>( - protocol_string: &'a str, -) -> Result<HashMap<Proto, HashSet<u32>>, &'static str> { - parse_protocols(protocol_string.split(" ")) -} - -/// Translates supported tor versions from a string into a HashMap, which is -/// useful when looking up a specific subprotocol. -/// -/// # Returns -/// -/// A `Result` whose `Ok` value is a `HashMap<Proto, <u32>>` holding all -/// subprotocols and versions currently supported by tor. -/// -/// The returned `Result`'s `Err` value is an `&'static str` with a description -/// of the error. -/// -fn tor_supported() -> Result<HashMap<Proto, HashSet<u32>>, &'static str> { - parse_protocols(SUPPORTED_PROTOCOLS.iter()) -} + /// ``` + /// use protover::SupportedProtocols; + /// + /// let supported_protocols = SupportedProtocols::from_proto_entries([ + /// "HSDir=1-2", + /// "HSIntro=3-4", + /// ].iter()); + /// ``` + pub fn from_proto_entries<I, S>(protocol_strs: I) -> Result<Self, &'static str> + where + I: Iterator<Item = S>, + S: AsRef<str>, + { + let mut parsed = HashMap::new(); + for subproto in protocol_strs { + let (name, version) = get_proto_and_vers(subproto.as_ref())?; + parsed.insert(name, version); + } + Ok(SupportedProtocols(parsed)) + }
-/// Get the unique version numbers supported by a subprotocol. -/// -/// # Inputs -/// -/// * `version_string`, a string comprised of "[0-9,-]" -/// -/// # Returns -/// -/// A `Result` whose `Ok` value is a `HashSet<u32>` holding all of the unique -/// version numbers. If there were ranges in the `version_string`, then these -/// are expanded, i.e. `"1-3"` would expand to `HashSet<u32>::new([1, 2, 3])`. -/// The returned HashSet is *unordered*. -/// -/// The returned `Result`'s `Err` value is an `&'static str` with a description -/// of the error. -/// -/// # Errors -/// -/// This function will error if: -/// -/// * the `version_string` is empty or contains an equals (`"="`) sign, -/// * the expansion of a version range produces an error (see -/// `expand_version_range`), -/// * any single version number is not parseable as an `u32` in radix 10, or -/// * there are greater than 2^16 version numbers to expand. -/// -fn get_versions(version_string: &str) -> Result<HashSet<u32>, &'static str> { - if version_string.is_empty() { - return Err("version string is empty"); + /// Translates supported tor versions from a string into a HashMap, which + /// is useful when looking up a specific subprotocol. + /// + /// # Returns + /// + /// A `Result` whose `Ok` value is a `HashMap<Proto, <Version>>` holding all + /// subprotocols and versions currently supported by tor. + /// + /// The returned `Result`'s `Err` value is an `&'static str` with a + /// description of the error. + /// + fn tor_supported() -> Result<Self, &'static str> { + Self::from_proto_entries(SUPPORTED_PROTOCOLS.iter().map(|n| *n)) } +}
- let mut versions = HashSet::<u32>::new(); +type Version = u32; + +/// Set of versions for a protocol. +#[derive(Debug, PartialEq, Eq)] +pub struct Versions(HashSet<Version>); + +impl Versions { + /// Get the unique version numbers supported by a subprotocol. + /// + /// # Inputs + /// + /// * `version_string`, a string comprised of "[0-9,-]" + /// + /// # Returns + /// + /// A `Result` whose `Ok` value is a `HashSet<u32>` holding all of the unique + /// version numbers. If there were ranges in the `version_string`, then these + /// are expanded, i.e. `"1-3"` would expand to `HashSet<u32>::new([1, 2, 3])`. + /// The returned HashSet is *unordered*. + /// + /// The returned `Result`'s `Err` value is an `&'static str` with a description + /// of the error. + /// + /// # Errors + /// + /// This function will error if: + /// + /// * the `version_string` is empty or contains an equals (`"="`) sign, + /// * the expansion of a version range produces an error (see + /// `expand_version_range`), + /// * any single version number is not parseable as an `u32` in radix 10, or + /// * there are greater than 2^16 version numbers to expand. + /// + fn from_version_string( + version_string: &str, + ) -> Result<Self, &'static str> { + if version_string.is_empty() { + return Err("version string is empty"); + }
- for piece in version_string.split(",") { - if piece.contains("-") { - for p in expand_version_range(piece)? { - versions.insert(p); + let mut versions = HashSet::<Version>::new(); + + for piece in version_string.split(",") { + if piece.contains("-") { + for p in expand_version_range(piece)? { + versions.insert(p); + } + } else { + versions.insert(u32::from_str(piece).or( + Err("invalid protocol entry"), + )?); } - } else { - versions.insert(u32::from_str(piece).or( - Err("invalid protocol entry"), - )?); - }
- if versions.len() > MAX_PROTOCOLS_TO_EXPAND as usize { - return Err("Too many versions to expand"); + if versions.len() > MAX_PROTOCOLS_TO_EXPAND as usize { + return Err("Too many versions to expand"); + } } + Ok(Versions(versions)) } - Ok(versions) }
@@ -205,7 +231,7 @@ fn get_versions(version_string: &str) -> Result<HashSet<u32>, &'static str> { /// fn get_proto_and_vers<'a>( protocol_entry: &'a str, -) -> Result<(Proto, HashSet<u32>), &'static str> { +) -> Result<(Proto, Versions), &'static str> { let mut parts = protocol_entry.splitn(2, "=");
let proto = match parts.next() { @@ -218,7 +244,7 @@ fn get_proto_and_vers<'a>( None => return Err("invalid protover entry"), };
- let versions = get_versions(vers)?; + let versions = Versions::from_version_string(vers)?; let proto_name = proto.parse()?;
Ok((proto_name, versions)) @@ -245,19 +271,18 @@ fn contains_only_supported_protocols(proto_entry: &str) -> bool { Err(_) => return false, };
- let currently_supported: HashMap<Proto, HashSet<u32>> = - match tor_supported() { - Ok(n) => n, - Err(_) => return false, - }; + let currently_supported = match SupportedProtocols::tor_supported() { + Ok(n) => n.0, + Err(_) => return false, + };
let supported_versions = match currently_supported.get(&name) { Some(n) => n, None => return false, };
- vers.retain(|x| !supported_versions.contains(x)); - vers.is_empty() + vers.0.retain(|x| !supported_versions.0.contains(x)); + vers.0.is_empty() }
/// Determine if we support every protocol a client supports, and if not, @@ -303,7 +328,7 @@ pub fn all_supported(protocols: &str) -> (bool, String) { /// /// * `list`, a string representation of a list of protocol entries. /// * `proto`, a `Proto` to test support for -/// * `vers`, a `u32` version which we will go on to determine whether the +/// * `vers`, a `Version` version which we will go on to determine whether the /// specified protocol supports. /// /// # Examples @@ -321,21 +346,19 @@ pub fn all_supported(protocols: &str) -> (bool, String) { pub fn protover_string_supports_protocol( list: &str, proto: Proto, - vers: u32, + vers: Version, ) -> bool { - let supported: HashMap<Proto, HashSet<u32>>; - - match parse_protocols_from_string(list) { - Ok(result) => supported = result, + let supported = match SupportedProtocols::from_proto_entries_string(list) { + Ok(result) => result.0, Err(_) => return false, - } + };
let supported_versions = match supported.get(&proto) { Some(n) => n, None => return false, };
- supported_versions.contains(&vers) + supported_versions.0.contains(&vers) }
/// As protover_string_supports_protocol(), but also returns True if @@ -365,23 +388,21 @@ pub fn protover_string_supports_protocol_or_later( proto: Proto, vers: u32, ) -> bool { - let supported: HashMap<Proto, HashSet<u32>>; - - match parse_protocols_from_string(list) { - Ok(result) => supported = result, + let supported = match SupportedProtocols::from_proto_entries_string(list) { + Ok(result) => result.0, Err(_) => return false, - } + };
let supported_versions = match supported.get(&proto) { Some(n) => n, None => return false, };
- supported_versions.iter().any(|v| v >= &vers) + supported_versions.0.iter().any(|v| v >= &vers) }
/// Fully expand a version range. For example, 1-3 expands to 1,2,3 -/// Helper for get_versions +/// Helper for Versions::from_version_string /// /// # Inputs /// @@ -483,10 +504,9 @@ fn find_range(list: &Vec<u32>) -> (bool, u32) { /// /// A `String` representation of this set in ascending order. /// -fn contract_protocol_list<'a>(supported_set: &'a HashSet<u32>) -> String { - let mut supported: Vec<u32> = supported_set.iter() - .map(|x| *x) - .collect(); +fn contract_protocol_list<'a>(supported_set: &'a HashSet<Version>) -> String { + let mut supported: Vec<Version> = + supported_set.iter().map(|x| *x).collect(); supported.sort();
let mut final_output: Vec<String> = Vec::new(); @@ -522,8 +542,8 @@ fn contract_protocol_list<'a>(supported_set: &'a HashSet<u32>) -> String { /// /// # Returns /// -/// A `Result` whose `Ok` value is a `HashSet<u32>` holding all of the unique -/// version numbers. +/// A `Result` whose `Ok` value is a `HashSet<Version>` holding all of the +/// unique version numbers. /// /// The returned `Result`'s `Err` value is an `&'static str` with a description /// of the error. @@ -534,12 +554,12 @@ fn contract_protocol_list<'a>(supported_set: &'a HashSet<u32>) -> String { /// /// * The protocol string does not follow the "protocol_name=version_list" /// expected format -/// * If the version string is malformed. See `get_versions`. +/// * If the version string is malformed. See `Versions::from_version_string`. /// fn parse_protocols_from_string_with_no_validation<'a>( protocol_string: &'a str, -) -> Result<HashMap<String, HashSet<u32>>, &'static str> { - let mut parsed: HashMap<String, HashSet<u32>> = HashMap::new(); +) -> Result<HashMap<String, Versions>, &'static str> { + let mut parsed: HashMap<String, Versions> = HashMap::new();
for subproto in protocol_string.split(" ") { let mut parts = subproto.splitn(2, "="); @@ -554,7 +574,7 @@ fn parse_protocols_from_string_with_no_validation<'a>( None => return Err("invalid protover entry"), };
- let versions = get_versions(vers)?; + let versions = Versions::from_version_string(vers)?;
parsed.insert(String::from(name), versions); } @@ -602,21 +622,22 @@ pub fn compute_vote( // } // means that FirstSupportedProtocol has three votes which support version // 1, and one vote that supports version 2 - let mut all_count: HashMap<String, HashMap<u32, usize>> = HashMap::new(); + let mut all_count: HashMap<String, HashMap<Version, usize>> = + HashMap::new();
// parse and collect all of the protos and their versions and collect them for vote in list_of_proto_strings { - let this_vote: HashMap<String, HashSet<u32>> = + let this_vote: HashMap<String, Versions> = match parse_protocols_from_string_with_no_validation(&vote) { Ok(result) => result, Err(_) => continue, };
for (protocol, versions) in this_vote { - let supported_vers: &mut HashMap<u32, usize> = + let supported_vers: &mut HashMap<Version, usize> = all_count.entry(protocol).or_insert(HashMap::new());
- for version in versions { + for version in versions.0 { let counter: &mut usize = supported_vers.entry(version).or_insert(0); *counter += 1; @@ -690,20 +711,18 @@ fn write_vote_to_string(vote: &HashMap<String, String>) -> String { /// let is_supported = is_supported_here(Proto::Link, 1); /// assert_eq!(true, is_supported); /// ``` -pub fn is_supported_here(proto: Proto, vers: u32) -> bool { - let currently_supported: HashMap<Proto, HashSet<u32>>; - - match tor_supported() { - Ok(result) => currently_supported = result, +pub fn is_supported_here(proto: Proto, vers: Version) -> bool { + let currently_supported = match SupportedProtocols::tor_supported() { + Ok(result) => result.0, Err(_) => return false, - } + };
let supported_versions = match currently_supported.get(&proto) { Some(n) => n, None => return false, };
- supported_versions.contains(&vers) + supported_versions.0.contains(&vers) }
/// Older versions of Tor cannot infer their own subprotocols @@ -755,48 +774,50 @@ pub fn compute_for_old_tor(version: &str) -> String {
#[cfg(test)] mod test { + use super::Version; + #[test] - fn test_get_versions() { + fn test_versions_from_version_string() { use std::collections::HashSet;
- use super::get_versions; + use super::Versions;
- assert_eq!(Err("version string is empty"), get_versions("")); - assert_eq!(Err("invalid protocol entry"), get_versions("a,b")); - assert_eq!(Err("invalid protocol entry"), get_versions("1,!")); + assert_eq!(Err("version string is empty"), Versions::from_version_string("")); + assert_eq!(Err("invalid protocol entry"), Versions::from_version_string("a,b")); + assert_eq!(Err("invalid protocol entry"), Versions::from_version_string("1,!"));
{ - let mut versions: HashSet<u32> = HashSet::new(); + let mut versions: HashSet<Version> = HashSet::new(); versions.insert(1); - assert_eq!(Ok(versions), get_versions("1")); + assert_eq!(versions, Versions::from_version_string("1").unwrap().0); } { - let mut versions: HashSet<u32> = HashSet::new(); + let mut versions: HashSet<Version> = HashSet::new(); versions.insert(1); versions.insert(2); - assert_eq!(Ok(versions), get_versions("1,2")); + assert_eq!(versions, Versions::from_version_string("1,2").unwrap().0); } { - let mut versions: HashSet<u32> = HashSet::new(); + let mut versions: HashSet<Version> = HashSet::new(); versions.insert(1); versions.insert(2); versions.insert(3); - assert_eq!(Ok(versions), get_versions("1-3")); + assert_eq!(versions, Versions::from_version_string("1-3").unwrap().0); } { - let mut versions: HashSet<u32> = HashSet::new(); + let mut versions: HashSet<Version> = HashSet::new(); versions.insert(1); versions.insert(2); versions.insert(5); - assert_eq!(Ok(versions), get_versions("1-2,5")); + assert_eq!(versions, Versions::from_version_string("1-2,5").unwrap().0); } { - let mut versions: HashSet<u32> = HashSet::new(); + let mut versions: HashSet<Version> = HashSet::new(); versions.insert(1); versions.insert(3); versions.insert(4); versions.insert(5); - assert_eq!(Ok(versions), get_versions("1,3-5")); + assert_eq!(versions, Versions::from_version_string("1,3-5").unwrap().0); } }
@@ -852,7 +873,7 @@ mod test { use super::contract_protocol_list;
{ - let mut versions = HashSet::<u32>::new(); + let mut versions = HashSet::<Version>::new(); assert_eq!(String::from(""), contract_protocol_list(&versions));
versions.insert(1); @@ -863,14 +884,14 @@ mod test { }
{ - let mut versions = HashSet::<u32>::new(); + let mut versions = HashSet::<Version>::new(); versions.insert(1); versions.insert(3); assert_eq!(String::from("1,3"), contract_protocol_list(&versions)); }
{ - let mut versions = HashSet::<u32>::new(); + let mut versions = HashSet::<Version>::new(); versions.insert(1); versions.insert(2); versions.insert(3); @@ -879,7 +900,7 @@ mod test { }
{ - let mut versions = HashSet::<u32>::new(); + let mut versions = HashSet::<Version>::new(); versions.insert(1); versions.insert(3); versions.insert(5); @@ -892,7 +913,7 @@ mod test { }
{ - let mut versions = HashSet::<u32>::new(); + let mut versions = HashSet::<Version>::new(); versions.insert(1); versions.insert(2); versions.insert(3);