[tor-commits] [tor/master] Wrap types in protover.rs.

nickm at torproject.org nickm at torproject.org
Thu Feb 8 22:45:35 UTC 2018


commit 124caf28e6db1e7bf8cdfef25c55760c81fb91b5
Author: Corey Farwell <coreyf at rwell.org>
Date:   Wed Dec 20 23:34:05 2017 -0500

    Wrap types in protover.rs.
    
    https://trac.torproject.org/projects/tor/ticket/24030
    
    Introduce new wrapper types:
    
    - `SupportedProtocols`
    - `Versions`
    
    Introduce a type alias:
    
    - `Version` (`u32`)
---
 src/rust/protover/protover.rs | 307 ++++++++++++++++++++++--------------------
 1 file changed, 164 insertions(+), 143 deletions(-)

diff --git a/src/rust/protover/protover.rs b/src/rust/protover/protover.rs
index 5e281a3e9..af0049a41 100644
--- a/src/rust/protover/protover.rs
+++ b/src/rust/protover/protover.rs
@@ -98,94 +98,120 @@ pub fn get_supported_protocols() -> String {
     SUPPORTED_PROTOCOLS.join(" ")
 }
 
-/// Translates a vector representation of a protocol list into a HashMap
-fn parse_protocols<P, S>(
-    protocols: P,
-) -> Result<HashMap<Proto, HashSet<u32>>, &'static str>
-where
-    P: Iterator<Item = S>,
-    S: AsRef<str>,
-{
-    let mut parsed = HashMap::new();
-
-    for subproto in protocols {
-        let (name, version) = get_proto_and_vers(subproto.as_ref())?;
-        parsed.insert(name, version);
+pub struct SupportedProtocols(HashMap<Proto, Versions>);
+
+impl SupportedProtocols {
+    /// # Examples
+    ///
+    /// ```
+    /// use protover::SupportedProtocols;
+    ///
+    /// let supported_protocols = SupportedProtocols::from_proto_entries_string(
+    ///     "HSDir=1-2 HSIntro=3-4"
+    /// );
+    /// ```
+    pub fn from_proto_entries_string(
+        proto_entries: &str,
+    ) -> Result<Self, &'static str> {
+        Self::from_proto_entries(proto_entries.split(" "))
     }
-    Ok(parsed)
-}
 
-/// Translates a string representation of a protocol list to a HashMap
-fn parse_protocols_from_string<'a>(
-    protocol_string: &'a str,
-) -> Result<HashMap<Proto, HashSet<u32>>, &'static str> {
-    parse_protocols(protocol_string.split(" "))
-}
-
-/// Translates supported tor versions from  a string into a HashMap, which is
-/// useful when looking up a specific subprotocol.
-///
-/// # Returns
-///
-/// A `Result` whose `Ok` value is a `HashMap<Proto, <u32>>` holding all
-/// subprotocols and versions currently supported by tor.
-///
-/// The returned `Result`'s `Err` value is an `&'static str` with a description
-/// of the error.
-///
-fn tor_supported() -> Result<HashMap<Proto, HashSet<u32>>, &'static str> {
-    parse_protocols(SUPPORTED_PROTOCOLS.iter())
-}
+    /// ```
+    /// use protover::SupportedProtocols;
+    ///
+    /// let supported_protocols = SupportedProtocols::from_proto_entries([
+    ///     "HSDir=1-2",
+    ///     "HSIntro=3-4",
+    /// ].iter());
+    /// ```
+    pub fn from_proto_entries<I, S>(protocol_strs: I) -> Result<Self, &'static str>
+    where
+        I: Iterator<Item = S>,
+        S: AsRef<str>,
+    {
+        let mut parsed = HashMap::new();
+        for subproto in protocol_strs {
+            let (name, version) = get_proto_and_vers(subproto.as_ref())?;
+            parsed.insert(name, version);
+        }
+        Ok(SupportedProtocols(parsed))
+    }
 
