commit 7981a765e15c0e4b7b4154c5c236974b8a3cdfea
Author: Damian Johnson <atagar(a)torproject.org>
Date: Sat Jan 20 15:26:32 2018 -0800
Change Certificate to a Field subclass
Moving packing and unpacking of certificates into the type so we fit the
pattern we're forming.
---
stem/client/__init__.py | 71 +++++++++++++++++++++++++++++------------
stem/client/cell.py | 20 ++----------
test/settings.cfg | 1 +
test/unit/client/__init__.py | 1 +
test/unit/client/cell.py | 21 ++++++------
test/unit/client/certificate.py | 39 ++++++++++++++++++++++
6 files changed, 106 insertions(+), 47 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 13952a00..1fd05f9d 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -51,7 +51,6 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as
===================== ===========
"""
-import collections
import io
import struct
@@ -198,24 +197,6 @@ class Size(Field):
return self.unpack(packed[:self.size]), packed[self.size:]
-class Certificate(collections.namedtuple('Certificate', ['type', 'value'])):
- """
- Relay certificate as defined in tor-spec section 4.2. Certificate types
- are...
-
- ==================== ===========
- Type Value Description
- ==================== ===========
- 1 Link key certificate certified by RSA1024 identity
- 2 RSA1024 Identity certificate
- 3 RSA1024 AUTHENTICATE cell link certificate
- ==================== ===========
-
- :var int type: certificate type
- :var bytes value: certificate value
- """
-
-
class Address(Field):
"""
Relay address.
@@ -299,10 +280,60 @@ class Address(Field):
return Address(addr_type, addr_value), content
def __hash__(self):
- # no need to include value or type since they're derived from these
return _hash_attr(self, 'type_int', 'value_bin')
+class Certificate(Field):
+ """
+ Relay certificate as defined in tor-spec section 4.2.
+
+ :var stem.client.CertType type: certificate type
+ :var int type_int: integer value of the certificate type
+ :var bytes value: certificate value
+ """
+
+ TYPE_FOR_INT = {
+ 1: CertType.LINK,
+ 2: CertType.IDENTITY,
+ 3: CertType.AUTHENTICATE,
+ }
+
+ INT_FOR_TYPE = dict((v, k) for k, v in TYPE_FOR_INT.items())
+
+ def __init__(self, cert_type, value):
+ if isinstance(cert_type, int):
+ self.type = Certificate.TYPE_FOR_INT.get(cert_type, CertType.UNKNOWN)
+ self.type_int = cert_type
+ elif cert_type in CertType:
+ self.type = cert_type
+ self.type_int = Certificate.INT_FOR_TYPE.get(cert_type, -1)
+ else:
+ raise ValueError('Invalid certificate type: %s' % cert_type)
+
+ self.value = value
+
+ def pack(self):
+ cell = io.BytesIO()
+ cell.write(Size.CHAR.pack(self.type_int))
+ cell.write(Size.SHORT.pack(len(self.value)))
+ cell.write(self.value)
+ return cell.getvalue()
+
+ @staticmethod
+ def pop(content):
+ cert_type, content = Size.CHAR.pop(content)
+ cert_size, content = Size.SHORT.pop(content)
+
+ if cert_size > len(content):
+ raise ValueError('CERTS cell should have a certificate with %i bytes, but only had %i remaining' % (cert_size, len(content)))
+
+ cert_bytes, content = split(content, cert_size)
+ return Certificate(cert_type, cert_bytes), content
+
+ def __hash__(self):
+ return _hash_attr(self, 'type_int', 'value')
+
+
setattr(Size, 'CHAR', Size('CHAR', 1, '!B'))
setattr(Size, 'SHORT', Size('SHORT', 2, '!H'))
setattr(Size, 'LONG', Size('LONG', 4, '!L'))
diff --git a/stem/client/cell.py b/stem/client/cell.py
index 332e3f16..074d2079 100644
--- a/stem/client/cell.py
+++ b/stem/client/cell.py
@@ -489,15 +489,7 @@ class CertsCell(Cell):
:returns: **bytes** with a payload for these versions
"""
- payload = io.BytesIO()
- payload.write(Size.CHAR.pack(len(certs)))
-
- for cert in certs:
- payload.write(Size.CHAR.pack(cert.type))
- payload.write(Size.SHORT.pack(len(cert.value)))
- payload.write(cert.value)
-
- return cls._pack(link_version, payload.getvalue())
+ return cls._pack(link_version, Size.CHAR.pack(len(certs)) + ''.join([cert.pack() for cert in certs]))
@classmethod
def _unpack(cls, content, circ_id, link_version):
@@ -508,14 +500,8 @@ class CertsCell(Cell):
if not content:
raise ValueError('CERTS cell indicates it should have %i certificates, but only contained %i' % (cert_count, len(certs)))
- cert_type, content = Size.CHAR.pop(content)
- cert_size, content = Size.SHORT.pop(content)
-
- if cert_size > len(content):
- raise ValueError('CERTS cell should have a certificate with %i bytes, but only had %i remaining' % (cert_size, len(content)))
-
- cert_bytes, content = split(content, cert_size)
- certs.append(Certificate(cert_type, cert_bytes))
+ cert, content = Certificate.pop(content)
+ certs.append(cert)
return CertsCell(certs)
diff --git a/test/settings.cfg b/test/settings.cfg
index 28694b60..518ae2b2 100644
--- a/test/settings.cfg
+++ b/test/settings.cfg
@@ -232,6 +232,7 @@ test.unit_tests
|test.unit.response.mapaddress.TestMapAddressResponse
|test.unit.client.size.TestSize
|test.unit.client.address.TestAddress
+|test.unit.client.certificate.TestCertificate
|test.unit.client.cell.TestCell
|test.unit.connection.authentication.TestAuthenticate
|test.unit.connection.connect.TestConnect
diff --git a/test/unit/client/__init__.py b/test/unit/client/__init__.py
index b36b72af..7e745352 100644
--- a/test/unit/client/__init__.py
+++ b/test/unit/client/__init__.py
@@ -5,6 +5,7 @@ Unit tests for stem.client.* contents.
__all__ = [
'address',
'cell',
+ 'certificate',
'size',
]
diff --git a/test/unit/client/cell.py b/test/unit/client/cell.py
index c9aea5bc..7f3318ee 100644
--- a/test/unit/client/cell.py
+++ b/test/unit/client/cell.py
@@ -6,7 +6,7 @@ import datetime
import os
import unittest
-from stem.client import ZERO, AddrType, Address, Certificate
+from stem.client import ZERO, AddrType, CertType, Address, Certificate
from test.unit.client import test_data
from stem.client.cell import (
@@ -44,8 +44,8 @@ VPADDING_CELLS = {
CERTS_CELLS = {
'\x00\x00\x81\x00\x01\x00': [],
- '\x00\x00\x81\x00\x04\x01\x01\x00\x00': [Certificate(type = 1, value = '')],
- '\x00\x00\x81\x00\x05\x01\x01\x00\x01\x08': [Certificate(type = 1, value = '\x08')],
+ '\x00\x00\x81\x00\x04\x01\x01\x00\x00': [Certificate(1, '')],
+ '\x00\x00\x81\x00\x05\x01\x01\x00\x01\x08': [Certificate(1, '\x08')],
}
AUTH_CHALLENGE_CELLS = {
@@ -79,11 +79,11 @@ class TestCell(unittest.TestCase):
def test_unpack_for_new_link(self):
expected_certs = (
- (1, '0\x82\x02F0\x82\x01\xaf'),
- (2, '0\x82\x01\xc90\x82\x012'),
- (4, '\x01\x04\x00\x06m\x1f'),
- (5, '\x01\x05\x00\x06m\n\x01'),
- (7, '\x1a\xa5\xb3\xbd\x88\xb1C'),
+ (CertType.LINK, 1, '0\x82\x02F0\x82\x01\xaf'),
+ (CertType.IDENTITY, 2, '0\x82\x01\xc90\x82\x012'),
+ (CertType.UNKNOWN, 4, '\x01\x04\x00\x06m\x1f'),
+ (CertType.UNKNOWN, 5, '\x01\x05\x00\x06m\n\x01'),
+ (CertType.UNKNOWN, 7, '\x1a\xa5\xb3\xbd\x88\xb1C'),
)
content = test_data('new_link_cells')
@@ -95,8 +95,9 @@ class TestCell(unittest.TestCase):
self.assertEqual(CertsCell, type(certs_cell))
self.assertEqual(len(expected_certs), len(certs_cell.certificates))
- for i, (cert_type, cert_prefix) in enumerate(expected_certs):
+ for i, (cert_type, cert_type_int, cert_prefix) in enumerate(expected_certs):
self.assertEqual(cert_type, certs_cell.certificates[i].type)
+ self.assertEqual(cert_type_int, certs_cell.certificates[i].type_int)
self.assertTrue(certs_cell.certificates[i].value.startswith(cert_prefix))
auth_challenge_cell, content = Cell.unpack(content, 2)
@@ -141,7 +142,7 @@ class TestCell(unittest.TestCase):
# extra bytes after the last certificate should be ignored
- self.assertEqual([Certificate(type = 1, value = '\x08')], Cell.unpack('\x00\x00\x81\x00\x07\x01\x01\x00\x01\x08\x06\x04', 2)[0].certificates)
+ self.assertEqual([Certificate(1, '\x08')], Cell.unpack('\x00\x00\x81\x00\x07\x01\x01\x00\x01\x08\x06\x04', 2)[0].certificates)
# ... but truncated or missing certificates should error
diff --git a/test/unit/client/certificate.py b/test/unit/client/certificate.py
new file mode 100644
index 00000000..873de51d
--- /dev/null
+++ b/test/unit/client/certificate.py
@@ -0,0 +1,39 @@
+"""
+Unit tests for stem.client.Certificate.
+"""
+
+import unittest
+
+from stem.client import CertType, Certificate
+
+
+class TestCertificate(unittest.TestCase):
+ def test_constructor(self):
+ test_data = (
+ ((1, '\x7f\x00\x00\x01'), (CertType.LINK, 1, '\x7f\x00\x00\x01')),
+ ((2, '\x7f\x00\x00\x01'), (CertType.IDENTITY, 2, '\x7f\x00\x00\x01')),
+ ((3, '\x7f\x00\x00\x01'), (CertType.AUTHENTICATE, 3, '\x7f\x00\x00\x01')),
+ ((4, '\x7f\x00\x00\x01'), (CertType.UNKNOWN, 4, '\x7f\x00\x00\x01')),
+ ((CertType.IDENTITY, '\x7f\x00\x00\x01'), (CertType.IDENTITY, 2, '\x7f\x00\x00\x01')),
+ )
+
+ for (cert_type, cert_value), (expected_type, expected_type_int, expected_value) in test_data:
+ cert = Certificate(cert_type, cert_value)
+ self.assertEqual(expected_type, cert.type)
+ self.assertEqual(expected_type_int, cert.type_int)
+ self.assertEqual(expected_value, cert.value)
+
+ def test_unknown_type(self):
+ cert = Certificate(12, 'hello')
+ self.assertEqual(CertType.UNKNOWN, cert.type)
+ self.assertEqual(12, cert.type_int)
+ self.assertEqual('hello', cert.value)
+
+ def test_packing(self):
+ cert, content = Certificate.pop('\x02\x00\x04\x00\x00\x01\x01\x04\x04aq\x0f\x02\x00\x00\x00\x00')
+ self.assertEqual('\x04\x04aq\x0f\x02\x00\x00\x00\x00', content)
+
+ self.assertEqual(CertType.IDENTITY, cert.type)
+ self.assertEqual(2, cert.type_int)
+ self.assertEqual('\x00\x00\x01\x01', cert.value)
+ self.assertEqual('\x02\x00\x04\x00\x00\x01\x01', cert.pack())