[tor-commits] [stem/master] Concatenate strings with bytearray rather than BytesIO

atagar at torproject.org atagar at torproject.org
Fri May 25 21:38:48 UTC 2018


commit c23cffc5cbe12b276fa18424a7bc85f91f901a60
Author: Damian Johnson <atagar at torproject.org>
Date:   Fri May 25 14:19:22 2018 -0700

    Concatenate strings with bytearray rather than BytesIO
    
    While skimming through the python docs I came across a note suggesting
    bytearray for string concatenation...
    
      https://docs.python.org/3/faq/programming.html#what-is-the-most-efficient-way-to-concatenate-many-strings-together
    
    Gave this a whirl and it's not only more readable, but about 10% faster
    too...
    
    ============================================================
    concat_with_bytesio.py
    ============================================================
    
    import io
    import time
    
    start = time.time()
    full_bytes = io.BytesIO()
    
    for i in range(10000000):
      full_bytes.write(b'hiiiiiiiiiiiiiiiiiiiii')
    
    result = full_bytes.getvalue()
    print('total length: %i, took %0.2f seconds' % (len(result), time.time() - start))
    
    ============================================================
    concat_with_bytearray.py
    ============================================================
    
    import time
    
    start = time.time()
    full_bytes = bytearray()
    
    for i in range(10000000):
      full_bytes += b'hiiiiiiiiiiiiiiiiiiiii'
    
    result = bytes(full_bytes)
    print('total length: %i, took %0.2f seconds' % (len(result), time.time() - start))
    
    ============================================================
    Results
    ============================================================
    
    % python2 --version
    Python 2.7.12
    
    % python3 --version
    Python 3.5.2
    
    % for ((n=0;n<3;n++)); do python2 concat_with_bytearray.py; done
    total length: 220000000, took 1.18 seconds
    total length: 220000000, took 1.19 seconds
    total length: 220000000, took 1.17 seconds
    
    % for ((n=0;n<3;n++)); do python3 concat_with_bytearray.py; done
    total length: 220000000, took 1.02 seconds
    total length: 220000000, took 0.99 seconds
    total length: 220000000, took 1.03 seconds
    
    % for ((n=0;n<3;n++)); do python2 concat_with_bytesio.py; done
    total length: 220000000, took 1.40 seconds
    total length: 220000000, took 1.38 seconds
    total length: 220000000, took 1.38 seconds
    
    % for ((n=0;n<3;n++)); do python3 concat_with_bytesio.py; done
    total length: 220000000, took 1.18 seconds
    total length: 220000000, took 1.19 seconds
    total length: 220000000, took 1.20 seconds
---
 stem/client/cell.py     | 58 ++++++++++++++++++++++++-------------------------
 stem/client/datatype.py | 21 +++++++++---------
 stem/socket.py          | 22 +++++++++----------
 3 files changed, 48 insertions(+), 53 deletions(-)

diff --git a/stem/client/cell.py b/stem/client/cell.py
index b40a737f..d0a98ada 100644
--- a/stem/client/cell.py
+++ b/stem/client/cell.py
@@ -39,7 +39,6 @@ Messages communicated over a Tor relay's ORPort.
 
 import datetime
 import inspect
-import io
 import os
 import random
 import sys
@@ -192,24 +191,23 @@ class Cell(object):
     if isinstance(cls, CircuitCell) and circ_id is None:
       raise ValueError('%s cells require a circ_id' % cls.NAME)
 
-    cell = io.BytesIO()
-    cell.write(Size.LONG.pack(circ_id) if link_protocol > 3 else Size.SHORT.pack(circ_id))
-    cell.write(Size.CHAR.pack(cls.VALUE))
-    cell.write(b'' if cls.IS_FIXED_SIZE else Size.SHORT.pack(len(payload)))
-    cell.write(payload)
+    cell = bytearray()
+    cell += Size.LONG.pack(circ_id) if link_protocol > 3 else Size.SHORT.pack(circ_id)
+    cell += Size.CHAR.pack(cls.VALUE)
+    cell += b'' if cls.IS_FIXED_SIZE else Size.SHORT.pack(len(payload))
+    cell += payload
 
     # pad fixed sized cells to the required length
 
     if cls.IS_FIXED_SIZE:
-      cell_size = cell.seek(0, io.SEEK_END)
       fixed_cell_len = 514 if link_protocol > 3 else 512
 
-      if cell_size > fixed_cell_len:
-        raise ValueError('Payload of %s is too large (%i bytes), must be less than %i' % (cls.NAME, cell_size, fixed_cell_len))
+      if len(cell) > fixed_cell_len:
+        raise ValueError('Payload of %s is too large (%i bytes), must be less than %i' % (cls.NAME, len(cell), fixed_cell_len))
 
-      cell.write(ZERO * (fixed_cell_len - cell_size))
+      cell += ZERO * (fixed_cell_len - len(cell))
 
-    return cell.getvalue()
+    return bytes(cell)
 
   @classmethod
   def _unpack(cls, content, circ_id, link_protocol):
@@ -330,15 +328,15 @@ class RelayCell(CircuitCell):
         raise ValueError('%s relay cells concern the circuit itself and cannot have a stream id' % self.command)
 
   def pack(self, link_protocol):
-    payload = io.BytesIO()
-    payload.write(Size.CHAR.pack(self.command_int))
-    payload.write(Size.SHORT.pack(self.recognized))
-    payload.write(Size.SHORT.pack(self.stream_id))
-    payload.write(Size.LONG.pack(self.digest))
-    payload.write(Size.SHORT.pack(len(self.data)))
-    payload.write(self.data)
+    payload = bytearray()
+    payload += Size.CHAR.pack(self.command_int)
+    payload += Size.SHORT.pack(self.recognized)
+    payload += Size.SHORT.pack(self.stream_id)
+    payload += Size.LONG.pack(self.digest)
+    payload += Size.SHORT.pack(len(self.data))
+    payload += self.data
 
-    return RelayCell._pack(link_protocol, payload.getvalue(), self.circ_id)
+    return RelayCell._pack(link_protocol, bytes(payload), self.circ_id)
 
   @classmethod
   def _unpack(cls, content, circ_id, link_protocol):
@@ -521,15 +519,15 @@ class NetinfoCell(Cell):
     self.sender_addresses = sender_addresses
 
   def pack(self, link_protocol):
-    payload = io.BytesIO()
-    payload.write(Size.LONG.pack(int(datetime_to_unix(self.timestamp))))
-    payload.write(self.receiver_address.pack())
-    payload.write(Size.CHAR.pack(len(self.sender_addresses)))
+    payload = bytearray()
+    payload += Size.LONG.pack(int(datetime_to_unix(self.timestamp)))
+    payload += self.receiver_address.pack()
+    payload += Size.CHAR.pack(len(self.sender_addresses))
 
     for addr in self.sender_addresses:
-      payload.write(addr.pack())
+      payload += addr.pack()
 
-    return NetinfoCell._pack(link_protocol, payload.getvalue())
+    return NetinfoCell._pack(link_protocol, bytes(payload))
 
   @classmethod
   def _unpack(cls, content, circ_id, link_protocol):
@@ -664,14 +662,14 @@ class AuthChallengeCell(Cell):
     self.methods = methods
 
   def pack(self, link_protocol):
-    payload = io.BytesIO()
-    payload.write(self.challenge)
-    payload.write(Size.SHORT.pack(len(self.methods)))
+    payload = bytearray()
+    payload += self.challenge
+    payload += Size.SHORT.pack(len(self.methods))
 
     for method in self.methods:
-      payload.write(Size.SHORT.pack(method))
+      payload += Size.SHORT.pack(method)
 
-    return AuthChallengeCell._pack(link_protocol, payload.getvalue())
+    return AuthChallengeCell._pack(link_protocol, bytes(payload))
 
   @classmethod
   def _unpack(cls, content, circ_id, link_protocol):