-/// Get the unique version numbers supported by a subprotocol.
-///
-/// # Inputs
-///
-/// * `version_string`, a string comprised of "[0-9,-]"
-///
-/// # Returns
-///
-/// A `Result` whose `Ok` value is a `HashSet<u32>` holding all of the unique
-/// version numbers.  If there were ranges in the `version_string`, then these
-/// are expanded, i.e. `"1-3"` would expand to `HashSet<u32>::new([1, 2, 3])`.
-/// The returned HashSet is *unordered*.
-///
-/// The returned `Result`'s `Err` value is an `&'static str` with a description
-/// of the error.
-///
-/// # Errors
-///
-/// This function will error if:
-///
-/// * the `version_string` is empty or contains an equals (`"="`) sign,
-/// * the expansion of a version range produces an error (see
-///  `expand_version_range`),
-/// * any single version number is not parseable as an `u32` in radix 10, or
-/// * there are greater than 2^16 version numbers to expand.
-///
-fn get_versions(version_string: &str) -> Result<HashSet<u32>, &'static str> {
-    if version_string.is_empty() {
-        return Err("version string is empty");
+    /// Translates supported tor versions from  a string into a HashMap, which
+    /// is useful when looking up a specific subprotocol.
+    ///
+    /// # Returns
+    ///
+    /// A `Result` whose `Ok` value is a `HashMap<Proto, <Version>>` holding all
+    /// subprotocols and versions currently supported by tor.
+    ///
+    /// The returned `Result`'s `Err` value is an `&'static str` with a
+    /// description of the error.
+    ///
+    fn tor_supported() -> Result<Self, &'static str> {
+        Self::from_proto_entries(SUPPORTED_PROTOCOLS.iter().map(|n| *n))
     }
+}
 
-    let mut versions = HashSet::<u32>::new();
+type Version = u32;
+
+/// Set of versions for a protocol.
+#[derive(Debug, PartialEq, Eq)]
+pub struct Versions(HashSet<Version>);
+
+impl Versions {
+    /// Get the unique version numbers supported by a subprotocol.
+    ///
+    /// # Inputs
+    ///
+    /// * `version_string`, a string comprised of "[0-9,-]"
+    ///
+    /// # Returns
+    ///
+    /// A `Result` whose `Ok` value is a `HashSet<u32>` holding all of the unique
+    /// version numbers.  If there were ranges in the `version_string`, then these
+    /// are expanded, i.e. `"1-3"` would expand to `HashSet<u32>::new([1, 2, 3])`.
+    /// The returned HashSet is *unordered*.
+    ///
+    /// The returned `Result`'s `Err` value is an `&'static str` with a description
+    /// of the error.
+    ///
+    /// # Errors
+    ///
+    /// This function will error if:
+    ///
+    /// * the `version_string` is empty or contains an equals (`"="`) sign,
+    /// * the expansion of a version range produces an error (see
+    ///  `expand_version_range`),
+    /// * any single version number is not parseable as an `u32` in radix 10, or
+    /// * there are greater than 2^16 version numbers to expand.
+    ///
+    fn from_version_string(
+        version_string: &str,
+    ) -> Result<Self, &'static str> {
+        if version_string.is_empty() {
+            return Err("version string is empty");
+        }
 
-    for piece in version_string.split(",") {
-        if piece.contains("-") {
-            for p in expand_version_range(piece)? {
-                versions.insert(p);
+        let mut versions = HashSet::<Version>::new();
+
+        for piece in version_string.split(",") {
+            if piece.contains("-") {
+                for p in expand_version_range(piece)? {
+                    versions.insert(p);
+                }
+            } else {
+                versions.insert(u32::from_str(piece).or(
+                    Err("invalid protocol entry"),
+                )?);
             }
-        } else {
-            versions.insert(u32::from_str(piece).or(
-                Err("invalid protocol entry"),
-            )?);
-        }
 
-        if versions.len() > MAX_PROTOCOLS_TO_EXPAND as usize {
-            return Err("Too many versions to expand");
+            if versions.len() > MAX_PROTOCOLS_TO_EXPAND as usize {
+                return Err("Too many versions to expand");
+            }
         }
+        Ok(Versions(versions))
     }
-    Ok(versions)
 }
 
 
@@ -205,7 +231,7 @@ fn get_versions(version_string: &str) -> Result<HashSet<u32>, &'static str> {
 ///
 fn get_proto_and_vers<'a>(
     protocol_entry: &'a str,
-) -> Result<(Proto, HashSet<u32>), &'static str> {
+) -> Result<(Proto, Versions), &'static str> {
     let mut parts = protocol_entry.splitn(2, "=");
 
     let proto = match parts.next() {
@@ -218,7 +244,7 @@ fn get_proto_and_vers<'a>(
         None => return Err("invalid protover entry"),
     };
 
