[tor-commits] [stem/master] Replacing manual caching with @lru_cache

atagar at torproject.org atagar at torproject.org
Sat Oct 12 23:28:41 UTC 2013


commit 242569af36ebb1f30a5d2629254b55320cfd5298
Author: Damian Johnson <atagar at torproject.org>
Date:   Sat Oct 12 13:11:11 2013 -0700

    Replacing manual caching with @lru_cache
    
    Replacing instances where we do...
    
    def get_stuff():
      if self._stuff is None:
        self._stuff = ... calculated stuff...
    
      return self._stuff
    
    ... with a @lru_cache().
---
 stem/control.py                         |    2 +-
 stem/descriptor/extrainfo_descriptor.py |   19 +++---
 stem/descriptor/microdescriptor.py      |   27 ++++----
 stem/descriptor/server_descriptor.py    |  107 +++++++++++++++----------------
 stem/prereq.py                          |   67 +++++++++----------
 stem/response/getinfo.py                |    2 +-
 stem/util/proc.py                       |   84 +++++++++++-------------
 stem/version.py                         |    1 +
 8 files changed, 150 insertions(+), 159 deletions(-)

diff --git a/stem/control.py b/stem/control.py
index 64eecc4..16f8892 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -802,7 +802,7 @@ class Controller(BaseController):
     try:
       response = self.msg("GETINFO %s" % " ".join(params))
       stem.response.convert("GETINFO", response)
-      response.assert_matches(params)
+      response._assert_matches(params)
 
       # usually we want unicode values under python 3.x
 
diff --git a/stem/descriptor/extrainfo_descriptor.py b/stem/descriptor/extrainfo_descriptor.py
index 7bfbd5a..0d50f75 100644
--- a/stem/descriptor/extrainfo_descriptor.py
+++ b/stem/descriptor/extrainfo_descriptor.py
@@ -82,6 +82,12 @@ from stem.descriptor import (
   _get_descriptor_components,
 )
 
