[tor-commits] [stem/master] Reducing ExitPolicyRule's memory requirements

atagar at torproject.org atagar at torproject.org
Fri Dec 7 06:54:07 UTC 2012


commit 45403096a550186cfbecc035cd54c125e3b0237c
Author: Damian Johnson <atagar at torproject.org>
Date:   Thu Dec 6 22:44:17 2012 -0800

    Reducing ExitPolicyRule's memory requirements
    
    Frequently there's several ExitPolicyRule entries per an exit policy, so when
    we pull the consensus we get quite a few instances of this class. Making the
    following changes to reduce the memory requirements...
    
    * Dropping the rule attribute. The string representation of the rule should be
      good enough, if not better for callers.
    
    * Replacing the address_type attribute with a method for getting it. This lets
      us store the address type as an integer within our class.
    
    * Replacing the mask attribute with a method for getting it. The ip mask
      representation is very rarely useful, so there's little reason to store it
      unless it's requested.
    
    * Lazily loading the integer reprentation of our address and mask, both
      speeding up our constructor and avoiding it entirely if our caller never uses
      is_match()
    
    This lowers the memory requirement for loading the full consensus on my netbook
    from 5.5% to 5.1% (a 7% drop). This is a drop in the bucket compared to the
    prior commit, but between the faster constructor runtime and squeezing out a
    little more performance it's still worth it.
---
 stem/exit_policy.py             |  156 +++++++++++++++++++++++++++-----------
 stem/util/connection.py         |    4 +
 test/unit/exit_policy/policy.py |    9 +--
 test/unit/exit_policy/rule.py   |   14 ++--
 4 files changed, 124 insertions(+), 59 deletions(-)

diff --git a/stem/exit_policy.py b/stem/exit_policy.py
index a6ae028..f9f0eb8 100644
--- a/stem/exit_policy.py
+++ b/stem/exit_policy.py
@@ -33,7 +33,9 @@ exiting to a destination is permissible or not. For instance...
   ExitPolicyRule - Single rule of an exit policy chain
     |- is_address_wildcard - checks if we'll accept any address
     |- is_port_wildcard - checks if we'll accept any port
+    |- get_address_type - provides the protocol our ip address belongs to
     |- is_match - checks if we match a given destination
+    |- get_mask - provides the address representation of our mask
     +- __str__ - string representation for this rule
 
 .. data:: AddressType (enum)
@@ -309,6 +311,7 @@ class MicrodescriptorExitPolicy(ExitPolicy):
         raise ValueError(exc_msg)
     
     super(MicrodescriptorExitPolicy, self).__init__(*rules)
+    self.set_default_allowed(not self.is_accept)
   
   def __str__(self):
     return self._policy
@@ -334,12 +337,9 @@ class ExitPolicyRule(object):
   
   This should be treated as an immutable object.
   
-  :var str rule: rule that we were originally created from
   :var bool is_accept: indicates if exiting is allowed or disallowed
   
-  :var stem.exit_policy.AddressType address_type: type of address that we have
   :var str address: address that this rule is for
-  :var str mask: subnet mask for the address (ex. "255.255.255.0")
   :var int masked_bits: number of bits the subnet mask represents, **None** if
     the mask can't have a bit representation
   
