commit 7981a765e15c0e4b7b4154c5c236974b8a3cdfea Author: Damian Johnson atagar@torproject.org Date: Sat Jan 20 15:26:32 2018 -0800
Change Certificate to a Field subclass
Moving packing and unpacking of certificates into the type so we fit the pattern we're forming. --- stem/client/__init__.py | 71 +++++++++++++++++++++++++++++------------ stem/client/cell.py | 20 ++---------- test/settings.cfg | 1 + test/unit/client/__init__.py | 1 + test/unit/client/cell.py | 21 ++++++------ test/unit/client/certificate.py | 39 ++++++++++++++++++++++ 6 files changed, 106 insertions(+), 47 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py index 13952a00..1fd05f9d 100644 --- a/stem/client/__init__.py +++ b/stem/client/__init__.py @@ -51,7 +51,6 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as ===================== =========== """
-import collections import io import struct
@@ -198,24 +197,6 @@ class Size(Field): return self.unpack(packed[:self.size]), packed[self.size:]
-class Certificate(collections.namedtuple('Certificate', ['type', 'value'])): - """ - Relay certificate as defined in tor-spec section 4.2. Certificate types - are... - - ==================== =========== - Type Value Description - ==================== =========== - 1 Link key certificate certified by RSA1024 identity - 2 RSA1024 Identity certificate - 3 RSA1024 AUTHENTICATE cell link certificate - ==================== =========== - - :var int type: certificate type - :var bytes value: certificate value - """ - - class Address(Field): """ Relay address. @@ -299,10 +280,60 @@ class Address(Field): return Address(addr_type, addr_value), content
def __hash__(self): - # no need to include value or type since they're derived from these return _hash_attr(self, 'type_int', 'value_bin')
+class Certificate(Field): + """ + Relay certificate as defined in tor-spec section 4.2. + + :var stem.client.CertType type: certificate type + :var int type_int: integer value of the certificate type + :var bytes value: certificate value + """ + + TYPE_FOR_INT = { + 1: CertType.LINK, + 2: CertType.IDENTITY, + 3: CertType.AUTHENTICATE, + } + + INT_FOR_TYPE = dict((v, k) for k, v in TYPE_FOR_INT.items()) + + def __init__(self, cert_type, value): + if isinstance(cert_type, int): + self.type = Certificate.TYPE_FOR_INT.get(cert_type, CertType.UNKNOWN) + self.type_int = cert_type + elif cert_type in CertType: + self.type = cert_type + self.type_int = Certificate.INT_FOR_TYPE.get(cert_type, -1) + else: + raise ValueError('Invalid certificate type: %s' % cert_type) + + self.value = value + + def pack(self): + cell = io.BytesIO() + cell.write(Size.CHAR.pack(self.type_int)) + cell.write(Size.SHORT.pack(len(self.value))) + cell.write(self.value) + return cell.getvalue() + + @staticmethod + def pop(content): + cert_type, content = Size.CHAR.pop(content) + cert_size, content = Size.SHORT.pop(content) + + if cert_size > len(content): + raise ValueError('CERTS cell should have a certificate with %i bytes, but only had %i remaining' % (cert_size, len(content))) + + cert_bytes, content = split(content, cert_size) + return Certificate(cert_type, cert_bytes), content + + def __hash__(self): + return _hash_attr(self, 'type_int', 'value') + + setattr(Size, 'CHAR', Size('CHAR', 1, '!B')) setattr(Size, 'SHORT', Size('SHORT', 2, '!H')) setattr(Size, 'LONG', Size('LONG', 4, '!L')) diff --git a/stem/client/cell.py b/stem/client/cell.py index 332e3f16..074d2079 100644 --- a/stem/client/cell.py +++ b/stem/client/cell.py @@ -489,15 +489,7 @@ class CertsCell(Cell): :returns: **bytes** with a payload for these versions """
- payload = io.BytesIO() - payload.write(Size.CHAR.pack(len(certs))) - - for cert in certs: - payload.write(Size.CHAR.pack(cert.type)) - payload.write(Size.SHORT.pack(len(cert.value))) - payload.write(cert.value) - - return cls._pack(link_version, payload.getvalue()) + return cls._pack(link_version, Size.CHAR.pack(len(certs)) + ''.join([cert.pack() for cert in certs]))
@classmethod def _unpack(cls, content, circ_id, link_version): @@ -508,14 +500,8 @@ class CertsCell(Cell): if not content: raise ValueError('CERTS cell indicates it should have %i certificates, but only contained %i' % (cert_count, len(certs)))
- cert_type, content = Size.CHAR.pop(content) - cert_size, content = Size.SHORT.pop(content) - - if cert_size > len(content): - raise ValueError('CERTS cell should have a certificate with %i bytes, but only had %i remaining' % (cert_size, len(content))) - - cert_bytes, content = split(content, cert_size) - certs.append(Certificate(cert_type, cert_bytes)) + cert, content = Certificate.pop(content) + certs.append(cert)
return CertsCell(certs)
diff --git a/test/settings.cfg b/test/settings.cfg index 28694b60..518ae2b2 100644 --- a/test/settings.cfg +++ b/test/settings.cfg @@ -232,6 +232,7 @@ test.unit_tests |test.unit.response.mapaddress.TestMapAddressResponse |test.unit.client.size.TestSize |test.unit.client.address.TestAddress +|test.unit.client.certificate.TestCertificate |test.unit.client.cell.TestCell |test.unit.connection.authentication.TestAuthenticate |test.unit.connection.connect.TestConnect diff --git a/test/unit/client/__init__.py b/test/unit/client/__init__.py index b36b72af..7e745352 100644 --- a/test/unit/client/__init__.py +++ b/test/unit/client/__init__.py @@ -5,6 +5,7 @@ Unit tests for stem.client.* contents. __all__ = [ 'address', 'cell', + 'certificate', 'size', ]
diff --git a/test/unit/client/cell.py b/test/unit/client/cell.py index c9aea5bc..7f3318ee 100644 --- a/test/unit/client/cell.py +++ b/test/unit/client/cell.py @@ -6,7 +6,7 @@ import datetime import os import unittest
-from stem.client import ZERO, AddrType, Address, Certificate +from stem.client import ZERO, AddrType, CertType, Address, Certificate from test.unit.client import test_data
from stem.client.cell import ( @@ -44,8 +44,8 @@ VPADDING_CELLS = {
CERTS_CELLS = { '\x00\x00\x81\x00\x01\x00': [], - '\x00\x00\x81\x00\x04\x01\x01\x00\x00': [Certificate(type = 1, value = '')], - '\x00\x00\x81\x00\x05\x01\x01\x00\x01\x08': [Certificate(type = 1, value = '\x08')], + '\x00\x00\x81\x00\x04\x01\x01\x00\x00': [Certificate(1, '')], + '\x00\x00\x81\x00\x05\x01\x01\x00\x01\x08': [Certificate(1, '\x08')], }
AUTH_CHALLENGE_CELLS = { @@ -79,11 +79,11 @@ class TestCell(unittest.TestCase):
def test_unpack_for_new_link(self): expected_certs = ( - (1, '0\x82\x02F0\x82\x01\xaf'), - (2, '0\x82\x01\xc90\x82\x012'), - (4, '\x01\x04\x00\x06m\x1f'), - (5, '\x01\x05\x00\x06m\n\x01'), - (7, '\x1a\xa5\xb3\xbd\x88\xb1C'), + (CertType.LINK, 1, '0\x82\x02F0\x82\x01\xaf'), + (CertType.IDENTITY, 2, '0\x82\x01\xc90\x82\x012'), + (CertType.UNKNOWN, 4, '\x01\x04\x00\x06m\x1f'), + (CertType.UNKNOWN, 5, '\x01\x05\x00\x06m\n\x01'), + (CertType.UNKNOWN, 7, '\x1a\xa5\xb3\xbd\x88\xb1C'), )
content = test_data('new_link_cells') @@ -95,8 +95,9 @@ class TestCell(unittest.TestCase): self.assertEqual(CertsCell, type(certs_cell)) self.assertEqual(len(expected_certs), len(certs_cell.certificates))
- for i, (cert_type, cert_prefix) in enumerate(expected_certs): + for i, (cert_type, cert_type_int, cert_prefix) in enumerate(expected_certs): self.assertEqual(cert_type, certs_cell.certificates[i].type) + self.assertEqual(cert_type_int, certs_cell.certificates[i].type_int) self.assertTrue(certs_cell.certificates[i].value.startswith(cert_prefix))
auth_challenge_cell, content = Cell.unpack(content, 2) @@ -141,7 +142,7 @@ class TestCell(unittest.TestCase):
# extra bytes after the last certificate should be ignored
- self.assertEqual([Certificate(type = 1, value = '\x08')], Cell.unpack('\x00\x00\x81\x00\x07\x01\x01\x00\x01\x08\x06\x04', 2)[0].certificates) + self.assertEqual([Certificate(1, '\x08')], Cell.unpack('\x00\x00\x81\x00\x07\x01\x01\x00\x01\x08\x06\x04', 2)[0].certificates)
# ... but truncated or missing certificates should error
diff --git a/test/unit/client/certificate.py b/test/unit/client/certificate.py new file mode 100644 index 00000000..873de51d --- /dev/null +++ b/test/unit/client/certificate.py @@ -0,0 +1,39 @@ +""" +Unit tests for stem.client.Certificate. +""" + +import unittest + +from stem.client import CertType, Certificate + + +class TestCertificate(unittest.TestCase): + def test_constructor(self): + test_data = ( + ((1, '\x7f\x00\x00\x01'), (CertType.LINK, 1, '\x7f\x00\x00\x01')), + ((2, '\x7f\x00\x00\x01'), (CertType.IDENTITY, 2, '\x7f\x00\x00\x01')), + ((3, '\x7f\x00\x00\x01'), (CertType.AUTHENTICATE, 3, '\x7f\x00\x00\x01')), + ((4, '\x7f\x00\x00\x01'), (CertType.UNKNOWN, 4, '\x7f\x00\x00\x01')), + ((CertType.IDENTITY, '\x7f\x00\x00\x01'), (CertType.IDENTITY, 2, '\x7f\x00\x00\x01')), + ) + + for (cert_type, cert_value), (expected_type, expected_type_int, expected_value) in test_data: + cert = Certificate(cert_type, cert_value) + self.assertEqual(expected_type, cert.type) + self.assertEqual(expected_type_int, cert.type_int) + self.assertEqual(expected_value, cert.value) + + def test_unknown_type(self): + cert = Certificate(12, 'hello') + self.assertEqual(CertType.UNKNOWN, cert.type) + self.assertEqual(12, cert.type_int) + self.assertEqual('hello', cert.value) + + def test_packing(self): + cert, content = Certificate.pop('\x02\x00\x04\x00\x00\x01\x01\x04\x04aq\x0f\x02\x00\x00\x00\x00') + self.assertEqual('\x04\x04aq\x0f\x02\x00\x00\x00\x00', content) + + self.assertEqual(CertType.IDENTITY, cert.type) + self.assertEqual(2, cert.type_int) + self.assertEqual('\x00\x00\x01\x01', cert.value) + self.assertEqual('\x02\x00\x04\x00\x00\x01\x01', cert.pack())