-    let versions = get_versions(vers)?;
+    let versions = Versions::from_version_string(vers)?;
     let proto_name = proto.parse()?;
 
     Ok((proto_name, versions))
@@ -245,19 +271,18 @@ fn contains_only_supported_protocols(proto_entry: &str) -> bool {
         Err(_) => return false,
     };
 
-    let currently_supported: HashMap<Proto, HashSet<u32>> =
-        match tor_supported() {
-            Ok(n) => n,
-            Err(_) => return false,
-        };
+    let currently_supported = match SupportedProtocols::tor_supported() {
+        Ok(n) => n.0,
+        Err(_) => return false,
+    };
 
     let supported_versions = match currently_supported.get(&name) {
         Some(n) => n,
         None => return false,
     };
 
-    vers.retain(|x| !supported_versions.contains(x));
-    vers.is_empty()
+    vers.0.retain(|x| !supported_versions.0.contains(x));
+    vers.0.is_empty()
 }
 
 /// Determine if we support every protocol a client supports, and if not,
@@ -303,7 +328,7 @@ pub fn all_supported(protocols: &str) -> (bool, String) {
 ///
 /// * `list`, a string representation of a list of protocol entries.
 /// * `proto`, a `Proto` to test support for
-/// * `vers`, a `u32` version which we will go on to determine whether the
+/// * `vers`, a `Version` version which we will go on to determine whether the
 /// specified protocol supports.
 ///
 /// # Examples
@@ -321,21 +346,19 @@ pub fn all_supported(protocols: &str) -> (bool, String) {
 pub fn protover_string_supports_protocol(
     list: &str,
     proto: Proto,
-    vers: u32,
+    vers: Version,
 ) -> bool {
-    let supported: HashMap<Proto, HashSet<u32>>;
-
-    match parse_protocols_from_string(list) {
-        Ok(result) => supported = result,
+    let supported = match SupportedProtocols::from_proto_entries_string(list) {
+        Ok(result) => result.0,
         Err(_) => return false,
-    }
+    };
 
     let supported_versions = match supported.get(&proto) {
         Some(n) => n,
         None => return false,
     };
 
-    supported_versions.contains(&vers)
+    supported_versions.0.contains(&vers)
 }
 
 /// As protover_string_supports_protocol(), but also returns True if
@@ -365,23 +388,21 @@ pub fn protover_string_supports_protocol_or_later(
     proto: Proto,
     vers: u32,
 ) -> bool {
-    let supported: HashMap<Proto, HashSet<u32>>;
-
-    match parse_protocols_from_string(list) {
-        Ok(result) => supported = result,
+    let supported = match SupportedProtocols::from_proto_entries_string(list) {
+        Ok(result) => result.0,
         Err(_) => return false,
-    }
+    };
 
     let supported_versions = match supported.get(&proto) {
         Some(n) => n,
         None => return false,
     };
 
-    supported_versions.iter().any(|v| v >= &vers)
+    supported_versions.0.iter().any(|v| v >= &vers)
 }
 
 /// Fully expand a version range. For example, 1-3 expands to 1,2,3
-/// Helper for get_versions
+/// Helper for Versions::from_version_string
 ///
 /// # Inputs
 ///
@@ -483,10 +504,9 @@ fn find_range(list: &Vec<u32>) -> (bool, u32) {
 ///
 /// A `String` representation of this set in ascending order.
 ///
-fn contract_protocol_list<'a>(supported_set: &'a HashSet<u32>) -> String {
-    let mut supported: Vec<u32> = supported_set.iter()
-                                               .map(|x| *x)
-                                               .collect();
+fn contract_protocol_list<'a>(supported_set: &'a HashSet<Version>) -> String {
+    let mut supported: Vec<Version> =
+        supported_set.iter().map(|x| *x).collect();
     supported.sort();
 
     let mut final_output: Vec<String> = Vec::new();
@@ -522,8 +542,8 @@ fn contract_protocol_list<'a>(supported_set: &'a HashSet<u32>) -> String {
 ///
 /// # Returns
 ///
-/// A `Result` whose `Ok` value is a `HashSet<u32>` holding all of the unique
-/// version numbers.
+/// A `Result` whose `Ok` value is a `HashSet<Version>` holding all of the
+/// unique version numbers.
 ///
 /// The returned `Result`'s `Err` value is an `&'static str` with a description
 /// of the error.
@@ -534,12 +554,12 @@ fn contract_protocol_list<'a>(supported_set: &'a HashSet<u32>) -> String {
 ///
 /// * The protocol string does not follow the "protocol_name=version_list"
 /// expected format
-/// * If the version string is malformed. See `get_versions`.
+/// * If the version string is malformed. See `Versions::from_version_string`.
 ///
 fn parse_protocols_from_string_with_no_validation<'a>(
     protocol_string: &'a str,
-) -> Result<HashMap<String, HashSet<u32>>, &'static str> {
-    let mut parsed: HashMap<String, HashSet<u32>> = HashMap::new();
+) -> Result<HashMap<String, Versions>, &'static str> {
+    let mut parsed: HashMap<String, Versions> = HashMap::new();
 
     for subproto in protocol_string.split(" ") {
         let mut parts = subproto.splitn(2, "=");
@@ -554,7 +574,7 @@ fn parse_protocols_from_string_with_no_validation<'a>(
             None => return Err("invalid protover entry"),
         };
 
-        let versions = get_versions(vers)?;
+        let versions = Versions::from_version_string(vers)?;
 
         parsed.insert(String::from(name), versions);
     }
