[tor-commits] [stem/master] Descriptor _mappings_for() helper

atagar at torproject.org atagar at torproject.org
Mon Jul 16 21:24:11 UTC 2018


commit 64a77db052c1157c9c13c1f764c8e04e8711ba08
Author: Damian Johnson <atagar at torproject.org>
Date:   Tue Jul 10 17:38:48 2018 -0700

    Descriptor _mappings_for() helper
    
    We parse quite a number of 'key=value' descriptor fields. Clear spot where we
    can deduplicate things a bit.
---
 stem/descriptor/__init__.py               |  47 ++++++++---
 stem/descriptor/extrainfo_descriptor.py   | 124 ++++++++++--------------------
 stem/descriptor/networkstatus.py          | 101 +++++++++---------------
 test/unit/descriptor/server_descriptor.py |   2 +-
 4 files changed, 117 insertions(+), 157 deletions(-)

diff --git a/stem/descriptor/__init__.py b/stem/descriptor/__init__.py
index 6a54ef29..05255648 100644
--- a/stem/descriptor/__init__.py
+++ b/stem/descriptor/__init__.py
@@ -512,21 +512,17 @@ def _parse_protocol_line(keyword, attribute):
     value = _value(keyword, entries)
     protocols = OrderedDict()
 
-    for entry in value.split():
-      if '=' not in entry:
-        raise ValueError("Protocol entires are expected to be a series of 'key=value' pairs but was: %s %s" % (keyword, value))
-
-      k, v = entry.split('=', 1)
+    for k, v in _mappings_for(keyword, value):
       versions = []
 
       if not v:
         continue
 
-      for subentry in v.split(','):
-        if '-' in subentry:
-          min_value, max_value = subentry.split('-', 1)
+      for entry in v.split(','):
+        if '-' in entry:
+          min_value, max_value = entry.split('-', 1)
         else:
-          min_value = max_value = subentry
+          min_value = max_value = entry
 
         if not min_value.isdigit() or not max_value.isdigit():
           raise ValueError('Protocol values should be a number or number range, but was: %s %s' % (keyword, value))
