[tor-commits] [stem/master] Bundle enums with the numbers they map to

atagar at torproject.org atagar at torproject.org
Wed Feb 7 19:44:51 UTC 2018


commit ab2dfc63394d0caa06a07ce14acc1fd0c784974e
Author: Damian Johnson <atagar at torproject.org>
Date:   Tue Jan 23 11:57:56 2018 -0800

    Bundle enums with the numbers they map to
    
    Putting this all in one place both makes usage nicer, as well as less error
    prone for adding new values.
---
 stem/client/__init__.py | 162 +++++++++++++++++++++++++-----------------------
 stem/client/cell.py     |  57 +----------------
 2 files changed, 86 insertions(+), 133 deletions(-)

diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index fcd0a238..a247c7ba 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -120,56 +120,93 @@ __all__ = [
   'cell',
 ]
 
-AddrType = stem.util.enum.UppercaseEnum(
-  'HOSTNAME',
-  'IPv4',
-  'IPv6',
-  'ERROR_TRANSIENT',
-  'ERROR_PERMANENT',
-  'UNKNOWN',
+
+class _IntegerEnum(stem.util.enum.Enum):
+  """
+  Integer backed enumeration. Enumerations of this type always have an implicit
+  **UNKNOWN** value for integer values that lack a mapping.
+  """
+
+  def __init__(self, *args):
+    self._enum_to_int = {}
+    self._int_to_enum = {}
+    parent_args = []
+
+    for entry in args:
+      if len(entry) == 2:
+        enum, int_val = entry
+        str_val = enum
+      elif len(entry) == 3:
+        enum, str_val, int_val = entry
+      else:
+        raise ValueError('IntegerEnums can only be constructed with two or three value tuples: %s' % repr(entry))
+
+      self._enum_to_int[enum] = int_val
+      self._int_to_enum[int_val] = enum
+      parent_args.append((enum, str_val))
+
+    parent_args.append(('UNKNOWN', 'UNKNOWN'))
+    super(_IntegerEnum, self).__init__(*parent_args)
+
+  def get(self, val):
+    """
+    Privides the (enum, int_value) tuple for a given value.
+    """
+
+    if isinstance(val, int):
+      return self._int_to_enum.get(val, self.UNKNOWN), val
+    elif val in self:
+      return val, self._enum_to_int.get(val, val)
+    else:
+      raise ValueError('Invalid %s type: %s' % (self.__name__, val))
+
+
+AddrType = _IntegerEnum(
+  ('HOSTNAME', 0),
+  ('IPv4', 4),
+  ('IPv6', 6),
+  ('ERROR_TRANSIENT', 16),
+  ('ERROR_PERMANENT', 17),
 )
 
-RelayCommand = stem.util.enum.Enum(
-  ('BEGIN', 'RELAY_BEGIN'),
-  ('DATA', 'RELAY_DATA'),
-  ('END', 'RELAY_END'),
-  ('CONNECTED', 'RELAY_CONNECTED'),
-  ('SENDME', 'RELAY_SENDME'),
-  ('EXTEND', 'RELAY_EXTEND'),
-  ('EXTENDED', 'RELAY_EXTENDED'),
-  ('TRUNCATE', 'RELAY_TRUNCATE'),
-  ('TRUNCATED', 'RELAY_TRUNCATED'),
-  ('DROP', 'RELAY_DROP'),
-  ('RESOLVE', 'RELAY_RESOLVE'),
-  ('RESOLVED', 'RELAY_RESOLVED'),
-  ('BEGIN_DIR', 'RELAY_BEGIN_DIR'),
-  ('EXTEND2', 'RELAY_EXTEND2'),
-  ('EXTENDED2', 'RELAY_EXTENDED2'),
-  ('UNKNOWN', 'UNKNOWN'),
+RelayCommand = _IntegerEnum(
+  ('BEGIN', 'RELAY_BEGIN', 1),
+  ('DATA', 'RELAY_DATA', 2),
+  ('END', 'RELAY_END', 3),
+  ('CONNECTED', 'RELAY_CONNECTED', 4),
+  ('SENDME', 'RELAY_SENDME', 5),
+  ('EXTEND', 'RELAY_EXTEND', 6),
+  ('EXTENDED', 'RELAY_EXTENDED', 7),
+  ('TRUNCATE', 'RELAY_TRUNCATE', 8),
+  ('TRUNCATED', 'RELAY_TRUNCATED', 9),
+  ('DROP', 'RELAY_DROP', 10),
+  ('RESOLVE', 'RELAY_RESOLVE', 11),
+  ('RESOLVED', 'RELAY_RESOLVED', 12),
+  ('BEGIN_DIR', 'RELAY_BEGIN_DIR', 13),
+  ('EXTEND2', 'RELAY_EXTEND2', 14),
+  ('EXTENDED2', 'RELAY_EXTENDED2', 15),
 )
 
-CertType = stem.util.enum.UppercaseEnum(
-  'LINK',
-  'IDENTITY',
-  'AUTHENTICATE',
-  'UNKNOWN',
+CertType = _IntegerEnum(
+  ('LINK', 1),
+  ('IDENTITY', 2),
+  ('AUTHENTICATE', 3),
 )
 
-CloseReason = stem.util.enum.UppercaseEnum(
-  'NONE',
-  'PROTOCOL',
-  'INTERNAL',
-  'REQUESTED',
-  'HIBERNATING',
-  'RESOURCELIMIT',
-  'CONNECTFAILED',
-  'OR_IDENTITY',
-  'OR_CONN_CLOSED',
-  'FINISHED',
-  'TIMEOUT',
-  'DESTROYED',
-  'NOSUCHSERVICE',
-  'UNKNOWN',
+CloseReason = _IntegerEnum(
+  ('NONE', 0),
+  ('PROTOCOL', 1),
+  ('INTERNAL', 2),
+  ('REQUESTED', 3),
+  ('HIBERNATING', 4),
+  ('RESOURCELIMIT', 5),
+  ('CONNECTFAILED', 6),
+  ('OR_IDENTITY', 7),
+  ('OR_CONN_CLOSED', 8),
+  ('FINISHED', 9),
+  ('TIMEOUT', 10),
+  ('DESTROYED', 11),
+  ('NOSUCHSERVICE', 12),
 )
 
 
@@ -298,25 +335,8 @@ class Address(Field):
   :var bytes value_bin: encoded address value
   """
 
-  TYPE_FOR_INT = {
-    0: AddrType.HOSTNAME,
-    4: AddrType.IPv4,
-    6: AddrType.IPv6,
-    16: AddrType.ERROR_TRANSIENT,
-    17: AddrType.ERROR_PERMANENT,
-  }
-
-  INT_FOR_TYPE = dict((v, k) for k, v in TYPE_FOR_INT.items())
-
   def __init__(self, addr_type, value):
-    if isinstance(addr_type, int):
-      self.type = Address.TYPE_FOR_INT.get(addr_type, AddrType.UNKNOWN)
-      self.type_int = addr_type
-    elif addr_type in AddrType:
-      self.type = addr_type
-      self.type_int = Address.INT_FOR_TYPE.get(addr_type, -1)
-    else:
-      raise ValueError('Invalid address type: %s' % addr_type)
+    self.type, self.type_int = AddrType.get(addr_type)
 
     if self.type == AddrType.IPv4:
       if stem.util.connection.is_valid_ipv4_address(value):
@@ -383,24 +403,8 @@ class Certificate(Field):
   :var bytes value: certificate value
   """
 
-  TYPE_FOR_INT = {
-    1: CertType.LINK,
-    2: CertType.IDENTITY,
-    3: CertType.AUTHENTICATE,
-  }
-
-  INT_FOR_TYPE = dict((v, k) for k, v in TYPE_FOR_INT.items())
-
   def __init__(self, cert_type, value):
-    if isinstance(cert_type, int):
-      self.type = Certificate.TYPE_FOR_INT.get(cert_type, CertType.UNKNOWN)
-      self.type_int = cert_type
-    elif cert_type in CertType:
-      self.type = cert_type
-      self.type_int = Certificate.INT_FOR_TYPE.get(cert_type, -1)
-    else:
-      raise ValueError('Invalid certificate type: %s' % cert_type)
-
+    self.type, self.type_int = CertType.get(cert_type)
     self.value = value
 
   def pack(self):
diff --git a/stem/client/cell.py b/stem/client/cell.py
index 0f292c10..9813a85e 100644
--- a/stem/client/cell.py
+++ b/stem/client/cell.py
@@ -44,7 +44,7 @@ import random
 import sys
 
 from stem import UNDEFINED
-from stem.client import ZERO, Address, Certificate, CloseReason, RelayCommand, Size, split
+from stem.client import ZERO, Address, Certificate, CloseReason, Size, split
 from stem.util import _hash_attr, datetime_to_unix
 
 FIXED_PAYLOAD_LEN = 509
@@ -265,26 +265,6 @@ class RelayCell(CircuitCell):
   VALUE = 3
   IS_FIXED_SIZE = True
 
-  COMMAND_FOR_INT = {
-    1: RelayCommand.BEGIN,
-    2: RelayCommand.DATA,
-    3: RelayCommand.END,
-    4: RelayCommand.CONNECTED,
-    5: RelayCommand.SENDME,
-    6: RelayCommand.EXTEND,
-    7: RelayCommand.EXTENDED,
-    8: RelayCommand.TRUNCATE,
-    9: RelayCommand.TRUNCATED,
-    10: RelayCommand.DROP,
-    11: RelayCommand.RESOLVE,
-    12: RelayCommand.RESOLVED,
-    13: RelayCommand.BEGIN_DIR,
-    14: RelayCommand.EXTEND2,
-    15: RelayCommand.EXTENDED2,
-  }
-
-  INT_FOR_COMMANDS = dict((v, k) for k, v in COMMAND_FOR_INT.items())
-
 
 class DestroyCell(CircuitCell):
   """
@@ -298,35 +278,9 @@ class DestroyCell(CircuitCell):
   VALUE = 4
   IS_FIXED_SIZE = True
 
-  REASON_FOR_INT = {
-    0: CloseReason.NONE,
-    1: CloseReason.PROTOCOL,
-    2: CloseReason.INTERNAL,
-    3: CloseReason.REQUESTED,
-    4: CloseReason.HIBERNATING,
-    5: CloseReason.RESOURCELIMIT,
-    6: CloseReason.CONNECTFAILED,
-    7: CloseReason.OR_IDENTITY,
-    8: CloseReason.OR_CONN_CLOSED,
-    9: CloseReason.FINISHED,
-    10: CloseReason.TIMEOUT,
-    11: CloseReason.DESTROYED,
-    12: CloseReason.NOSUCHSERVICE,
-  }
-
-  INT_FOR_REASON = dict((v, k) for k, v in REASON_FOR_INT.items())
-
   def __init__(self, circ_id, reason):
     super(DestroyCell, self).__init__(circ_id)
-
-    if isinstance(reason, int):
-      self.reason = DestroyCell.REASON_FOR_INT.get(reason, CloseReason.UNKNOWN)
-      self.reason_int = reason
-    elif reason in CloseReason:
-      self.reason = reason
-      self.reason_int = DestroyCell.INT_FOR_REASON.get(reason, -1)
-    else:
-      raise ValueError('Invalid closure reason: %s' % reason)
+    self.reason, self.reason_int = CloseReason.get(reason)
 
   @classmethod
   def pack(cls, link_version, circ_id, reason = CloseReason.NONE):
@@ -340,12 +294,7 @@ class DestroyCell(CircuitCell):
     :returns: **bytes** to close the circuit
     """
 
-    reason = DestroyCell.INT_FOR_REASON.get(reason, reason)
-
-    if not isinstance(reason, int):
-      raise ValueError('Invalid closure reason: %s' % reason)
-
-    return cls._pack(link_version, Size.CHAR.pack(reason), circ_id)
+    return cls._pack(link_version, Size.CHAR.pack(CloseReason.get(reason)[1]), circ_id)
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):





More information about the tor-commits mailing list