[tor-commits] [stem/master] Change Certificate to a Field subclass

atagar at torproject.org atagar at torproject.org
Sun Jan 21 02:04:04 UTC 2018


commit 7981a765e15c0e4b7b4154c5c236974b8a3cdfea
Author: Damian Johnson <atagar at torproject.org>
Date:   Sat Jan 20 15:26:32 2018 -0800

    Change Certificate to a Field subclass
    
    Moving packing and unpacking of certificates into the type so we fit the
    pattern we're forming.
---
 stem/client/__init__.py         | 71 +++++++++++++++++++++++++++++------------
 stem/client/cell.py             | 20 ++----------
 test/settings.cfg               |  1 +
 test/unit/client/__init__.py    |  1 +
 test/unit/client/cell.py        | 21 ++++++------
 test/unit/client/certificate.py | 39 ++++++++++++++++++++++
 6 files changed, 106 insertions(+), 47 deletions(-)

diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 13952a00..1fd05f9d 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -51,7 +51,6 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as
   ===================== ===========
 """
 
-import collections
 import io
 import struct
 
@@ -198,24 +197,6 @@ class Size(Field):
     return self.unpack(packed[:self.size]), packed[self.size:]
 
 
-class Certificate(collections.namedtuple('Certificate', ['type', 'value'])):
-  """
-  Relay certificate as defined in tor-spec section 4.2. Certificate types
-  are...
-
-  ====================  ===========
-  Type Value            Description
-  ====================  ===========
-  1                     Link key certificate certified by RSA1024 identity
-  2                     RSA1024 Identity certificate
-  3                     RSA1024 AUTHENTICATE cell link certificate
-  ====================  ===========
-
-  :var int type: certificate type
-  :var bytes value: certificate value
-  """
-
-
 class Address(Field):
   """
   Relay address.
@@ -299,10 +280,60 @@ class Address(Field):
     return Address(addr_type, addr_value), content
 
   def __hash__(self):
-    # no need to include value or type since they're derived from these
     return _hash_attr(self, 'type_int', 'value_bin')
 
 
+class Certificate(Field):
+  """
+  Relay certificate as defined in tor-spec section 4.2.
+
+  :var stem.client.CertType type: certificate type
+  :var int type_int: integer value of the certificate type
+  :var bytes value: certificate value
+  """
+
+  TYPE_FOR_INT = {
+    1: CertType.LINK,
+    2: CertType.IDENTITY,
+    3: CertType.AUTHENTICATE,
+  }
+
+  INT_FOR_TYPE = dict((v, k) for k, v in TYPE_FOR_INT.items())
+
+  def __init__(self, cert_type, value):
+    if isinstance(cert_type, int):
+      self.type = Certificate.TYPE_FOR_INT.get(cert_type, CertType.UNKNOWN)
+      self.type_int = cert_type
+    elif cert_type in CertType:
+      self.type = cert_type
+      self.type_int = Certificate.INT_FOR_TYPE.get(cert_type, -1)
+    else:
+      raise ValueError('Invalid certificate type: %s' % cert_type)
+
+    self.value = value
+
+  def pack(self):
+    cell = io.BytesIO()
+    cell.write(Size.CHAR.pack(self.type_int))
+    cell.write(Size.SHORT.pack(len(self.value)))
+    cell.write(self.value)
+    return cell.getvalue()
+
+  @staticmethod
+  def pop(content):
+    cert_type, content = Size.CHAR.pop(content)
+    cert_size, content = Size.SHORT.pop(content)
+
+    if cert_size > len(content):
+      raise ValueError('CERTS cell should have a certificate with %i bytes, but only had %i remaining' % (cert_size, len(content)))
+
+    cert_bytes, content = split(content, cert_size)
+    return Certificate(cert_type, cert_bytes), content
+
+  def __hash__(self):
+    return _hash_attr(self, 'type_int', 'value')
+
+
 setattr(Size, 'CHAR', Size('CHAR', 1, '!B'))
 setattr(Size, 'SHORT', Size('SHORT', 2, '!H'))
 setattr(Size, 'LONG', Size('LONG', 4, '!L'))
