[tor-commits] [stem/master] Move attribute validation into Directory constructors

atagar at torproject.org atagar at torproject.org
Tue May 8 20:20:09 UTC 2018


commit cdffc5a9c81145877907ec0135c45d24e6676ae6
Author: Damian Johnson <atagar at torproject.org>
Date:   Tue May 8 11:31:11 2018 -0700

    Move attribute validation into Directory constructors
    
    Doing validation in the constructors not only makes sense, but lets us
    deduplicate this.
---
 stem/directory.py                | 90 ++++++++++++++++------------------------
 test/unit/directory/authority.py |  6 +--
 test/unit/directory/fallback.py  | 17 ++++----
 3 files changed, 47 insertions(+), 66 deletions(-)

diff --git a/stem/directory.py b/stem/directory.py
index c70cbecf..acd24a80 100644
--- a/stem/directory.py
+++ b/stem/directory.py
@@ -102,9 +102,22 @@ class Directory(object):
   """
 
   def __init__(self, address, or_port, dir_port, fingerprint, nickname):
+    identifier = '%s (%s)' % (fingerprint, nickname) if nickname else fingerprint
+
+    if not connection.is_valid_ipv4_address(address):
+      raise ValueError('%s has an invalid IPv4 address: %s' % (identifier, address))
+    elif not connection.is_valid_port(or_port):
+      raise ValueError('%s has an invalid ORPort: %s' % (identifier, or_port))
+    elif not connection.is_valid_port(dir_port):
+      raise ValueError('%s has an invalid DirPort: %s' % (identifier, dir_port))
+    elif not tor_tools.is_valid_fingerprint(fingerprint):
+      raise ValueError('%s has an invalid fingerprint: %s' % (identifier, fingerprint))
+    elif nickname and not tor_tools.is_valid_nickname(nickname):
+      raise ValueError('%s has an invalid nickname: %s' % (fingerprint, nickname))
+
     self.address = address
-    self.or_port = or_port
-    self.dir_port = dir_port
+    self.or_port = int(or_port)
+    self.dir_port = int(dir_port)
     self.fingerprint = fingerprint
     self.nickname = nickname
 
@@ -183,6 +196,11 @@ class Authority(Directory):
 
   def __init__(self, address = None, or_port = None, dir_port = None, fingerprint = None, nickname = None, v3ident = None, is_bandwidth_authority = False):
     super(Authority, self).__init__(address, or_port, dir_port, fingerprint, nickname)
+    identifier = '%s (%s)' % (fingerprint, nickname) if nickname else fingerprint
+
+    if v3ident and not tor_tools.is_valid_fingerprint(v3ident):
+      raise ValueError('%s has an invalid v3ident: %s' % (identifier, v3ident))
+
     self.v3ident = v3ident
     self.is_bandwidth_authority = is_bandwidth_authority
 
@@ -253,33 +271,14 @@ class Authority(Directory):
 
     nickname, or_port = matches.get(AUTHORITY_NAME)
     v3ident = matches.get(AUTHORITY_V3IDENT)
-    orport_v6 = matches.get(AUTHORITY_IPV6)  # TODO: add this to stem's data?
+    # orport_v6 = matches.get(AUTHORITY_IPV6)  # TODO: add this to stem's data?
     address, dir_port, fingerprint = matches.get(AUTHORITY_ADDR)
 
-    fingerprint = fingerprint.replace(' ', '')
-
-    if not connection.is_valid_ipv4_address(address):
-      raise ValueError('%s has an invalid IPv4 address: %s' % (nickname, address))
-    elif not connection.is_valid_port(or_port):
-      raise ValueError('%s has an invalid or_port: %s' % (nickname, or_port))
-    elif not connection.is_valid_port(dir_port):
-      raise ValueError('%s has an invalid dir_port: %s' % (nickname, dir_port))
-    elif not tor_tools.is_valid_fingerprint(fingerprint):
-      raise ValueError('%s has an invalid fingerprint: %s' % (nickname, fingerprint))
-    elif nickname and not tor_tools.is_valid_nickname(nickname):
-      raise ValueError('%s has an invalid nickname: %s' % (nickname, nickname))
-    elif orport_v6 and not connection.is_valid_ipv6_address(orport_v6[0]):
-      raise ValueError('%s has an invalid IPv6 address: %s' % (nickname, orport_v6[0]))
-    elif orport_v6 and not connection.is_valid_port(orport_v6[1]):
-      raise ValueError('%s has an invalid ORPort for its IPv6 endpoint: %s' % (nickname, orport_v6[1]))
-    elif v3ident and not tor_tools.is_valid_fingerprint(v3ident):
-      raise ValueError('%s has an invalid v3ident: %s' % (nickname, v3ident))
-
     return Authority(
       address = address,
-      or_port = int(or_port),
-      dir_port = int(dir_port),
-      fingerprint = fingerprint,
+      or_port = or_port,
+      dir_port = dir_port,
+      fingerprint = fingerprint.replace(' ', ''),
       nickname = nickname,
       v3ident = v3ident,
     )
@@ -353,9 +352,18 @@ class Fallback(Directory):
 
   def __init__(self, address = None, or_port = None, dir_port = None, fingerprint = None, nickname = None, has_extrainfo = False, orport_v6 = None, header = None):
     super(Fallback, self).__init__(address, or_port, dir_port, fingerprint, nickname)
+    identifier = '%s (%s)' % (fingerprint, nickname) if nickname else fingerprint
+
+    if orport_v6:
+      if not isinstance(orport_v6, tuple) or len(orport_v6) != 2:
+        raise ValueError('%s orport_v6 should be a two value tuple: %s' % (identifier, str(orport_v6)))
+      elif not connection.is_valid_ipv6_address(orport_v6[0]):
+        raise ValueError('%s has an invalid IPv6 address: %s' % (identifier, orport_v6[0]))
+      elif not connection.is_valid_port(orport_v6[1]):
+        raise ValueError('%s has an invalid IPv6 port: %s' % (identifier, orport_v6[1]))
 
     self.has_extrainfo = has_extrainfo
-    self.orport_v6 = orport_v6
+    self.orport_v6 = (orport_v6[0], int(orport_v6[1])) if orport_v6 else None
     self.header = header if header else OrderedDict()
 
   @staticmethod
@@ -379,19 +387,6 @@ class Fallback(Directory):
         if not attr[attr_name] and attr_name not in ('nickname', 'has_extrainfo', 'orport6_address', 'orport6_port'):
           raise IOError("'%s' is missing from %s" % (key, FALLBACK_CACHE_PATH))
 
-      if not connection.is_valid_ipv4_address(attr['address']):
-        raise IOError("'%s.address' was an invalid IPv4 address (%s)" % (fingerprint, attr['address']))
-      elif not connection.is_valid_port(attr['or_port']):
-        raise IOError("'%s.or_port' was an invalid port (%s)" % (fingerprint, attr['or_port']))
-      elif not connection.is_valid_port(attr['dir_port']):
-        raise IOError("'%s.dir_port' was an invalid port (%s)" % (fingerprint, attr['dir_port']))
-      elif attr['nickname'] and not tor_tools.is_valid_nickname(attr['nickname']):
-        raise IOError("'%s.nickname' was an invalid nickname (%s)" % (fingerprint, attr['nickname']))
-      elif attr['orport6_address'] and not connection.is_valid_ipv6_address(attr['orport6_address']):
-        raise IOError("'%s.orport6_address' was an invalid IPv6 address (%s)" % (fingerprint, attr['orport6_address']))
-      elif attr['orport6_port'] and not connection.is_valid_port(attr['orport6_port']):
-        raise IOError("'%s.orport6_port' was an invalid port (%s)" % (fingerprint, attr['orport6_port']))
-
       if attr['orport6_address'] and attr['orport6_port']:
         orport_v6 = (attr['orport6_address'], int(attr['orport6_port']))
       else:
@@ -496,21 +491,6 @@ class Fallback(Directory):
     has_extrainfo = matches.get(FALLBACK_EXTRAINFO) == '1'
     orport_v6 = matches.get(FALLBACK_IPV6)
 
-    if not connection.is_valid_ipv4_address(address):
-      raise ValueError('%s has an invalid IPv4 address: %s' % (fingerprint, address))
-    elif not connection.is_valid_port(or_port):
-      raise ValueError('%s has an invalid or_port: %s' % (fingerprint, or_port))
-    elif not connection.is_valid_port(dir_port):
-      raise ValueError('%s has an invalid dir_port: %s' % (fingerprint, dir_port))
-    elif not tor_tools.is_valid_fingerprint(fingerprint):
-      raise ValueError('%s has an invalid fingerprint: %s' % (fingerprint, fingerprint))
-    elif nickname and not tor_tools.is_valid_nickname(nickname):
-      raise ValueError('%s has an invalid nickname: %s' % (fingerprint, nickname))
-    elif orport_v6 and not connection.is_valid_ipv6_address(orport_v6[0]):
-      raise ValueError('%s has an invalid IPv6 address: %s' % (fingerprint, orport_v6[0]))
-    elif orport_v6 and not connection.is_valid_port(orport_v6[1]):
-      raise ValueError('%s has an invalid ORPort for its IPv6 endpoint: %s' % (fingerprint, orport_v6[1]))
-
     return Fallback(
       address = address,
       or_port = int(or_port),
@@ -518,7 +498,7 @@ class Fallback(Directory):
       fingerprint = fingerprint,
       nickname = nickname,
       has_extrainfo = has_extrainfo,
-      orport_v6 = (orport_v6[0], int(orport_v6[1])) if orport_v6 else None,
+      orport_v6 = (orport_v6[0], orport_v6[1]) if orport_v6 else None,
     )
 
   @staticmethod
diff --git a/test/unit/directory/authority.py b/test/unit/directory/authority.py
index 431cbf52..8fd70880 100644
--- a/test/unit/directory/authority.py
+++ b/test/unit/directory/authority.py
@@ -42,9 +42,9 @@ class TestAuthority(unittest.TestCase):
 
     for attr in authority_attr:
       for value in (None, 'something else'):
-        second_authority = dict(authority_attr)
-        second_authority[attr] = value
-        self.assertNotEqual(stem.directory.Authority(**authority_attr), stem.directory.Authority(**second_authority))
+        second_authority = stem.directory.Authority(**authority_attr)
+        setattr(second_authority, attr, value)
+        self.assertNotEqual(stem.directory.Authority(**authority_attr), second_authority)
 
   def test_from_cache(self):
     authorities = stem.directory.Authority.from_cache()
diff --git a/test/unit/directory/fallback.py b/test/unit/directory/fallback.py
index 1cbff9a3..c2d66173 100644
--- a/test/unit/directory/fallback.py
+++ b/test/unit/directory/fallback.py
@@ -3,6 +3,7 @@ Unit tests for stem.directory.Fallback.
 """
 
 import io
