[tor-commits] [stem/master] Flexable Address constructor

atagar at torproject.org atagar at torproject.org
Sun Jan 21 02:04:04 UTC 2018


commit 13a819231b0aab7aa715d14059ab92ded7d79e1a
Author: Damian Johnson <atagar at torproject.org>
Date:   Fri Jan 19 02:53:47 2018 -0800

    Flexable Address constructor
    
    Callers will need to make addresses themselves, so expanding our Address
    constructor to be more flexable. It now does the types and value conversion so
    you can construct Address instances with either packed or unpacked data.
---
 stem/client/__init__.py      | 75 +++++++++++++++++++++++++++++---------------
 test/unit/client/__init__.py |  3 +-
 test/unit/client/address.py  | 24 ++++++++++++--
 test/unit/client/cell.py     | 10 +++---
 4 files changed, 78 insertions(+), 34 deletions(-)

diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 223956f5..217e32af 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -54,8 +54,11 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as
 import collections
 import struct
 
+import stem.util.connection
 import stem.util.enum
 
+from stem.util import _hash_attr
+
 ZERO = '\x00'
 
 __all__ = [
@@ -78,14 +81,6 @@ CertType = stem.util.enum.UppercaseEnum(
   'UNKNOWN',
 )
 
-ADDR_INT = {
-  0: AddrType.HOSTNAME,
-  4: AddrType.IPv4,
-  6: AddrType.IPv6,
-  16: AddrType.ERROR_TRANSIENT,
-  17: AddrType.ERROR_PERMANENT,
-}
-
 
 def split(content, size):
   """
@@ -149,6 +144,12 @@ class Field(object):
 
     raise NotImplementedError('Not yet available')
 
+  def __eq__(self, other):
+    return hash(self) == hash(other) if isinstance(other, Field) else False
+
+  def __ne__(self, other):
+    return not self == other
+
 
 class Size(Field):
   """
@@ -214,7 +215,7 @@ class Certificate(collections.namedtuple('Certificate', ['type', 'value'])):
   """
 
 
-class Address(collections.namedtuple('Address', ['type', 'type_int', 'value', 'value_bin'])):
+class Address(Field):
   """
   Relay address.
 
@@ -224,13 +225,40 @@ class Address(collections.namedtuple('Address', ['type', 'type_int', 'value', 'v
   :var bytes value_bin: encoded address value
   """
 
-  @staticmethod
-  def pack(addr):
-    """
-    Bytes payload for an address.
-    """
-
-    raise NotImplementedError('Not yet available')
+  TYPE_FOR_INT = {
+    0: AddrType.HOSTNAME,
+    4: AddrType.IPv4,
+    6: AddrType.IPv6,
+    16: AddrType.ERROR_TRANSIENT,
+    17: AddrType.ERROR_PERMANENT,
+  }
+
+  INT_FOR_TYPE = dict((v, k) for k, v in TYPE_FOR_INT.items())
+
+  def __init__(self, addr_type, value):
+    if isinstance(addr_type, int):
+      self.type = Address.TYPE_FOR_INT.get(addr_type, AddrType.UNKNOWN)
+      self.type_int = addr_type
+    elif addr_type in AddrType:
+      self.type = addr_type
+      self.type_int = Address.INT_FOR_TYPE.get(addr_type, -1)
+    else:
+      raise ValueError('Invalid address type: %s' % addr_type)
+
+    if self.type == AddrType.IPv4:
+      if stem.util.connection.is_valid_ipv4_address(value):
+        self.value = value
+        self.value_bin = ''.join([Size.CHAR.pack(int(v)) for v in value.split('.')])
+      else:
+        if len(value) != 4:
+          raise ValueError('Packed IPv4 addresses should be four bytes, but was: %s' % repr(value))
+
+        self.value = '.'.join([str(Size.CHAR.unpack(value[i])) for i in range(4)])
+        self.value_bin = value
+    elif self.type == AddrType.IPv6:
+      self.value, self.value_bin = None, None  # TODO: implement
+    else:
+      self.value, self.value_bin = None, None  # TODO: implement
 
   @staticmethod
   def pop(content):
@@ -239,22 +267,19 @@ class Address(collections.namedtuple('Address', ['type', 'type_int', 'value', 'v
     elif len(content) < 2:
       raise ValueError('Insuffient data for address headers')
 
-    addr_type_int, content = Size.CHAR.pop(content)
-    addr_type = ADDR_INT.get(addr_type_int, AddrType.UNKNOWN)
+    addr_type, content = Size.CHAR.pop(content)
     addr_length, content = Size.CHAR.pop(content)
 
     if len(content) < addr_length:
       raise ValueError('Address specified a payload of %i bytes, but only had %i' % (addr_length, len(content)))
 
-    # TODO: add support for other address types
-
-    address_bin, content = split(content, addr_length)
-    address = None
+    addr_value, content = split(content, addr_length)
 
-    if addr_type == AddrType.IPv4 and len(address_bin) == 4:
-      address = '.'.join([str(Size.CHAR.unpack(address_bin[i])) for i in range(4)])
+    return Address(addr_type, addr_value), content
 
-    return Address(addr_type, addr_type_int, address, address_bin), content
+  def __hash__(self):
+    # no need to include value or type since they're derived from these
+    return _hash_attr(self, 'type_int', 'value_bin')
 
 
 setattr(Size, 'CHAR', Size('CHAR', 1, '!B'))
diff --git a/test/unit/client/__init__.py b/test/unit/client/__init__.py
index fdc7a0c6..b36b72af 100644
--- a/test/unit/client/__init__.py
+++ b/test/unit/client/__init__.py
@@ -3,8 +3,9 @@ Unit tests for stem.client.* contents.
 """
 
 __all__ = [
+  'address',
   'cell',
-  'types',
+  'size',
 ]
 
 import os
diff --git a/test/unit/client/address.py b/test/unit/client/address.py
index bb844264..9e6430ff 100644
--- a/test/unit/client/address.py
+++ b/test/unit/client/address.py
@@ -2,17 +2,35 @@
 Unit tests for stem.client.Address.
 """
 
+import collections
 import unittest
 
-from stem.client import Address
+from stem.client import AddrType, Address
+
+ExpectedAddress = collections.namedtuple('ExpectedAddress', ['type', 'type_int', 'value', 'value_bin'])
 
 
 class TestAddress(unittest.TestCase):
-  def test_ipv4(self):
+  def test_constructor(self):
+    test_data = (
+      ((4, '\x7f\x00\x00\x01'), ExpectedAddress(AddrType.IPv4, 4, '127.0.0.1', '\x7f\x00\x00\x01')),
+      ((4, 'aq\x0f\x02'), ExpectedAddress(AddrType.IPv4, 4, '97.113.15.2', 'aq\x0f\x02')),
+      ((AddrType.IPv4, '127.0.0.1'), ExpectedAddress(AddrType.IPv4, 4, '127.0.0.1', '\x7f\x00\x00\x01')),
+      ((AddrType.IPv4, '97.113.15.2'), ExpectedAddress(AddrType.IPv4, 4, '97.113.15.2', 'aq\x0f\x02')),
+    )
+
+    for (addr_type, addr_value), expected in test_data:
+      addr = Address(addr_type, addr_value)
+      self.assertEqual(expected.type, addr.type)
+      self.assertEqual(expected.type_int, addr.type_int)
+      self.assertEqual(expected.value, addr.value)
+      self.assertEqual(expected.value_bin, addr.value_bin)
+
+  def test_pop_ipv4(self):
     addr, content = Address.pop('\x04\x04\x7f\x00\x00\x01\x01\x04\x04aq\x0f\x02\x00\x00\x00\x00')
     self.assertEqual('\x01\x04\x04aq\x0f\x02\x00\x00\x00\x00', content)
 
-    self.assertEqual('IPv4', addr.type)
+    self.assertEqual(AddrType.IPv4, addr.type)
     self.assertEqual(4, addr.type_int)
     self.assertEqual('127.0.0.1', addr.value)
     self.assertEqual('\x7f\x00\x00\x01', addr.value_bin)
diff --git a/test/unit/client/cell.py b/test/unit/client/cell.py
index a0eab2e5..c9aea5bc 100644
--- a/test/unit/client/cell.py
+++ b/test/unit/client/cell.py
@@ -6,7 +6,7 @@ import datetime
 import os
 import unittest
 
-from stem.client import ZERO, Address, Certificate
+from stem.client import ZERO, AddrType, Address, Certificate
 from test.unit.client import test_data
 
 from stem.client.cell import (
@@ -105,8 +105,8 @@ class TestCell(unittest.TestCase):
     netinfo_cell, content = Cell.unpack(content, 2)
     self.assertEqual(NetinfoCell, type(netinfo_cell))
     self.assertEqual(datetime.datetime(2018, 1, 14, 1, 46, 56), netinfo_cell.timestamp)
-    self.assertEqual(Address(type='IPv4', type_int=4, value='127.0.0.1', value_bin='\x7f\x00\x00\x01'), netinfo_cell.receiver_address)
-    self.assertEqual([Address(type='IPv4', type_int=4, value='97.113.15.2', value_bin='aq\x0f\x02')], netinfo_cell.sender_addresses)
+    self.assertEqual(Address(AddrType.IPv4, '127.0.0.1'), netinfo_cell.receiver_address)
+    self.assertEqual([Address(AddrType.IPv4, '97.113.15.2')], netinfo_cell.sender_addresses)
 
     self.assertEqual('', content)  # check that we've consumed all of the bytes
 
@@ -124,8 +124,8 @@ class TestCell(unittest.TestCase):
     cell = Cell.unpack(NETINFO_CELL, 2)[0]
 
     self.assertEqual(datetime.datetime(2018, 1, 14, 1, 46, 56), cell.timestamp)
-    self.assertEqual(Address(type='IPv4', type_int=4, value='127.0.0.1', value_bin='\x7f\x00\x00\x01'), cell.receiver_address)
-    self.assertEqual([Address(type='IPv4', type_int=4, value='97.113.15.2', value_bin='aq\x0f\x02')], cell.sender_addresses)
+    self.assertEqual(Address(AddrType.IPv4, '127.0.0.1'), cell.receiver_address)
+    self.assertEqual([Address(AddrType.IPv4, '97.113.15.2')], cell.sender_addresses)
 
   def test_vpadding_packing(self):
     for cell_bytes, payload in VPADDING_CELLS.items():





More information about the tor-commits mailing list