diff --git a/stem/client/cell.py b/stem/client/cell.py
index 332e3f16..074d2079 100644
--- a/stem/client/cell.py
+++ b/stem/client/cell.py
@@ -489,15 +489,7 @@ class CertsCell(Cell):
     :returns: **bytes** with a payload for these versions
     """
 
-    payload = io.BytesIO()
-    payload.write(Size.CHAR.pack(len(certs)))
-
-    for cert in certs:
-      payload.write(Size.CHAR.pack(cert.type))
-      payload.write(Size.SHORT.pack(len(cert.value)))
-      payload.write(cert.value)
-
-    return cls._pack(link_version, payload.getvalue())
+    return cls._pack(link_version, Size.CHAR.pack(len(certs)) + ''.join([cert.pack() for cert in certs]))
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):
@@ -508,14 +500,8 @@ class CertsCell(Cell):
       if not content:
         raise ValueError('CERTS cell indicates it should have %i certificates, but only contained %i' % (cert_count, len(certs)))
 
-      cert_type, content = Size.CHAR.pop(content)
-      cert_size, content = Size.SHORT.pop(content)
-
-      if cert_size > len(content):
-        raise ValueError('CERTS cell should have a certificate with %i bytes, but only had %i remaining' % (cert_size, len(content)))
-
-      cert_bytes, content = split(content, cert_size)
-      certs.append(Certificate(cert_type, cert_bytes))
+      cert, content = Certificate.pop(content)
+      certs.append(cert)
 
     return CertsCell(certs)
 
diff --git a/test/settings.cfg b/test/settings.cfg
index 28694b60..518ae2b2 100644
--- a/test/settings.cfg
+++ b/test/settings.cfg
@@ -232,6 +232,7 @@ test.unit_tests
 |test.unit.response.mapaddress.TestMapAddressResponse
 |test.unit.client.size.TestSize
 |test.unit.client.address.TestAddress
+|test.unit.client.certificate.TestCertificate
 |test.unit.client.cell.TestCell
 |test.unit.connection.authentication.TestAuthenticate
 |test.unit.connection.connect.TestConnect
diff --git a/test/unit/client/__init__.py b/test/unit/client/__init__.py
index b36b72af..7e745352 100644
--- a/test/unit/client/__init__.py
+++ b/test/unit/client/__init__.py
@@ -5,6 +5,7 @@ Unit tests for stem.client.* contents.
 __all__ = [
   'address',
   'cell',
+  'certificate',
   'size',
 ]
 
diff --git a/test/unit/client/cell.py b/test/unit/client/cell.py
index c9aea5bc..7f3318ee 100644
--- a/test/unit/client/cell.py
+++ b/test/unit/client/cell.py
@@ -6,7 +6,7 @@ import datetime
 import os
 import unittest
 
-from stem.client import ZERO, AddrType, Address, Certificate
+from stem.client import ZERO, AddrType, CertType, Address, Certificate
 from test.unit.client import test_data
 
 from stem.client.cell import (
@@ -44,8 +44,8 @@ VPADDING_CELLS = {
 
 CERTS_CELLS = {
   '\x00\x00\x81\x00\x01\x00': [],
-  '\x00\x00\x81\x00\x04\x01\x01\x00\x00': [Certificate(type = 1, value = '')],
-  '\x00\x00\x81\x00\x05\x01\x01\x00\x01\x08': [Certificate(type = 1, value = '\x08')],
+  '\x00\x00\x81\x00\x04\x01\x01\x00\x00': [Certificate(1, '')],
+  '\x00\x00\x81\x00\x05\x01\x01\x00\x01\x08': [Certificate(1, '\x08')],
 }
 
 AUTH_CHALLENGE_CELLS = {
@@ -79,11 +79,11 @@ class TestCell(unittest.TestCase):
 
   def test_unpack_for_new_link(self):
     expected_certs = (
-      (1, '0\x82\x02F0\x82\x01\xaf'),
-      (2, '0\x82\x01\xc90\x82\x012'),
-      (4, '\x01\x04\x00\x06m\x1f'),
-      (5, '\x01\x05\x00\x06m\n\x01'),
-      (7, '\x1a\xa5\xb3\xbd\x88\xb1C'),
+      (CertType.LINK, 1, '0\x82\x02F0\x82\x01\xaf'),
+      (CertType.IDENTITY, 2, '0\x82\x01\xc90\x82\x012'),
+      (CertType.UNKNOWN, 4, '\x01\x04\x00\x06m\x1f'),
+      (CertType.UNKNOWN, 5, '\x01\x05\x00\x06m\n\x01'),
+      (CertType.UNKNOWN, 7, '\x1a\xa5\xb3\xbd\x88\xb1C'),
     )
 
     content = test_data('new_link_cells')
@@ -95,8 +95,9 @@ class TestCell(unittest.TestCase):
     self.assertEqual(CertsCell, type(certs_cell))
     self.assertEqual(len(expected_certs), len(certs_cell.certificates))
 
-    for i, (cert_type, cert_prefix) in enumerate(expected_certs):
+    for i, (cert_type, cert_type_int, cert_prefix) in enumerate(expected_certs):
       self.assertEqual(cert_type, certs_cell.certificates[i].type)
+      self.assertEqual(cert_type_int, certs_cell.certificates[i].type_int)
       self.assertTrue(certs_cell.certificates[i].value.startswith(cert_prefix))
 
     auth_challenge_cell, content = Cell.unpack(content, 2)
@@ -141,7 +142,7 @@ class TestCell(unittest.TestCase):
 
     # extra bytes after the last certificate should be ignored
 
-    self.assertEqual([Certificate(type = 1, value = '\x08')], Cell.unpack('\x00\x00\x81\x00\x07\x01\x01\x00\x01\x08\x06\x04', 2)[0].certificates)
+    self.assertEqual([Certificate(1, '\x08')], Cell.unpack('\x00\x00\x81\x00\x07\x01\x01\x00\x01\x08\x06\x04', 2)[0].certificates)
 
     # ... but truncated or missing certificates should error
 
diff --git a/test/unit/client/certificate.py b/test/unit/client/certificate.py
new file mode 100644
index 00000000..873de51d
--- /dev/null
+++ b/test/unit/client/certificate.py
@@ -0,0 +1,39 @@
+"""
+Unit tests for stem.client.Certificate.
+"""
+
+import unittest
+
+from stem.client import CertType, Certificate
+
+
+class TestCertificate(unittest.TestCase):
+  def test_constructor(self):
+    test_data = (
+      ((1, '\x7f\x00\x00\x01'), (CertType.LINK, 1, '\x7f\x00\x00\x01')),
+      ((2, '\x7f\x00\x00\x01'), (CertType.IDENTITY, 2, '\x7f\x00\x00\x01')),
+      ((3, '\x7f\x00\x00\x01'), (CertType.AUTHENTICATE, 3, '\x7f\x00\x00\x01')),
+      ((4, '\x7f\x00\x00\x01'), (CertType.UNKNOWN, 4, '\x7f\x00\x00\x01')),
+      ((CertType.IDENTITY, '\x7f\x00\x00\x01'), (CertType.IDENTITY, 2, '\x7f\x00\x00\x01')),
+    )
+
+    for (cert_type, cert_value), (expected_type, expected_type_int, expected_value) in test_data:
+      cert = Certificate(cert_type, cert_value)
+      self.assertEqual(expected_type, cert.type)
+      self.assertEqual(expected_type_int, cert.type_int)
+      self.assertEqual(expected_value, cert.value)
+
+  def test_unknown_type(self):
+    cert = Certificate(12, 'hello')
+    self.assertEqual(CertType.UNKNOWN, cert.type)
+    self.assertEqual(12, cert.type_int)
+    self.assertEqual('hello', cert.value)
+
+  def test_packing(self):
+    cert, content = Certificate.pop('\x02\x00\x04\x00\x00\x01\x01\x04\x04aq\x0f\x02\x00\x00\x00\x00')
+    self.assertEqual('\x04\x04aq\x0f\x02\x00\x00\x00\x00', content)
+
+    self.assertEqual(CertType.IDENTITY, cert.type)
+    self.assertEqual(2, cert.type_int)
+    self.assertEqual('\x00\x00\x01\x01', cert.value)
+    self.assertEqual('\x02\x00\x04\x00\x00\x01\x01', cert.pack())





More information about the tor-commits mailing list