+import re
 import tempfile
 import unittest
 
@@ -91,9 +92,9 @@ class TestFallback(unittest.TestCase):
 
     for attr in fallback_attr:
       for value in (None, 'something else'):
-        second_fallback = dict(fallback_attr)
-        second_fallback[attr] = value
-        self.assertNotEqual(stem.directory.Fallback(**fallback_attr), stem.directory.Fallback(**second_fallback))
+        second_fallback = stem.directory.Fallback(**fallback_attr)
+        setattr(second_fallback, attr, value)
+        self.assertNotEqual(stem.directory.Fallback(**fallback_attr), second_fallback)
 
   def test_from_cache(self):
     fallbacks = stem.directory.Fallback.from_cache()
@@ -155,15 +156,15 @@ class TestFallback(unittest.TestCase):
   def test_from_str_malformed(self):
     test_values = {
       FALLBACK_ENTRY.replace(b'id=0756B7CD4DFC8182BE23143FAC0642F515182CEB', b''): 'Malformed fallback address line:',
-      FALLBACK_ENTRY.replace(b'5.9.110.236', b'5.9.110'): '0756B7CD4DFC8182BE23143FAC0642F515182CEB has an invalid IPv4 address: 5.9.110',
-      FALLBACK_ENTRY.replace(b':9030', b':7814713228'): '0756B7CD4DFC8182BE23143FAC0642F515182CEB has an invalid dir_port: 7814713228',
-      FALLBACK_ENTRY.replace(b'orport=9001', b'orport=7814713228'): '0756B7CD4DFC8182BE23143FAC0642F515182CEB has an invalid or_port: 7814713228',
-      FALLBACK_ENTRY.replace(b'ipv6=[2a01', b'ipv6=[:::'): '0756B7CD4DFC8182BE23143FAC0642F515182CEB has an invalid IPv6 address: ::::4f8:162:51e2::2',
+      FALLBACK_ENTRY.replace(b'5.9.110.236', b'5.9.110'): '0756B7CD4DFC8182BE23143FAC0642F515182CEB (rueckgrat) has an invalid IPv4 address: 5.9.110',
+      FALLBACK_ENTRY.replace(b':9030', b':7814713228'): '0756B7CD4DFC8182BE23143FAC0642F515182CEB (rueckgrat) has an invalid DirPort: 7814713228',
+      FALLBACK_ENTRY.replace(b'orport=9001', b'orport=7814713228'): '0756B7CD4DFC8182BE23143FAC0642F515182CEB (rueckgrat) has an invalid ORPort: 7814713228',
+      FALLBACK_ENTRY.replace(b'ipv6=[2a01', b'ipv6=[:::'): '0756B7CD4DFC8182BE23143FAC0642F515182CEB (rueckgrat) has an invalid IPv6 address: ::::4f8:162:51e2::2',
       FALLBACK_ENTRY.replace(b'nickname=rueckgrat', b'nickname=invalid~nickname'): '0756B7CD4DFC8182BE23143FAC0642F515182CEB has an invalid nickname: invalid~nickname',
     }
 
     for entry, expected in test_values.items():
-      self.assertRaisesRegexp(ValueError, expected, stem.directory.Fallback._from_str, entry)
+      self.assertRaisesRegexp(ValueError, re.escape(expected), stem.directory.Fallback._from_str, entry)
 
   def test_persistence(self):
     expected = {





More information about the tor-commits mailing list