@@ -352,8 +352,6 @@ class ExitPolicyRule(object):
   """
   
   def __init__(self, rule):
-    self.rule = rule
-    
     # policy ::= "accept" exitpattern | "reject" exitpattern
     # exitpattern ::= addrspec ":" portspec
     
@@ -375,24 +373,27 @@ class ExitPolicyRule(object):
       raise ValueError("An exitpattern must be of the form 'addrspec:portspec': %s" % rule)
     
     self.address = None
-    self.address_type = None
-    self.mask = self.masked_bits = None
+    self._address_type = None
+    self.masked_bits = None
     self.min_port = self.max_port = None
     
+    # Our mask in ip notation (ex. "255.255.255.0"). This is only set if we
+    # either have a custom mask that can't be represented by a number of bits,
+    # or the user has called mask(), lazily loading this.
+    
+    self._mask = None
+    
     addrspec, portspec = exitpattern.rsplit(":", 1)
-    self._apply_addrspec(addrspec)
-    self._apply_portspec(portspec)
+    self._apply_addrspec(rule, addrspec)
+    self._apply_portspec(rule, portspec)
     
-    # Pre-calculating the integer representation of our mask and masked
-    # address. These are used by our is_match() method to compare ourselves to
+    # The integer representation of our mask and masked address. These are
+    # lazily loaded and used by our is_match() method to compare ourselves to
     # other addresses.
     
-    if self.is_address_wildcard():
-      # is_match() will short circuit so these are unused
-      self._mask_bin = self._addr_bin = None
-    else:
-      self._mask_bin = int(stem.util.connection.get_address_binary(self.mask), 2)
-      self._addr_bin = int(stem.util.connection.get_address_binary(self.address), 2) & self._mask_bin
+    self._mask_bin = self._addr_bin = None
+    
+    # Lazily loaded string representation of our policy.
     
     self._str_representation = None
   
@@ -406,7 +407,7 @@ class ExitPolicyRule(object):
     :returns: **bool** for if our address matching is a wildcard
     """
     
-    return self.address_type == AddressType.WILDCARD
+    return self._address_type == _address_type_to_int(AddressType.WILDCARD)
   
   def is_port_wildcard(self):
     """
@@ -432,10 +433,12 @@ class ExitPolicyRule(object):
     
     # validate our input and check if the argument doesn't match our address type
     if address is not None:
+      address_type = self.get_address_type()
+      
       if stem.util.connection.is_valid_ip_address(address):
-        if self.address_type == AddressType.IPv6: return False
+        if address_type == AddressType.IPv6: return False
       elif stem.util.connection.is_valid_ipv6_address(address, allow_brackets = True):
-        if self.address_type == AddressType.IPv4: return False
+        if address_type == AddressType.IPv4: return False
         
         address = address.lstrip("[").rstrip("]")
       else:
@@ -453,8 +456,8 @@ class ExitPolicyRule(object):
         return False
       else:
         comparison_addr_bin = int(stem.util.connection.get_address_binary(address), 2)
-        comparison_addr_bin &= self._mask_bin
-        if self._addr_bin != comparison_addr_bin: return False
+        comparison_addr_bin &= self._get_mask_bin()
+        if self._get_address_bin() != comparison_addr_bin: return False
     
     if not self.is_port_wildcard():
       if port is None:
@@ -464,6 +467,43 @@ class ExitPolicyRule(object):
     
     return True
   
+  def get_address_type(self):
+    """
+    Provides the :data:`~stem.exit_policy.AddressType: for our policy.
+    
+    :returns: :data:`~stem.exit_policy.AddressType: for the type of address that we have
+    """
+    
+    return _int_to_address_type(self._address_type)
+  
+  def get_mask(self, cache = True):
+    """
+    Provides the address represented by our mask. This is **None** if our
+    address type is a wildcard.
+    
+    :param bool cache: caches the result if **True**
+    
+    :returns: str of our subnet mask for the address (ex. "255.255.255.0")
+    """
+    
+    # Lazy loading our mask because it very infrequently requested. There's
+    # no reason to usually usse memory for it.
+    
+    address_type = self.get_address_type()
+    
+    if not self._mask:
+      if address_type == AddressType.WILDCARD:
+        mask = None
+      elif address_type == AddressType.IPv4:
+        mask = stem.util.connection.get_mask(self.masked_bits)
+      elif address_type == AddressType.IPv6:
+        mask = stem.util.connection.get_mask_ipv6(self.masked_bits)
+      
+      if not cache: return mask
+      self._mask = mask
+    
+    return self._mask
+  
   def __str__(self):
     """
     Provides the string representation of our policy. This does not
@@ -479,7 +519,9 @@ class ExitPolicyRule(object):
       if self.is_address_wildcard():
         label += "*:"
       else:
-        if self.address_type == AddressType.IPv4:
+        address_type = self.get_address_type()
+        
+        if address_type == AddressType.IPv4:
           label += self.address
         else:
           label += "[%s]" % self.address