+try:
+  # added in python 3.2
+  from collections import lru_cache
+except ImportError:
+  from stem.util.lru_cache import lru_cache
+
 # known statuses for dirreq-v2-resp and dirreq-v3-resp...
 DirResponse = stem.util.enum.Enum(
   ("OK", "ok"),
@@ -856,18 +862,15 @@ class RelayExtraInfoDescriptor(ExtraInfoDescriptor):
 
   def __init__(self, raw_contents, validate = True):
     self.signature = None
-    self._digest = None
 
     super(RelayExtraInfoDescriptor, self).__init__(raw_contents, validate)
 
+  @lru_cache()
   def digest(self):
-    if self._digest is None:
-      # our digest is calculated from everything except our signature
-      raw_content, ending = str(self), "\nrouter-signature\n"
-      raw_content = raw_content[:raw_content.find(ending) + len(ending)]
-      self._digest = hashlib.sha1(stem.util.str_tools._to_bytes(raw_content)).hexdigest().upper()
-
-    return self._digest
+    # our digest is calculated from everything except our signature
+    raw_content, ending = str(self), "\nrouter-signature\n"
+    raw_content = raw_content[:raw_content.find(ending) + len(ending)]
+    return hashlib.sha1(stem.util.str_tools._to_bytes(raw_content)).hexdigest().upper()
 
   def _parse(self, entries, validate):
     entries = dict(entries)  # shallow copy since we're destructive
diff --git a/stem/descriptor/microdescriptor.py b/stem/descriptor/microdescriptor.py
index afcd777..c48419a 100644
--- a/stem/descriptor/microdescriptor.py
+++ b/stem/descriptor/microdescriptor.py
@@ -75,6 +75,12 @@ from stem.descriptor import (
   _read_until_keywords,
 )
 
+try:
+  # added in python 3.2
+  from collections import lru_cache
+except ImportError:
+  from stem.util.lru_cache import lru_cache
+
 REQUIRED_FIELDS = (
   "onion-key",
 )
@@ -178,7 +184,6 @@ class Microdescriptor(Descriptor):
     self._unrecognized_lines = []
 
     self._annotation_lines = annotations if annotations else []
-    self._annotation_dict = None  # cached breakdown of key/value mappings
 
     entries = _get_descriptor_components(raw_contents, validate)
     self._parse(entries, validate)
@@ -189,6 +194,7 @@ class Microdescriptor(Descriptor):
   def get_unrecognized_lines(self):
     return list(self._unrecognized_lines)
 
+  @lru_cache()
   def get_annotations(self):
     """
     Provides content that appeared prior to the descriptor. If this comes from
@@ -201,19 +207,16 @@ class Microdescriptor(Descriptor):
     :returns: **dict** with the key/value pairs in our annotations
     """
 
-    if self._annotation_dict is None:
-      annotation_dict = {}
-
-      for line in self._annotation_lines:
-        if b" " in line:
-          key, value = line.split(b" ", 1)
-          annotation_dict[key] = value
-        else:
-          annotation_dict[line] = None
+    annotation_dict = {}
 
-      self._annotation_dict = annotation_dict
+    for line in self._annotation_lines:
+      if b" " in line:
+        key, value = line.split(b" ", 1)
+        annotation_dict[key] = value
+      else:
+        annotation_dict[line] = None
 
-    return self._annotation_dict
+    return annotation_dict
 
   def get_annotation_lines(self):
     """
diff --git a/stem/descriptor/server_descriptor.py b/stem/descriptor/server_descriptor.py
index 8d11801..51b382f 100644
--- a/stem/descriptor/server_descriptor.py
+++ b/stem/descriptor/server_descriptor.py
@@ -52,6 +52,12 @@ from stem.descriptor import (
   _read_until_keywords,
 )
 
+try:
+  # added in python 3.2
+  from collections import lru_cache
+except ImportError:
+  from stem.util.lru_cache import lru_cache
+
 # relay descriptors must have exactly one of the following
 REQUIRED_FIELDS = (
   "router",
@@ -273,7 +279,6 @@ class ServerDescriptor(Descriptor):
     self._unrecognized_lines = []
 
     self._annotation_lines = annotations if annotations else []
-    self._annotation_dict = None  # cached breakdown of key/value mappings
 
     # A descriptor contains a series of 'keyword lines' which are simply a
     # keyword followed by an optional value. Lines can also be followed by a
@@ -308,6 +313,7 @@ class ServerDescriptor(Descriptor):
   def get_unrecognized_lines(self):
     return list(self._unrecognized_lines)
 
+  @lru_cache()
   def get_annotations(self):
     """
     Provides content that appeared prior to the descriptor. If this comes from
@@ -321,19 +327,16 @@ class ServerDescriptor(Descriptor):
     :returns: **dict** with the key/value pairs in our annotations
     """
 
-    if self._annotation_dict is None:
-      annotation_dict = {}
-
-      for line in self._annotation_lines:
-        if b" " in line:
-          key, value = line.split(b" ", 1)
-          annotation_dict[key] = value
-        else:
-          annotation_dict[line] = None
+    annotation_dict = {}
 
-      self._annotation_dict = annotation_dict
+    for line in self._annotation_lines:
+      if b" " in line:
+        key, value = line.split(b" ", 1)
+        annotation_dict[key] = value
+      else:
+        annotation_dict[line] = None
 
-    return self._annotation_dict
+    return annotation_dict
 
   def get_annotation_lines(self):
     """
@@ -654,7 +657,6 @@ class RelayDescriptor(ServerDescriptor):
     self.ntor_onion_key = None
     self.signing_key = None
     self.signature = None
-    self._digest = None
 
     super(RelayDescriptor, self).__init__(raw_contents, validate, annotations)
 
@@ -662,6 +664,7 @@ class RelayDescriptor(ServerDescriptor):
     if validate:
       self._validate_content()
 
+  @lru_cache()
   def digest(self):
     """
     Provides the digest of our descriptor's content.
@@ -671,25 +674,22 @@ class RelayDescriptor(ServerDescriptor):
     :raises: ValueError if the digest canot be calculated
     """
 
-    if self._digest is None:
-      # Digest is calculated from everything in the
-      # descriptor except the router-signature.
-
-      raw_descriptor = self.get_bytes()
-      start_token = b"router "
-      sig_token = b"\nrouter-signature\n"
-      start = raw_descriptor.find(start_token)
-      sig_start = raw_descriptor.find(sig_token)
-      end = sig_start + len(sig_token)
-
-      if start >= 0 and sig_start > 0 and end > start:
-        for_digest = raw_descriptor[start:end]
-        digest_hash = hashlib.sha1(stem.util.str_tools._to_bytes(for_digest))
-        self._digest = stem.util.str_tools._to_unicode(digest_hash.hexdigest().upper())
-      else:
-        raise ValueError("unable to calculate digest for descriptor")
+    # Digest is calculated from everything in the
+    # descriptor except the router-signature.
 
-    return self._digest
+    raw_descriptor = self.get_bytes()
+    start_token = b"router "
+    sig_token = b"\nrouter-signature\n"
+    start = raw_descriptor.find(start_token)
+    sig_start = raw_descriptor.find(sig_token)
+    end = sig_start + len(sig_token)
+
+    if start >= 0 and sig_start > 0 and end > start:
+      for_digest = raw_descriptor[start:end]
+      digest_hash = hashlib.sha1(stem.util.str_tools._to_bytes(for_digest))
+      return stem.util.str_tools._to_unicode(digest_hash.hexdigest().upper())
+    else:
+      raise ValueError("unable to calculate digest for descriptor")
 
   def _validate_content(self):
     """
@@ -852,7 +852,6 @@ class BridgeDescriptor(ServerDescriptor):
 
   def __init__(self, raw_contents, validate = True, annotations = None):
     self._digest = None
-    self._scrubbing_issues = None
 
     super(BridgeDescriptor, self).__init__(raw_contents, validate, annotations)
 
@@ -889,6 +888,7 @@ class BridgeDescriptor(ServerDescriptor):
 
     return self.get_scrubbing_issues() == []
 
+  @lru_cache()
   def get_scrubbing_issues(self):
     """
     Provides issues with our scrubbing.
@@ -897,34 +897,31 @@ class BridgeDescriptor(ServerDescriptor):
       scrubbing, this list is empty if we're properly scrubbed
     """
 
-    if self._scrubbing_issues is None:
-      issues = []
-
-      if not self.address.startswith("10."):
-        issues.append("Router line's address should be scrubbed to be '10.x.x.x': %s" % self.address)
+    issues = []
 
-      if self.contact and self.contact != "somebody":
-        issues.append("Contact line should be scrubbed to be 'somebody', but instead had '%s'" % self.contact)
+    if not self.address.startswith("10."):
+      issues.append("Router line's address should be scrubbed to be '10.x.x.x': %s" % self.address)
 
-      for address, _, is_ipv6 in self.or_addresses:
-        if not is_ipv6 and not address.startswith("10."):
-          issues.append("or-address line's address should be scrubbed to be '10.x.x.x': %s" % address)
-        elif is_ipv6 and not address.startswith("fd9f:2e19:3bcf::"):
-          # TODO: this check isn't quite right because we aren't checking that
-          # the next grouping of hex digits contains 1-2 digits
-          issues.append("or-address line's address should be scrubbed to be 'fd9f:2e19:3bcf::xx:xxxx': %s" % address)
+    if self.contact and self.contact != "somebody":
+      issues.append("Contact line should be scrubbed to be 'somebody', but instead had '%s'" % self.contact)
 
-      for line in self.get_unrecognized_lines():
-        if line.startswith("onion-key "):
-          issues.append("Bridge descriptors should have their onion-key scrubbed: %s" % line)
-        elif line.startswith("signing-key "):
-          issues.append("Bridge descriptors should have their signing-key scrubbed: %s" % line)
-        elif line.startswith("router-signature "):
-          issues.append("Bridge descriptors should have their signature scrubbed: %s" % line)
+    for address, _, is_ipv6 in self.or_addresses:
+      if not is_ipv6 and not address.startswith("10."):
+        issues.append("or-address line's address should be scrubbed to be '10.x.x.x': %s" % address)
+      elif is_ipv6 and not address.startswith("fd9f:2e19:3bcf::"):
+        # TODO: this check isn't quite right because we aren't checking that
+        # the next grouping of hex digits contains 1-2 digits
+        issues.append("or-address line's address should be scrubbed to be 'fd9f:2e19:3bcf::xx:xxxx': %s" % address)
 
-      self._scrubbing_issues = issues
+    for line in self.get_unrecognized_lines():
+      if line.startswith("onion-key "):
+        issues.append("Bridge descriptors should have their onion-key scrubbed: %s" % line)
+      elif line.startswith("signing-key "):
+        issues.append("Bridge descriptors should have their signing-key scrubbed: %s" % line)
+      elif line.startswith("router-signature "):
+        issues.append("Bridge descriptors should have their signature scrubbed: %s" % line)
 
-    return self._scrubbing_issues
+    return issues
 
   def _required_fields(self):
     # bridge required fields are the same as a relay descriptor, minus items
diff --git a/stem/prereq.py b/stem/prereq.py
index 097847f..d088480 100644
--- a/stem/prereq.py
+++ b/stem/prereq.py
@@ -22,8 +22,13 @@ Checks for stem dependencies. We require python 2.6 or greater (including the
 import inspect
 import sys
 
-IS_CRYPTO_AVAILABLE = None
-IS_MOCK_AVAILABLE = None
+try:
+  # added in python 3.2
+  from collections import lru_cache
+except ImportError:
+  from stem.util.lru_cache import lru_cache
+
+CRYPTO_UNAVAILABLE = "Unable to import the pycrypto module. Because of this we'll be unable to verify descriptor signature integrity. You can get pycrypto from: https://www.dlitz.net/software/pycrypto/"
 
 
 def check_requirements():
@@ -62,33 +67,28 @@ def is_python_3():
   return sys.version_info[0] == 3
 
 
+ at lru_cache()
 def is_crypto_available():
   """
-  Checks if the pycrypto functions we use are available.
+  Checks if the pycrypto functions we use are available. This is used for
+  verifying relay descriptor signatures.
 
   :returns: **True** if we can use pycrypto and **False** otherwise
   """
 
-  global IS_CRYPTO_AVAILABLE
-
-  if IS_CRYPTO_AVAILABLE is None:
-    from stem.util import log
-
-    try:
-      from Crypto.PublicKey import RSA
-      from Crypto.Util import asn1
-      from Crypto.Util.number import long_to_bytes
-      IS_CRYPTO_AVAILABLE = True
-    except ImportError:
-      IS_CRYPTO_AVAILABLE = False
+  from stem.util import log
 
-      # the code that verifies relay descriptor signatures uses the python-crypto library
-      msg = "Unable to import the pycrypto module. Because of this we'll be unable to verify descriptor signature integrity. You can get pycrypto from: https://www.dlitz.net/software/pycrypto/"
-      log.log_once("stem.prereq.is_crypto_available", log.INFO, msg)
-
-  return IS_CRYPTO_AVAILABLE
+  try:
+    from Crypto.PublicKey import RSA
+    from Crypto.Util import asn1
+    from Crypto.Util.number import long_to_bytes
+    return True
+  except ImportError:
+    log.log_once("stem.prereq.is_crypto_available", log.INFO, CRYPTO_UNAVAILABLE)
+    return False
 
 
+ at lru_cache()
 def is_mock_available():
   """
   Checks if the mock module is available.
@@ -96,24 +96,19 @@ def is_mock_available():
   :returns: **True** if the mock module is available and **False** otherwise
   """
 
-  global IS_MOCK_AVAILABLE
-
-  if IS_MOCK_AVAILABLE is None:
-    try:
-      import mock
-
-      # check for mock's patch.dict() which was introduced in version 0.7.0
+  try:
+    import mock
 
-      if not hasattr(mock.patch, 'dict'):
-        raise ImportError()
+    # check for mock's patch.dict() which was introduced in version 0.7.0
 
-      # check for mock's new_callable argument for patch() which was introduced in version 0.8.0
+    if not hasattr(mock.patch, 'dict'):
+      raise ImportError()
 
-      if not 'new_callable' in inspect.getargspec(mock.patch).args:
-        raise ImportError()
+    # check for mock's new_callable argument for patch() which was introduced in version 0.8.0
 
-      IS_MOCK_AVAILABLE = True
-    except ImportError:
-      IS_MOCK_AVAILABLE = False
+    if not 'new_callable' in inspect.getargspec(mock.patch).args:
+      raise ImportError()
 
-  return IS_MOCK_AVAILABLE
+    return True
+  except ImportError:
+    return False
diff --git a/stem/response/getinfo.py b/stem/response/getinfo.py
index 48c14dd..c613ded 100644
--- a/stem/response/getinfo.py
+++ b/stem/response/getinfo.py
@@ -59,7 +59,7 @@ class GetInfoResponse(stem.response.ControlMessage):
 
       self.entries[key] = value
 
-  def assert_matches(self, params):
+  def _assert_matches(self, params):
     """
     Checks if we match a given set of parameters, and raise a ProtocolError if not.
 
diff --git a/stem/util/proc.py b/stem/util/proc.py
index 401c319..fbff14c 100644
--- a/stem/util/proc.py
+++ b/stem/util/proc.py
@@ -53,15 +53,17 @@ import stem.util.enum
 
 from stem.util import log
 
-# cached system values
-IS_PROC_AVAILABLE, SYS_START_TIME, SYS_PHYSICAL_MEMORY = None, None, None
-CLOCK_TICKS = None
+try:
+  # added in python 3.2
+  from collections import lru_cache
+except ImportError:
+  from stem.util.lru_cache import lru_cache
 
 # os.sysconf is only defined on unix
 try:
   CLOCK_TICKS = os.sysconf(os.sysconf_names["SC_CLK_TCK"])
 except AttributeError:
-  pass
+  CLOCK_TICKS = None
 
 Stat = stem.util.enum.Enum(
   ("COMMAND", "command"), ("CPU_UTIME", "utime"),
@@ -69,6 +71,7 @@ Stat = stem.util.enum.Enum(
 )
 
 
+ at lru_cache()
 def is_available():
   """
   Checks if proc information is available on this platform.
@@ -76,26 +79,20 @@ def is_available():
   :returns: **True** if proc contents exist on this platform, **False** otherwise
   """
 
-  global IS_PROC_AVAILABLE
-
-  if IS_PROC_AVAILABLE is None:
-    if platform.system() != "Linux":
-      IS_PROC_AVAILABLE = False
-    else:
-      # list of process independent proc paths we use
-      proc_paths = ("/proc/stat", "/proc/meminfo", "/proc/net/tcp", "/proc/net/udp")
-      proc_paths_exist = True
-
-      for path in proc_paths:
-        if not os.path.exists(path):
-          proc_paths_exist = False
-          break
+  if platform.system() != "Linux":
+    return False
+  else:
+    # list of process independent proc paths we use
+    proc_paths = ("/proc/stat", "/proc/meminfo", "/proc/net/tcp", "/proc/net/udp")
 
-      IS_PROC_AVAILABLE = proc_paths_exist
+    for path in proc_paths:
+      if not os.path.exists(path):
+        return False
 
-  return IS_PROC_AVAILABLE
+    return True
 
 
+ at lru_cache()
 def get_system_start_time():
   """
   Provides the unix time (seconds since epoch) when the system started.
@@ -105,22 +102,20 @@ def get_system_start_time():
   :raises: **IOError** if it can't be determined
   """
 
-  global SYS_START_TIME
-  if not SYS_START_TIME:
-    start_time, parameter = time.time(), "system start time"
-    btime_line = _get_line("/proc/stat", "btime", parameter)
-
-    try:
-      SYS_START_TIME = float(btime_line.strip().split()[1])
-      _log_runtime(parameter, "/proc/stat[btime]", start_time)
-    except:
-      exc = IOError("unable to parse the /proc/stat btime entry: %s" % btime_line)
-      _log_failure(parameter, exc)
-      raise exc
+  start_time, parameter = time.time(), "system start time"
+  btime_line = _get_line("/proc/stat", "btime", parameter)
 
-  return SYS_START_TIME
+  try:
+    result = float(btime_line.strip().split()[1])
+    _log_runtime(parameter, "/proc/stat[btime]", start_time)
+    return result
+  except:
+    exc = IOError("unable to parse the /proc/stat btime entry: %s" % btime_line)
+    _log_failure(parameter, exc)
+    raise exc
 
 
+ at lru_cache()
 def get_physical_memory():
   """
   Provides the total physical memory on the system in bytes.
@@ -130,20 +125,17 @@ def get_physical_memory():
   :raises: **IOError** if it can't be determined
   """
 
-  global SYS_PHYSICAL_MEMORY
-  if not SYS_PHYSICAL_MEMORY:
-    start_time, parameter = time.time(), "system physical memory"
-    mem_total_line = _get_line("/proc/meminfo", "MemTotal:", parameter)
+  start_time, parameter = time.time(), "system physical memory"
+  mem_total_line = _get_line("/proc/meminfo", "MemTotal:", parameter)
 
-    try:
-      SYS_PHYSICAL_MEMORY = int(mem_total_line.split()[1]) * 1024
-      _log_runtime(parameter, "/proc/meminfo[MemTotal]", start_time)
-    except:
-      exc = IOError("unable to parse the /proc/meminfo MemTotal entry: %s" % mem_total_line)
-      _log_failure(parameter, exc)
-      raise exc
-
-  return SYS_PHYSICAL_MEMORY
+  try:
+    result = int(mem_total_line.split()[1]) * 1024
+    _log_runtime(parameter, "/proc/meminfo[MemTotal]", start_time)
+    return result
+  except:
+    exc = IOError("unable to parse the /proc/meminfo MemTotal entry: %s" % mem_total_line)
+    _log_failure(parameter, exc)
+    raise exc
 
 
 def get_cwd(pid):
diff --git a/stem/version.py b/stem/version.py
index 68faa36..bf41898 100644
--- a/stem/version.py
+++ b/stem/version.py
@@ -247,6 +247,7 @@ class Version(object):
 
     return self._compare(other, lambda s, o: s >= o)
 
+  @lru_cache()
   def __hash__(self):
     my_hash = 0
 



More information about the tor-commits mailing list