@@ -555,6 +551,39 @@ def _parse_key_block(keyword, attribute, expected_block_type, value_attribute =
   return _parse
 
 
+def _mappings_for(keyword, value, require_value = False, divider = ' '):
+  """
+  Parses an attribute as a series of 'key=value' mappings. Unlike _parse_*
+  functions this is a helper, returning the attribute value rather than setting
+  a descriptor field. This way parsers can perform additional validations.
+
+  :param str keyword: descriptor field being parsed
+  :param str value: 'attribute => values' mappings to parse
+  :param str divider: separator between the key/value mappings
+  :param bool require_value: validates that values are not empty
+
+  :returns: **generator** with the key/value of the map attribute
+
+  :raises: **ValueError** if descriptor content is invalid
+  """
+
+  if value is None:
+    return  # no descripoter value to process
+  elif value == '':
+    return  # descriptor field was present, but blank
+
+  for entry in value.split(divider):
+    if '=' not in entry:
+      raise ValueError("'%s' should be a series of 'key=value' pairs but was: %s" % (keyword, value))
+
+    k, v = entry.split('=', 1)
+
+    if require_value and not v:
+      raise ValueError("'%s' line's %s mapping had a blank value: %s" % (keyword, k, value))
+
+    yield k, v
+
+
 def _copy(default):
   if default is None or isinstance(default, (bool, stem.exit_policy.ExitPolicy)):
     return default  # immutable
diff --git a/stem/descriptor/extrainfo_descriptor.py b/stem/descriptor/extrainfo_descriptor.py
index a43b637a..1813a06e 100644
--- a/stem/descriptor/extrainfo_descriptor.py
+++ b/stem/descriptor/extrainfo_descriptor.py
@@ -88,6 +88,7 @@ from stem.descriptor import (
   _parse_timestamp_line,
   _parse_forty_character_hex,
   _parse_key_block,
+  _mappings_for,
   _append_router_signature,
   _random_nickname,
   _random_fingerprint,
@@ -321,22 +322,14 @@ def _parse_padding_counts_line(descriptor, entries):
 
   value = _value('padding-counts', entries)
   timestamp, interval, remainder = _parse_timestamp_and_interval('padding-counts', value)
-  entries = {}
+  counts = {}
 
-  for entry in remainder.split(' '):
-    if '=' not in entry:
-      raise ValueError('Entries in padding-counts line should be key=value mappings: padding-counts %s' % value)
-
-    k, v = entry.split('=', 1)
-
-    if not v:
-      raise ValueError('Entry in padding-counts line had a blank value: padding-counts %s' % value)
-
-    entries[k] = int(v) if v.isdigit() else v
+  for k, v in _mappings_for('padding-counts', remainder, require_value = True):
+    counts[k] = int(v) if v.isdigit() else v
 
   setattr(descriptor, 'padding_counts_end', timestamp)
   setattr(descriptor, 'padding_counts_interval', interval)
-  setattr(descriptor, 'padding_counts', entries)
+  setattr(descriptor, 'padding_counts', counts)
 
 
 def _parse_dirreq_line(keyword, recognized_counts_attr, unrecognized_counts_attr, descriptor, entries):
@@ -349,22 +342,15 @@ def _parse_dirreq_line(keyword, recognized_counts_attr, unrecognized_counts_attr
   key_set = DirResponse if is_response_stats else DirStat
 
   key_type = 'STATUS' if is_response_stats else 'STAT'
-  error_msg = '%s lines should contain %s=COUNT mappings: %s %s' % (keyword, key_type, keyword, value)
 
-  if value:
-    for entry in value.split(','):
-      if '=' not in entry:
-        raise ValueError(error_msg)
-
-      status, count = entry.split('=', 1)
+  for status, count in _mappings_for(keyword, value, divider = ','):
+    if not count.isdigit():
+      raise ValueError('%s lines should contain %s=COUNT mappings: %s %s' % (keyword, key_type, keyword, value))
 
-      if count.isdigit():
-        if status in key_set:
-          recognized_counts[status] = int(count)
-        else:
-          unrecognized_counts[status] = int(count)
-      else:
-        raise ValueError(error_msg)
+    if status in key_set:
+      recognized_counts[status] = int(count)
+    else:
+      unrecognized_counts[status] = int(count)
 
   setattr(descriptor, recognized_counts_attr, recognized_counts)
   setattr(descriptor, unrecognized_counts_attr, unrecognized_counts)
@@ -453,22 +439,13 @@ def _parse_port_count_line(keyword, attribute, descriptor, entries):
   # "<keyword>" port=N,port=N,...
 
   value, port_mappings = _value(keyword, entries), {}
-  error_msg = 'Entries in %s line should only be PORT=N entries: %s %s' % (keyword, keyword, value)
-
-  if value:
-    for entry in value.split(','):
-      if '=' not in entry:
-        raise ValueError(error_msg)
 
-      port, stat = entry.split('=', 1)
+  for port, stat in _mappings_for(keyword, value, divider = ','):
+    if (port != 'other' and not stem.util.connection.is_valid_port(port)) or not stat.isdigit():
+      raise ValueError('Entries in %s line should only be PORT=N entries: %s %s' % (keyword, keyword, value))
 
-      if (port == 'other' or stem.util.connection.is_valid_port(port)) and stat.isdigit():
-        if port != 'other':
-          port = int(port)
-
-        port_mappings[port] = int(stat)
-      else:
-        raise ValueError(error_msg)
+    port = int(port) if port.isdigit() else port
+    port_mappings[port] = int(stat)
 
   setattr(descriptor, attribute, port_mappings)
 
@@ -483,19 +460,12 @@ def _parse_geoip_to_count_line(keyword, attribute, descriptor, entries):
   #   ??,"Unknown"
 
   value, locale_usage = _value(keyword, entries), {}
-  error_msg = 'Entries in %s line should only be CC=N entries: %s %s' % (keyword, keyword, value)
-
-  if value:
-    for entry in value.split(','):
-      if '=' not in entry:
-        raise ValueError(error_msg)
 
-      locale, count = entry.split('=', 1)
+  for locale, count in _mappings_for(keyword, value, divider = ','):
+    if not _locale_re.match(locale) or not count.isdigit():
+      raise ValueError('Entries in %s line should only be CC=N entries: %s %s' % (keyword, keyword, value))
 
-      if _locale_re.match(locale) and count.isdigit():
-        locale_usage[locale] = int(count)
-      else:
-        raise ValueError(error_msg)
+    locale_usage[locale] = int(count)
 
   setattr(descriptor, attribute, locale_usage)
 
@@ -503,17 +473,11 @@ def _parse_geoip_to_count_line(keyword, attribute, descriptor, entries):
 def _parse_bridge_ip_versions_line(descriptor, entries):
   value, ip_versions = _value('bridge-ip-versions', entries), {}
 
-  if value:
-    for entry in value.split(','):
-      if '=' not in entry:
-        raise stem.ProtocolError("The bridge-ip-versions should be a comma separated listing of '<protocol>=<count>' mappings: bridge-ip-versions %s" % value)
-
-      protocol, count = entry.split('=', 1)
-
-      if not count.isdigit():
-        raise stem.ProtocolError('IP protocol count was non-numeric (%s): bridge-ip-versions %s' % (count, value))
+  for protocol, count in _mappings_for('bridge-ip-versions', value, divider = ','):
+    if not count.isdigit():
+      raise stem.ProtocolError('IP protocol count was non-numeric (%s): bridge-ip-versions %s' % (count, value))
 
-      ip_versions[protocol] = int(count)
+    ip_versions[protocol] = int(count)
 
   descriptor.ip_versions = ip_versions
 
@@ -521,17 +485,11 @@ def _parse_bridge_ip_versions_line(descriptor, entries):
 def _parse_bridge_ip_transports_line(descriptor, entries):
   value, ip_transports = _value('bridge-ip-transports', entries), {}
 
-  if value:
-    for entry in value.split(','):
-      if '=' not in entry:
-        raise stem.ProtocolError("The bridge-ip-transports should be a comma separated listing of '<protocol>=<count>' mappings: bridge-ip-transports %s" % value)
-
-      protocol, count = entry.split('=', 1)
-
-      if not count.isdigit():
-        raise stem.ProtocolError('Transport count was non-numeric (%s): bridge-ip-transports %s' % (count, value))
+  for protocol, count in _mappings_for('bridge-ip-transports', value, divider = ','):
+    if not count.isdigit():
+      raise stem.ProtocolError('Transport count was non-numeric (%s): bridge-ip-transports %s' % (count, value))
 
-      ip_transports[protocol] = int(count)
+    ip_transports[protocol] = int(count)
 
   descriptor.ip_transports = ip_transports
 
@@ -541,22 +499,22 @@ def _parse_hs_stats(keyword, stat_attribute, extra_attribute, descriptor, entrie
 
   value, stat, extra = _value(keyword, entries), None, {}
 
-  if value is not None:
-    value_comp = value.split()
-
-    if not value_comp:
-      raise ValueError("'%s' line was blank" % keyword)
+  if value is None:
+    pass  # not in the descriptor
+  elif value == '':
+    raise ValueError("'%s' line was blank" % keyword)
+  else:
+    if ' ' in value:
+      stat_value, remainder = value.split(' ', 1)
+    else:
+      stat_value, remainder = value, None
 
     try:
-      stat = int(value_comp[0])
+      stat = int(stat_value)
     except ValueError:
-      raise ValueError("'%s' stat was non-numeric (%s): %s %s" % (keyword, value_comp[0], keyword, value))
-
-    for entry in value_comp[1:]:
-      if '=' not in entry:
-        raise ValueError('Entries after the stat in %s lines should only be key=val entries: %s %s' % (keyword, keyword, value))
+      raise ValueError("'%s' stat was non-numeric (%s): %s %s" % (keyword, stat_value, keyword, value))
 
-      key, val = entry.split('=', 1)
+    for key, val in _mappings_for(keyword, remainder):
       extra[key] = val
 
   setattr(descriptor, stat_attribute, stat)
diff --git a/stem/descriptor/networkstatus.py b/stem/descriptor/networkstatus.py
index eac556a6..f9245b43 100644
--- a/stem/descriptor/networkstatus.py
+++ b/stem/descriptor/networkstatus.py
@@ -76,6 +76,7 @@ from stem.descriptor import (
   _parse_forty_character_hex,
   _parse_protocol_line,
   _parse_key_block,
+  _mappings_for,
   _random_nickname,
   _random_fingerprint,
   _random_ipv4_address,
@@ -642,26 +643,20 @@ def _parse_header_flag_thresholds_line(descriptor, entries):
 
   value, thresholds = _value('flag-thresholds', entries).strip(), {}
 
-  if value:
-    for entry in value.split(' '):
-      if '=' not in entry:
-        raise ValueError("Network status document's 'flag-thresholds' line is expected to be space separated key=value mappings, got: flag-thresholds %s" % value)
-
-      entry_key, entry_value = entry.split('=', 1)
-
-      try:
-        if entry_value.endswith('%'):
-          # opting for string manipulation rather than just
-          # 'float(entry_value) / 100' because floating point arithmetic
-          # will lose precision
-
-          thresholds[entry_key] = float('0.' + entry_value[:-1].replace('.', '', 1))
-        elif '.' in entry_value:
-          thresholds[entry_key] = float(entry_value)
-        else:
-          thresholds[entry_key] = int(entry_value)
-      except ValueError:
-        raise ValueError("Network status document's 'flag-thresholds' line is expected to have float values, got: flag-thresholds %s" % value)
+  for key, val in _mappings_for('flag-thresholds', value):
+    try:
+      if val.endswith('%'):
+        # opting for string manipulation rather than just
+        # 'float(entry_value) / 100' because floating point arithmetic
+        # will lose precision
+
+        thresholds[key] = float('0.' + val[:-1].replace('.', '', 1))
+      elif '.' in val:
+        thresholds[key] = float(val)
+      else:
+        thresholds[key] = int(val)
+    except ValueError:
+      raise ValueError("Network status document's 'flag-thresholds' line is expected to have float values, got: flag-thresholds %s" % value)
 
   descriptor.flag_thresholds = thresholds
 
@@ -716,7 +711,7 @@ def _parse_package_line(descriptor, entries):
   package_versions = []
 
   for value, _, _ in entries['package']:
-    value_comp = value.split()
+    value_comp = value.split(' ', 3)
 
     if len(value_comp) < 3:
       raise ValueError("'package' must at least have a 'PackageName Version URL': %s" % value)
@@ -724,12 +719,9 @@ def _parse_package_line(descriptor, entries):
     name, version, url = value_comp[:3]
     digests = {}
 
-    for digest_entry in value_comp[3:]:
-      if '=' not in digest_entry:
-        raise ValueError("'package' digest entries should be 'key=value' pairs: %s" % value)
-
-      key, value = digest_entry.split('=', 1)
-      digests[key] = value
+    if len(value_comp) == 4:
+      for key, val in _mappings_for('package', value_comp[3]):
+        digests[key] = val
 
     package_versions.append(PackageVersion(name, version, url, digests))
 
@@ -793,18 +785,8 @@ def _parse_bandwidth_file_headers(descriptor, entries):
   value = _value('bandwidth-file-headers', entries)
   results = {}
 
-  for entry in value.split(' '):
-    if not entry:
-      continue
-    elif '=' not in entry:
-      raise ValueError("'bandwidth-file-headers' lines must be a series of 'key=value' pairs: %s" % value)
-
-    k, v = entry.split('=', 1)
-
-    if not v:
-      raise ValueError("'bandwidth-file-headers' mappings should all have values: %s" % value)
-
-    results[k] = v
+  for key, val in _mappings_for('bandwidth-file-headers', value, require_value = True):
+    results[key] = val
 
   descriptor.bandwidth_file_headers = results
 
@@ -1310,35 +1292,26 @@ def _parse_int_mappings(keyword, value, validate):
   # - keys are sorted in lexical order
 
   results, seen_keys = {}, []
-  for entry in value.split(' '):
-    try:
-      if '=' not in entry:
-        raise ValueError("must only have 'key=value' entries")
+  error_template = "Unable to parse network status document's '%s' line (%%s): %s'" % (keyword, value)
 
-      entry_key, entry_value = entry.split('=', 1)
+  for key, val in _mappings_for(keyword, value):
+    if validate:
+      # parameters should be in ascending order by their key
+      for prior_key in seen_keys:
+        if prior_key > key:
+          raise ValueError(error_template % 'parameters must be sorted by their key')
 
-      try:
-        # the int() function accepts things like '+123', but we don't want to
-        if entry_value.startswith('+'):
-          raise ValueError()
+    try:
+      # the int() function accepts things like '+123', but we don't want to
 
-        entry_value = int(entry_value)
-      except ValueError:
-        raise ValueError("'%s' is a non-numeric value" % entry_value)
-
-      if validate:
-        # parameters should be in ascending order by their key
-        for prior_key in seen_keys:
-          if prior_key > entry_key:
-            raise ValueError('parameters must be sorted by their key')
-
-      results[entry_key] = entry_value
-      seen_keys.append(entry_key)
-    except ValueError as exc:
-      if not validate:
-        continue
+      if val.startswith('+'):
+        raise ValueError()
+
+      results[key] = int(val)
+    except ValueError:
+      raise ValueError(error_template % ("'%s' is a non-numeric value" % val))
 
-      raise ValueError("Unable to parse network status document's '%s' line (%s): %s'" % (keyword, exc, value))
+    seen_keys.append(key)
 
   return results
 
diff --git a/test/unit/descriptor/server_descriptor.py b/test/unit/descriptor/server_descriptor.py
index ff1035c4..7a68613d 100644
--- a/test/unit/descriptor/server_descriptor.py
+++ b/test/unit/descriptor/server_descriptor.py
@@ -796,7 +796,7 @@ Qlx9HNCqCY877ztFRC624ja2ql6A2hBcuoYMbkHjcQ4=
     Checks a 'proto' line when it's not key=value pairs.
     """
 
-    exc_msg = "Protocol entires are expected to be a series of 'key=value' pairs but was: proto Desc Link=1-4"
+    exc_msg = "'proto' should be a series of 'key=value' pairs but was: Desc Link=1-4"
     self.assertRaisesRegexp(ValueError, exc_msg, RelayDescriptor.create, {'proto': 'Desc Link=1-4'})
 
   def test_parse_with_non_int_version(self):



More information about the tor-commits mailing list