@@ -489,12 +531,13 @@ class ExitPolicyRule(object):
         # - use our masked bit count if we can
         # - use the mask itself otherwise
         
-        if self.mask in (stem.util.connection.FULL_IPv4_MASK, stem.util.connection.FULL_IPv6_MASK):
+        if (address_type == AddressType.IPv4 and self.masked_bits == 32) or \
+           (address_type == AddressType.IPv6 and self.masked_bits == 128):
           label += ":"
-        elif not self.masked_bits is None:
+        elif self.masked_bits is not None:
           label += "/%i:" % self.masked_bits
         else:
-          label += "/%s:" % self.mask
+          label += "/%s:" % self.get_mask()
       
       if self.is_port_wildcard():
         label += "*"
@@ -507,7 +550,23 @@ class ExitPolicyRule(object):
     
     return self._str_representation
   
-  def _apply_addrspec(self, addrspec):
+  def _get_mask_bin(self):
+    # provides an integer representation of our mask
+    
+    if self._mask_bin is None:
+      self._mask_bin = int(stem.util.connection.get_address_binary(self.get_mask(False)), 2)
+    
+    return self._mask_bin
+  
+  def _get_address_bin(self):
+    # provides an integer representation of our address
+    
+    if self._addr_bin is None:
+      self._addr_bin = int(stem.util.connection.get_address_binary(self.address), 2) & self._mask_bin
+    
+    return self._addr_bin
+  
+  def _apply_addrspec(self, rule, addrspec):
     # Parses the addrspec...
     # addrspec ::= "*" | ip4spec | ip6spec
     
@@ -517,34 +576,34 @@ class ExitPolicyRule(object):
       self.address, addr_extra = addrspec, None
     
     if addrspec == "*":
-      self.address_type = AddressType.WILDCARD
-      self.address = self.mask = self.masked_bits = None
+      self._address_type = _address_type_to_int(AddressType.WILDCARD)
+      self.address = self.masked_bits = None
     elif stem.util.connection.is_valid_ip_address(self.address):
       # ipv4spec ::= ip4 | ip4 "/" num_ip4_bits | ip4 "/" ip4mask
       # ip4 ::= an IPv4 address in dotted-quad format
       # ip4mask ::= an IPv4 mask in dotted-quad format
       # num_ip4_bits ::= an integer between 0 and 32
       
-      self.address_type = AddressType.IPv4
+      self._address_type = _address_type_to_int(AddressType.IPv4)
       
       if addr_extra is None:
-        self.mask = stem.util.connection.FULL_IPv4_MASK
         self.masked_bits = 32
       elif stem.util.connection.is_valid_ip_address(addr_extra):
         # provided with an ip4mask
-        self.mask = addr_extra
-        
         try:
           self.masked_bits = stem.util.connection.get_masked_bits(addr_extra)
         except ValueError:
           # mask can't be represented as a number of bits (ex. "255.255.0.255")
+          self._mask = addr_extra
           self.masked_bits = None
       elif addr_extra.isdigit():
         # provided with a num_ip4_bits
-        self.mask = stem.util.connection.get_mask(int(addr_extra))
         self.masked_bits = int(addr_extra)
+        
+        if self.masked_bits < 0 or self.masked_bits > 32:
+          raise ValueError("IPv4 masks must be in the range of 0-32 bits")
       else:
-        raise ValueError("The '%s' isn't a mask nor number of bits: %s" % (addr_extra, self.rule))
+        raise ValueError("The '%s' isn't a mask nor number of bits: %s" % (addr_extra, rule))
     elif self.address.startswith("[") and self.address.endswith("]") and \
       stem.util.connection.is_valid_ipv6_address(self.address[1:-1]):
       # ip6spec ::= ip6 | ip6 "/" num_ip6_bits
@@ -552,21 +611,22 @@ class ExitPolicyRule(object):
       # num_ip6_bits ::= an integer between 0 and 128
       
       self.address = stem.util.connection.expand_ipv6_address(self.address[1:-1].upper())
-      self.address_type = AddressType.IPv6
+      self._address_type = _address_type_to_int(AddressType.IPv6)
       
       if addr_extra is None:
-        self.mask = stem.util.connection.FULL_IPv6_MASK
         self.masked_bits = 128
       elif addr_extra.isdigit():
         # provided with a num_ip6_bits
-        self.mask = stem.util.connection.get_mask_ipv6(int(addr_extra))
         self.masked_bits = int(addr_extra)
+        
+        if self.masked_bits < 0 or self.masked_bits > 128:
+          raise ValueError("IPv6 masks must be in the range of 0-128 bits")
       else:
-        raise ValueError("The '%s' isn't a number of bits: %s" % (addr_extra, self.rule))
+        raise ValueError("The '%s' isn't a number of bits: %s" % (addr_extra, rule))
     else:
-      raise ValueError("Address isn't a wildcard, IPv4, or IPv6 address: %s" % self.rule)
+      raise ValueError("Address isn't a wildcard, IPv4, or IPv6 address: %s" % rule)
   
-  def _apply_portspec(self, portspec):
+  def _apply_portspec(self, rule, portspec):
     # Parses the portspec...
     # portspec ::= "*" | port | port "-" port
     # port ::= an integer between 1 and 65535, inclusive.
@@ -581,7 +641,7 @@ class ExitPolicyRule(object):
       if stem.util.connection.is_valid_port(portspec, allow_zero = True):
         self.min_port = self.max_port = int(portspec)
       else:
-        raise ValueError("'%s' isn't within a valid port range: %s" % (portspec, self.rule))
+        raise ValueError("'%s' isn't within a valid port range: %s" % (portspec, rule))
     elif "-" in portspec:
       # provided with a port range
       port_comp = portspec.split("-", 1)
@@ -591,11 +651,11 @@ class ExitPolicyRule(object):
         self.max_port = int(port_comp[1])
         
         if self.min_port > self.max_port:
-          raise ValueError("Port range has a lower bound that's greater than its upper bound: %s" % self.rule)
+          raise ValueError("Port range has a lower bound that's greater than its upper bound: %s" % rule)
       else:
-        raise ValueError("Malformed port range: %s" % self.rule)
+        raise ValueError("Malformed port range: %s" % rule)
     else:
-      raise ValueError("Port value isn't a wildcard, integer, or range: %s" % self.rule)
+      raise ValueError("Port value isn't a wildcard, integer, or range: %s" % rule)
   
   def __eq__(self, other):
     if isinstance(other, ExitPolicyRule):
@@ -608,3 +668,9 @@ class ExitPolicyRule(object):
     else:
       return False
 
+def _address_type_to_int(address_type):
+  return AddressType.index_of(address_type)
+
+def _int_to_address_type(address_type_int):
+  return AddressType[AddressType.keys()[address_type_int]]
+
diff --git a/stem/util/connection.py b/stem/util/connection.py
index a21cde8..0417c87 100644
--- a/stem/util/connection.py
+++ b/stem/util/connection.py
@@ -165,6 +165,8 @@ def get_mask(bits):
   
   if bits > 32 or bits < 0:
     raise ValueError("A mask can only be 0-32 bits, got %i" % bits)
+  elif bits == 32:
+    return FULL_IPv4_MASK
   
   # get the binary representation of the mask
   mask_bin = get_binary(2 ** bits - 1, 32)[::-1]
@@ -213,6 +215,8 @@ def get_mask_ipv6(bits):
   
   if bits > 128 or bits < 0:
     raise ValueError("A mask can only be 0-128 bits, got %i" % bits)
+  elif bits == 128:
+    return FULL_IPv6_MASK
   
   # get the binary representation of the mask
   mask_bin = get_binary(2 ** bits - 1, 128)[::-1]
diff --git a/test/unit/exit_policy/policy.py b/test/unit/exit_policy/policy.py
index 0714ef9..67fbf68 100644
--- a/test/unit/exit_policy/policy.py
+++ b/test/unit/exit_policy/policy.py
@@ -154,22 +154,19 @@ class TestExitPolicy(unittest.TestCase):
         if expect_success: self.fail()
   
   def test_microdescriptor_attributes(self):
