[tor-commits] [stem/master] stem/util/connection.py: Optimize validators.

atagar at torproject.org atagar at torproject.org
Sun Nov 23 00:13:20 UTC 2014


commit 25771febe94b0a3b3885b74222e2291618acb751
Author: Ossi Herrala <oherrala at gmail.com>
Date:   Fri Nov 21 02:07:18 2014 +0200

    stem/util/connection.py: Optimize validators.
    
     * is_valid_ipv4_address() could do just fine without checking the
       address backwards (string -> packed -> string) but this breaks one
       unit test (IP '0.0.00.0').
     * Add one missing test case for is_valid_port().
---
 stem/util/connection.py      |   69 ++++++++++++++----------------------------
 test/unit/util/connection.py |    2 +-
 2 files changed, 24 insertions(+), 47 deletions(-)

diff --git a/stem/util/connection.py b/stem/util/connection.py
index 09acdeb..d1fb5c2 100644
--- a/stem/util/connection.py
+++ b/stem/util/connection.py
@@ -46,6 +46,7 @@ import hmac
 import os
 import platform
 import re
+import socket
 
 import stem.util.proc
 import stem.util.system
@@ -333,23 +334,12 @@ def is_valid_ipv4_address(address):
   :returns: **True** if input is a valid IPv4 address, **False** otherwise
   """
 
-  if not isinstance(address, (bytes, unicode)):
-    return False
-
-  # checks if theres four period separated values
-
-  if address.count('.') != 3:
+  try:
+    packed = socket.inet_pton(socket.AF_INET, address)
+    return socket.inet_ntop(socket.AF_INET, packed) == address
+  except socket.error:
     return False
 
-  # checks that each value in the octet are decimal values between 0-255
-  for entry in address.split('.'):
-    if not entry.isdigit() or int(entry) < 0 or int(entry) > 255:
-      return False
-    elif entry[0] == '0' and len(entry) > 1:
-      return False  # leading zeros, for instance in '1.2.3.001'
-
-  return True
-
 
 def is_valid_ipv6_address(address, allow_brackets = False):
   """
@@ -365,24 +355,11 @@ def is_valid_ipv6_address(address, allow_brackets = False):
     if address.startswith('[') and address.endswith(']'):
       address = address[1:-1]
 
-  # addresses are made up of eight colon separated groups of four hex digits
-  # with leading zeros being optional
-  # https://en.wikipedia.org/wiki/IPv6#Address_format
-
-  colon_count = address.count(':')
-
-  if colon_count > 7:
-    return False  # too many groups
-  elif colon_count != 7 and '::' not in address:
-    return False  # not enough groups and none are collapsed
-  elif address.count('::') > 1 or ':::' in address:
-    return False  # multiple groupings of zeros can't be collapsed
-
-  for entry in address.split(':'):
-    if not re.match('^[0-9a-fA-f]{0,4}$', entry):
-      return False
-
-  return True
+  try:
+    socket.inet_pton(socket.AF_INET6, address)
+    return True
+  except socket.error:
+    return False
 
 
 def is_valid_port(entry, allow_zero = False):
@@ -395,24 +372,24 @@ def is_valid_port(entry, allow_zero = False):
   :returns: **True** if input is an integer and within the valid port range, **False** otherwise
   """
 
-  if isinstance(entry, list):
+  try:
+    value = int(entry)
+    if str(value) != str(entry):
+      return False  # invalid leading char, e.g. space or zero
+    if allow_zero:
+      return value >= 0 and value < 65536
+    else:
+      return value > 0 and value < 65536
+
+  except TypeError:
+    # Maybe entry is list to validate?
     for port in entry:
       if not is_valid_port(port, allow_zero):
         return False
-
     return True
-  elif isinstance(entry, (bytes, unicode)):
-    if not entry.isdigit():
-      return False
-    elif entry[0] == '0' and len(entry) > 1:
-      return False  # leading zeros, ex '001'
 
-    entry = int(entry)
-
-  if allow_zero and entry == 0:
-    return True
-
-  return entry > 0 and entry < 65536
+  except ValueError:
+    return False
 
 
 def is_private_address(address):
diff --git a/test/unit/util/connection.py b/test/unit/util/connection.py
index 0ae6f9b..0766c32 100644
--- a/test/unit/util/connection.py
+++ b/test/unit/util/connection.py
@@ -359,7 +359,7 @@ class TestConnection(unittest.TestCase):
     Checks the is_valid_port function.
     """
 
-    valid_ports = (1, '1', 1234, '1234', 65535, '65535')
+    valid_ports = (1, '1', 1234, '1234', 65535, '65535', [1, '2'])
     invalid_ports = (0, '0', 65536, '65536', 'abc', '*', ' 15', '01')
 
     for port in valid_ports:





More information about the tor-commits mailing list