@@ -602,21 +622,22 @@ pub fn compute_vote(
     // }
     // means that FirstSupportedProtocol has three votes which support version
     // 1, and one vote that supports version 2
-    let mut all_count: HashMap<String, HashMap<u32, usize>> = HashMap::new();
+    let mut all_count: HashMap<String, HashMap<Version, usize>> =
+        HashMap::new();
 
     // parse and collect all of the protos and their versions and collect them
     for vote in list_of_proto_strings {
-        let this_vote: HashMap<String, HashSet<u32>> =
+        let this_vote: HashMap<String, Versions> =
             match parse_protocols_from_string_with_no_validation(&vote) {
                 Ok(result) => result,
                 Err(_) => continue,
             };
 
         for (protocol, versions) in this_vote {
-            let supported_vers: &mut HashMap<u32, usize> =
+            let supported_vers: &mut HashMap<Version, usize> =
                 all_count.entry(protocol).or_insert(HashMap::new());
 
-            for version in versions {
+            for version in versions.0 {
                 let counter: &mut usize =
                     supported_vers.entry(version).or_insert(0);
                 *counter += 1;
@@ -690,20 +711,18 @@ fn write_vote_to_string(vote: &HashMap<String, String>) -> String {
 /// let is_supported = is_supported_here(Proto::Link, 1);
 /// assert_eq!(true, is_supported);
 /// ```
-pub fn is_supported_here(proto: Proto, vers: u32) -> bool {
-    let currently_supported: HashMap<Proto, HashSet<u32>>;
-
-    match tor_supported() {
-        Ok(result) => currently_supported = result,
+pub fn is_supported_here(proto: Proto, vers: Version) -> bool {
+    let currently_supported = match SupportedProtocols::tor_supported() {
+        Ok(result) => result.0,
         Err(_) => return false,
-    }
+    };
 
     let supported_versions = match currently_supported.get(&proto) {
         Some(n) => n,
         None => return false,
     };
 
-    supported_versions.contains(&vers)
+    supported_versions.0.contains(&vers)
 }
 
 /// Older versions of Tor cannot infer their own subprotocols
@@ -755,48 +774,50 @@ pub fn compute_for_old_tor(version: &str) -> String {
 
 #[cfg(test)]
 mod test {
+    use super::Version;
+
     #[test]
-    fn test_get_versions() {
+    fn test_versions_from_version_string() {
         use std::collections::HashSet;
 
-        use super::get_versions;
+        use super::Versions;
 
-        assert_eq!(Err("version string is empty"), get_versions(""));
-        assert_eq!(Err("invalid protocol entry"), get_versions("a,b"));
-        assert_eq!(Err("invalid protocol entry"), get_versions("1,!"));
+        assert_eq!(Err("version string is empty"), Versions::from_version_string(""));
+        assert_eq!(Err("invalid protocol entry"), Versions::from_version_string("a,b"));
+        assert_eq!(Err("invalid protocol entry"), Versions::from_version_string("1,!"));
 
         {
-            let mut versions: HashSet<u32> = HashSet::new();
+            let mut versions: HashSet<Version> = HashSet::new();
             versions.insert(1);
-            assert_eq!(Ok(versions), get_versions("1"));
+            assert_eq!(versions, Versions::from_version_string("1").unwrap().0);
         }
         {
-            let mut versions: HashSet<u32> = HashSet::new();
+            let mut versions: HashSet<Version> = HashSet::new();
             versions.insert(1);
             versions.insert(2);
-            assert_eq!(Ok(versions), get_versions("1,2"));
+            assert_eq!(versions, Versions::from_version_string("1,2").unwrap().0);
         }
         {
-            let mut versions: HashSet<u32> = HashSet::new();
+            let mut versions: HashSet<Version> = HashSet::new();
             versions.insert(1);
             versions.insert(2);
             versions.insert(3);
-            assert_eq!(Ok(versions), get_versions("1-3"));
+            assert_eq!(versions, Versions::from_version_string("1-3").unwrap().0);
         }
         {
-            let mut versions: HashSet<u32> = HashSet::new();
+            let mut versions: HashSet<Version> = HashSet::new();
             versions.insert(1);
             versions.insert(2);
             versions.insert(5);
-            assert_eq!(Ok(versions), get_versions("1-2,5"));
+            assert_eq!(versions, Versions::from_version_string("1-2,5").unwrap().0);
         }
         {
-            let mut versions: HashSet<u32> = HashSet::new();
+            let mut versions: HashSet<Version> = HashSet::new();
             versions.insert(1);
             versions.insert(3);
             versions.insert(4);
             versions.insert(5);
-            assert_eq!(Ok(versions), get_versions("1,3-5"));
+            assert_eq!(versions, Versions::from_version_string("1,3-5").unwrap().0);
         }
     }
 
@@ -852,7 +873,7 @@ mod test {
         use super::contract_protocol_list;
 
         {
-            let mut versions = HashSet::<u32>::new();
+            let mut versions = HashSet::<Version>::new();
             assert_eq!(String::from(""), contract_protocol_list(&versions));
 
             versions.insert(1);
@@ -863,14 +884,14 @@ mod test {
         }
 
         {
-            let mut versions = HashSet::<u32>::new();
+            let mut versions = HashSet::<Version>::new();
             versions.insert(1);
             versions.insert(3);
             assert_eq!(String::from("1,3"), contract_protocol_list(&versions));
         }
 
         {
-            let mut versions = HashSet::<u32>::new();
+            let mut versions = HashSet::<Version>::new();
             versions.insert(1);
             versions.insert(2);
             versions.insert(3);
@@ -879,7 +900,7 @@ mod test {
         }
 
         {
-            let mut versions = HashSet::<u32>::new();
+            let mut versions = HashSet::<Version>::new();
             versions.insert(1);
             versions.insert(3);
             versions.insert(5);
@@ -892,7 +913,7 @@ mod test {
         }
 
         {
-            let mut versions = HashSet::<u32>::new();
+            let mut versions = HashSet::<Version>::new();
             versions.insert(1);
             versions.insert(2);
             versions.insert(3);





More information about the tor-commits mailing list