-    # checks that its is_accept and ports attributes are properly set
+    # checks that its is_accept attribute is properly set
     
     # single port
     policy = MicrodescriptorExitPolicy('accept 443')
     self.assertTrue(policy.is_accept)
-    self.assertEquals(set([443]), policy.ports)
     
     # multiple ports
     policy = MicrodescriptorExitPolicy('accept 80,443')
     self.assertTrue(policy.is_accept)
-    self.assertEquals(set([80, 443]), policy.ports)
     
     # port range
     policy = MicrodescriptorExitPolicy('reject 1-1024')
     self.assertFalse(policy.is_accept)
-    self.assertEquals(set(range(1, 1025)), policy.ports)
   
   def test_microdescriptor_can_exit_to(self):
     test_inputs = {
@@ -188,6 +185,6 @@ class TestExitPolicy(unittest.TestCase):
     # address argument should be ignored
     policy = MicrodescriptorExitPolicy('accept 80,443')
     
-    self.assertFalse(policy.can_exit_to('blah', 79))
-    self.assertTrue(policy.can_exit_to('blah', 80))
+    self.assertFalse(policy.can_exit_to('127.0.0.1', 79))
+    self.assertTrue(policy.can_exit_to('127.0.0.1', 80))
 
diff --git a/test/unit/exit_policy/rule.py b/test/unit/exit_policy/rule.py
index 55e0c76..994b899 100644
--- a/test/unit/exit_policy/rule.py
+++ b/test/unit/exit_policy/rule.py
@@ -46,7 +46,6 @@ class TestExitPolicyRule(unittest.TestCase):
     
     for rule_arg in test_inputs:
       rule = ExitPolicyRule(rule_arg)
-      self.assertEquals(rule_arg, rule.rule)
       self.assertEquals(rule_arg, str(rule))
   
   def test_str_changed(self):
@@ -60,7 +59,6 @@ class TestExitPolicyRule(unittest.TestCase):
     
     for rule_arg, expected_str in test_inputs.items():
       rule = ExitPolicyRule(rule_arg)
-      self.assertEquals(rule_arg, rule.rule)
       self.assertEquals(expected_str, str(rule))
   
   def test_valid_wildcard(self):
@@ -103,9 +101,9 @@ class TestExitPolicyRule(unittest.TestCase):
   
   def test_wildcard_attributes(self):
     rule = ExitPolicyRule("reject *:*")
-    self.assertEquals(AddressType.WILDCARD, rule.address_type)
+    self.assertEquals(AddressType.WILDCARD, rule.get_address_type())
     self.assertEquals(None, rule.address)
-    self.assertEquals(None, rule.mask)
+    self.assertEquals(None, rule.get_mask())
     self.assertEquals(None, rule.masked_bits)
     self.assertEquals(1, rule.min_port)
     self.assertEquals(65535, rule.max_port)
@@ -122,9 +120,9 @@ class TestExitPolicyRule(unittest.TestCase):
       address, mask, masked_bits = attr
       
       rule = ExitPolicyRule("accept %s:*" % rule_addr)
-      self.assertEquals(AddressType.IPv4, rule.address_type)
+      self.assertEquals(AddressType.IPv4, rule.get_address_type())
       self.assertEquals(address, rule.address)
-      self.assertEquals(mask, rule.mask)
+      self.assertEquals(mask, rule.get_mask())
       self.assertEquals(masked_bits, rule.masked_bits)
   
   def test_invalid_ipv4_addresses(self):
@@ -161,9 +159,9 @@ class TestExitPolicyRule(unittest.TestCase):
       address, mask, masked_bits = attr
       
       rule = ExitPolicyRule("accept %s:*" % rule_addr)
-      self.assertEquals(AddressType.IPv6, rule.address_type)
+      self.assertEquals(AddressType.IPv6, rule.get_address_type())
       self.assertEquals(address, rule.address)
-      self.assertEquals(mask, rule.mask)
+      self.assertEquals(mask, rule.get_mask())
       self.assertEquals(masked_bits, rule.masked_bits)
   
   def test_invalid_ipv6_addresses(self):



More information about the tor-commits mailing list