diff --git a/stem/client/datatype.py b/stem/client/datatype.py
index 038229d2..75fae662 100644
--- a/stem/client/datatype.py
+++ b/stem/client/datatype.py
@@ -112,7 +112,6 @@ users.** See our :class:`~stem.client.Relay` the API you probably want.
 
 import collections
 import hashlib
-import io
 import struct
 
 import stem.prereq
@@ -381,11 +380,11 @@ class Address(Field):
       self.value_bin = value
 
   def pack(self):
-    cell = io.BytesIO()
-    cell.write(Size.CHAR.pack(self.type_int))
-    cell.write(Size.CHAR.pack(len(self.value_bin)))
-    cell.write(self.value_bin)
-    return cell.getvalue()
+    cell = bytearray()
+    cell += Size.CHAR.pack(self.type_int)
+    cell += Size.CHAR.pack(len(self.value_bin))
+    cell += self.value_bin
+    return bytes(cell)
 
   @staticmethod
   def pop(content):
@@ -422,11 +421,11 @@ class Certificate(Field):
     self.value = value
 
   def pack(self):
-    cell = io.BytesIO()
-    cell.write(Size.CHAR.pack(self.type_int))
-    cell.write(Size.SHORT.pack(len(self.value)))
-    cell.write(self.value)
-    return cell.getvalue()
+    cell = bytearray()
+    cell += Size.CHAR.pack(self.type_int)
+    cell += Size.SHORT.pack(len(self.value))
+    cell += self.value
+    return bytes(cell)
 
   @staticmethod
   def pop(content):
diff --git a/stem/socket.py b/stem/socket.py
index 710ab657..3da24c29 100644
--- a/stem/socket.py
+++ b/stem/socket.py
@@ -72,7 +72,6 @@ Tor...
 
 from __future__ import absolute_import
 
-import io
 import re
 import socket
 import ssl
@@ -713,9 +712,9 @@ def recv_message(control_file):
         _log_trace(line)
         return stem.response.ControlMessage([(status_code, divider, content)], line)
       else:
-        parsed_content, raw_content, first_line = [], io.BytesIO(), False
+        parsed_content, raw_content, first_line = [], bytearray(), False
 
-    raw_content.write(line)
+    raw_content += line
 
     if divider == '-':
       # mid-reply line, keep pulling for more content
@@ -723,25 +722,24 @@ def recv_message(control_file):
     elif divider == ' ':
       # end of the message, return the message
       parsed_content.append((status_code, divider, content))
-      _log_trace(raw_content.getvalue())
-      return stem.response.ControlMessage(parsed_content, raw_content.getvalue())
+      _log_trace(bytes(raw_content))
+      return stem.response.ControlMessage(parsed_content, bytes(raw_content))
     elif divider == '+':
       # data entry, all of the following lines belong to the content until we
       # get a line with just a period
 
-      content_block = io.BytesIO()
-      content_block.write(content)
+      content_block = bytearray(content)
 
       while True:
         try:
           line = control_file.readline()
-          raw_content.write(line)
+          raw_content += line
         except socket.error as exc:
-          log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(raw_content.getvalue()))))
+          log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content)))))
           raise stem.SocketClosed(exc)
 
         if not line.endswith(b'\r\n'):
-          log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(raw_content.getvalue())))
+          log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content))))
           raise stem.ProtocolError('All lines should end with CRLF')
         elif line == b'.\r\n':
           break  # data block termination
@@ -754,12 +752,12 @@ def recv_message(control_file):
         if line.startswith(b'..'):
           line = line[1:]
 
-        content_block.write(b'\n' + line)
+        content_block += b'\n' + line
 
       # joins the content using a newline rather than CRLF separator (more
       # conventional for multi-line string content outside the windows world)
 
-      parsed_content.append((status_code, divider, content_block.getvalue()))
+      parsed_content.append((status_code, divider, bytes(content_block)))
     else:
       # this should never be reached due to the prefix regex, but might as well
       # be safe...



More information about the tor-commits mailing list