[tor-commits] [sbws/m12] chg: scanner: Move to concurrent.futures

juga at torproject.org juga at torproject.org
Wed Jun 30 08:53:27 UTC 2021


commit e11faefd010ca0447f10d79ca9da1ef71f37a1d9
Author: juga0 <juga at riseup.net>
Date:   Sat Jun 19 16:07:20 2021 +0000

    chg: scanner: Move to concurrent.futures
    
    away from multiprocessing, because it looks like we hit python bug
    22393, in which the pool hangs forever when a worker process dies.
    We don't know the reason why a worker process might due, maybe oom.
    See https://stackoverflow.com/questions/65115092/occasional-deadlock-in-multiprocessing-pool,
    
    We also run into several other issues in the past with multiprocessing.
    Concurrent.futures has a simpler API and is more modern.
    
    Closes #40092.
---
 sbws/core/scanner.py            | 370 +++++++++++++++++-----------------------
 sbws/globals.py                 |   1 -
 setup.cfg                       |   1 +
 tests/unit/conftest.py          |   7 +
 tests/unit/core/test_scanner.py |  83 ++++++++-
 5 files changed, 241 insertions(+), 221 deletions(-)

diff --git a/sbws/core/scanner.py b/sbws/core/scanner.py
index 7637028..bb723bf 100644
--- a/sbws/core/scanner.py
+++ b/sbws/core/scanner.py
@@ -1,4 +1,5 @@
 """ Measure the relays. """
+import concurrent.futures
 import logging
 import os
 import queue
@@ -10,17 +11,10 @@ import time
 import traceback
 import uuid
 from argparse import ArgumentDefaultsHelpFormatter
-from multiprocessing.context import TimeoutError
-from multiprocessing.dummy import Pool
 
 import sbws.util.requests as requests_utils
 import sbws.util.stem as stem_utils
-from sbws.globals import (
-    HTTP_GET_HEADERS,
-    SOCKET_TIMEOUT,
-    TIMEOUT_MEASUREMENTS,
-    fail_hard,
-)
+from sbws.globals import HTTP_GET_HEADERS, SOCKET_TIMEOUT, fail_hard
 
 from .. import settings
 from ..lib.circuitbuilder import GapsCircuitBuilder as CB
@@ -47,7 +41,6 @@ rng = random.SystemRandom()
 log = logging.getLogger(__name__)
 # Declare the objects that manage the threads global so that sbws can exit
 # gracefully at any time.
-pool = None
 rd = None
 controller = None
 
@@ -58,13 +51,10 @@ traceback."""
 
 
 def stop_threads(signal, frame, exit_code=0):
-    global rd, pool
+    global rd
     log.debug("Stopping sbws.")
     # Avoid new threads to start.
     settings.set_end_event()
-    # Stop Pool threads
-    pool.close()
-    pool.join()
     # Stop ResultDump thread
     rd.thread.join()
     # Stop Tor thread
@@ -609,51 +599,48 @@ def _next_expected_amount(
     return expected_amount
 
 
-def result_putter(result_dump):
-    """Create a function that takes a single argument -- the measurement
-    result -- and return that function so it can be used by someone else"""
-
-    def closure(measurement_result):
-        # Since result_dump thread is calling queue.get() every second,
-        # the queue should be full for only 1 second.
-        # This call blocks at maximum timeout seconds.
-        try:
-            result_dump.queue.put(measurement_result, timeout=3)
-        except queue.Full:
-            # The result would be lost, the scanner will continue working.
-            log.warning(
-                "The queue with measurements is full, when adding %s.\n"
-                "It is possible that the thread that get them to "
-                "write them to the disk (ResultDump.enter) is stalled.",
-                measurement_result,
-            )
-
-    return closure
-
+def result_putter(result_dump, measurement):
+    # Since result_dump thread is calling queue.get() every second,
+    # the queue should be full for only 1 second.
+    # This call blocks at maximum timeout seconds.
+    try:
+        result_dump.queue.put(measurement, timeout=3)
+    except queue.Full:
+        # The result would be lost, the scanner will continue working.
+        log.warning(
+            "The queue with measurements is full, when adding %s.\n"
+            "It is possible that the thread that get them to "
+            "write them to the disk (ResultDump.enter) is stalled.",
+            measurement,
+        )
 
-def result_putter_error(target):
-    """Create a function that takes a single argument -- an error from a
-    measurement -- and return that function so it can be used by someone else
-    """
 
-    def closure(object):
-        if settings.end_event.is_set():
-            return
-        # The only object that can be here if there is not any uncatched
-        # exception is stem.SocketClosed when stopping sbws
-        # An exception here means that the worker thread finished.
-        log.warning(FILLUP_TICKET_MSG)
-        # To print the traceback that happened in the thread, not here in
-        # the main process.
-        log.warning(
-            "".join(
-                traceback.format_exception(
-                    type(object), object, object.__traceback__
-                )
+def result_putter_error(target, exception):
+    print("in result putter error")
+    if settings.end_event.is_set():
+        return
+    # The only object that can be here if there is not any uncatched
+    # exception is stem.SocketClosed when stopping sbws
+    # An exception here means that the worker thread finished.
+    log.warning(FILLUP_TICKET_MSG)
+    # To print the traceback that happened in the thread, not here in
+    # the main process.
+    log.warning(
+        "".join(
+            traceback.format_exception(
+                type(exception), exception, exception.__traceback__
             )
         )
-
-    return closure
+    )
+    log.debug(
+        "".join(
+            target.fingerprint,
+            target.nickname,
+            traceback.format_exception(
+                type(exception), exception, exception.__traceback__
+            ),
+        )
+    )
 
 
 def main_loop(
@@ -666,41 +653,23 @@ def main_loop(
     relay_prioritizer,
     destinations,
 ):
-    """Starts and reuse the threads that measure the relays forever.
+    r"""Create the queue of future measurements for every relay to measure.
 
     It starts a loop that will be run while there is not and event signaling
     that sbws is stopping (because of SIGTERM or SIGINT).
 
-    Then, it starts a second loop with an ordered list (generator) of relays
-    to measure that might a subset of all the current relays in the Network.
+    Then the ``ThreadPoolExecutor`` (executor) queues all the relays to
+    measure in ``Future`` objects. These objects have an ``state``.
 
-    For every relay, it starts a new thread which runs ``measure_relay`` to
-    measure the relay until there are ``max_pending_results`` threads.
+    The executor starts a new thread for every relay to measure, which runs
+    ``measure_relay`` until there are ``max_pending_results`` threads.
     After that, it will reuse a thread that has finished for every relay to
     measure.
-    It is the the pool method ``apply_async`` which starts or reuse a thread.
-    This method returns an ``ApplyResult`` immediately, which has a ``ready``
-    methods that tells whether the thread has finished or not.
 
-    When the thread finish, ie. ``ApplyResult`` is ``ready``, it triggers
-    ``result_putter`` callback, which put the ``Result`` in ``ResultDump``
-    queue and complete immediately.
-
-    ``ResultDump`` thread (started before and out of this function) will get
-    the ``Result`` from the queue and write it to disk, so this doesn't block
-    the measurement threads.
-
-    If there was an exception not caught by ``measure_relay``, it will call
-    instead ``result_putter_error``, which logs the error and complete
-    immediately.
-
-    Before the outer loop iterates, it waits (non blocking) that all
-    the ``Results`` are ready calling ``wait_for_results``.
-    This avoid to start measuring the same relay which might still being
-    measured.
+    Then ``wait_for_results`` is call, to obtain the results in the completed
+    ``future``\s.
 
     """
-    global pool
     log.info("Started the main loop to measure the relays.")
     hbeat = Heartbeat(conf.getpath("paths", "state_fname"))
 
@@ -710,59 +679,57 @@ def main_loop(
     while not settings.end_event.is_set():
         log.debug("Starting a new measurement loop.")
         num_relays = 0
-        # Since loop might finish before pending_results is 0 due waiting too
-        # long, set it here and not outside the loop.
-        pending_results = []
         loop_tstart = time.time()
 
         # Register relay fingerprints to the heartbeat module
         hbeat.register_consensus_fprs(relay_list.relays_fingerprints)
-
-        for target in relay_prioritizer.best_priority():
-            # Don't start measuring a relay if sbws is stopping.
-            if settings.end_event.is_set():
-                break
-            # 40023, disable to decrease state.dat json lines
-            # relay_list.increment_recent_measurement_attempt()
-            target.increment_relay_recent_measurement_attempt()
-            num_relays += 1
-            # callback and callback_err must be non-blocking
-            callback = result_putter(result_dump)
-            callback_err = result_putter_error(target)
-            async_result = pool.apply_async(
-                dispatch_worker_thread,
-                [
+        # num_threads
+        max_pending_results = conf.getint("scanner", "measurement_threads")
+        with concurrent.futures.ThreadPoolExecutor(
+            max_workers=max_pending_results, thread_name_prefix="measurer"
+        ) as executor:
+            log.info("In the executor, queue all future measurements.")
+            # With futures, there's no need for callback, what it was the
+            # callback with multiprocessing library can be just a function
+            # that gets executed when the future result is obtained.
+            pending_results = {
+                executor.submit(
+                    dispatch_worker_thread,
                     args,
                     conf,
                     destinations,
                     circuit_builder,
                     relay_list,
                     target,
-                ],
-                {},
-                callback,
-                callback_err,
+                ): target
+                for target in relay_prioritizer.best_priority()
+            }
+            log.debug("Measurements queued.")
+            # After the submitting all the targets to the executor, the pool
+            # has queued all the relays and pending_results has the list of all
+            # `Future`s.
+
+            # Each target relay_recent_measurement_attempt is incremented in
+            # `wait_for_results` as well as hbeat measured fingerprints.
+            num_relays = len(pending_results)
+            # Without a callback, it's needed to pass `result_dump` here to
+            # call the function that writes the measurement when it's
+            # finished.
+            wait_for_results(
+                executor,
+                hbeat,
+                result_dump,
+                pending_results,
             )
-            pending_results.append(async_result)
-
-            # Register this measurement to the heartbeat module
-            hbeat.register_measured_fpr(target.fingerprint)
-
-        log.debug("Measurements queued.")
-        # After the for has finished, the pool has queued all the relays
-        # and pending_results has the list of all the AsyncResults.
-        # It could also be obtained with pool._cache, which contains
-        # a dictionary with AsyncResults as items.
-        num_relays_to_measure = len(pending_results)
-        wait_for_results(num_relays_to_measure, pending_results)
+            force_get_results(pending_results)
 
         # Print the heartbeat message
         hbeat.print_heartbeat_message()
 
         loop_tstop = time.time()
         loop_tdelta = (loop_tstop - loop_tstart) / 60
-        # At this point, we know the relays that were queued to be measured.
-        # That does not mean they were actually measured.
+        # At this point, we know the relays that were queued to be
+        # measured.
         log.debug(
             "Attempted to measure %s relays in %s minutes",
             num_relays,
@@ -775,113 +742,88 @@ def main_loop(
             stop_threads(signal.SIGTERM, None)
 
 
-def wait_for_results(num_relays_to_measure, pending_results):
-    """Wait for the pool to finish and log progress.
-
-    While there are relays being measured, just log the progress
-    and sleep :const:`~sbws.globals.TIMEOUT_MEASUREMENTS` (3mins),
-    which is approximately the time it can take to measure a relay in
-    the worst case.
-
-    When there has not been any relay measured in ``TIMEOUT_MEASUREMENTS``
-    and there are still relays pending to be measured, it means there is no
-    progress and call :func:`~sbws.core.scanner.force_get_results`.
-
-    This can happen in the case of a bug that makes either
-    :func:`~sbws.core.scanner.measure_relay`,
-    :func:`~sbws.core.scanner.result_putter` (callback) and/or
-    :func:`~sbws.core.scanner.result_putter_error` (callback error) stall.
+def wait_for_results(executor, hbeat, result_dump, pending_results):
+    """Obtain the relays' measurements as they finish.
 
-    .. note:: in a future refactor, this could be simpler by:
+    For every ``Future`` measurements that gets completed, obtain the
+    ``result`` and call ``result_putter``, which put the ``Result`` in
+    ``ResultDump`` queue and complete immediately.
 
-      1. Initializing the pool at the begingging of each loop
-      2. Callling :meth:`~Pool.close`; :meth:`~Pool.join` after
-         :meth:`~Pool.apply_async`,
-         to ensure no new jobs are added until the pool has finished with all
-         the ones in the queue.
-
-      As currently, there would be still two cases when the pool could stall:
-
-      1. There's an exception in ``measure_relay`` and another in
-         ``callback_err``
-      2. There's an exception ``callback``.
+    ``ResultDump`` thread (started before and out of this function) will get
+    the ``Result`` from the queue and write it to disk, so this doesn't block
+    the measurement threads.
 
-      This could also be simpler by not having callback and callback error in
-      ``apply_async`` and instead just calling callback with the
-      ``pending_results``.
+    If there was an exception not caught by ``measure_relay``, it will call
+    instead ``result_putter_error``, which logs the error and complete
+    immediately.
 
-      (callback could be also simpler by not having a thread and queue and
-      just storing to disk, since the time to write to disk is way smaller
-      than the time to request over the network.)
     """
-    num_last_measured = 1
-    while num_last_measured > 0 and not settings.end_event.is_set():
-        log.info(
-            "Pending measurements: %s out of %s: ",
-            len(pending_results),
-            num_relays_to_measure,
-        )
-        log.info("Last measured: %s", num_last_measured)
-        time.sleep(TIMEOUT_MEASUREMENTS)
-        old_pending_results = pending_results
-        pending_results = [r for r in pending_results if not r.ready()]
-        num_last_measured = len(old_pending_results) - len(pending_results)
-    if len(pending_results) > 0:
-        force_get_results(pending_results)
-
-
-def force_get_results(pending_results):
-    """Try to get either the result or an exception, which gets logged.
-
-    It is call by :func:`~sbws.core.scanner.wait_for_results` when
-    the time waiting for the results was long.
-
-    To get either the :class:`~sbws.lib.resultdump.Result` or an exception,
-    call :meth:`~AsyncResult.get` with timeout.
-    Timeout is low since we already waited.
+    num_relays_to_measure = num_pending_results = len(pending_results)
+    with executor:
+        for future_measurement in concurrent.futures.as_completed(
+            pending_results
+        ):
+            target = pending_results[future_measurement]
+            # 40023, disable to decrease state.dat json lines
+            # relay_list.increment_recent_measurement_attempt()
+            target.increment_relay_recent_measurement_attempt()
 
-    ``get`` is not call before, because it blocks and the callbacks
-    are not call.
-    """
-    global pool
-    log.debug("Forcing get")
-    # In case there are no finished AsyncResults, print the cache here
-    # at level info so that is visible even if debug is not enabled.
-    log.info("Pool cache %s", pool._cache)
-    for r in pending_results:
-        try:
-            # HTTP timeout is 10
-            result = r.get(timeout=SOCKET_TIMEOUT + 10)
-            log.warning("Result %s was not stored, it took too long.", result)
-        # TimeoutError is raised when the result is not ready, ie. has not
-        # been processed yet
-        except TimeoutError:
-            log.warning("A result was not stored, it was not ready.")
-            # This is the only place where using psutil so far.
-            import psutil
-
-            log.warning(psutil.Process(os.getpid()).memory_full_info())
-            virtualMemoryInfo = psutil.virtual_memory()
-            availableMemory = virtualMemoryInfo.available
-            log.warning("Memory available %s MB.", availableMemory / 1024 ** 2)
-            dumpstacks()
-        # If the result raised an exception, `get` returns it,
-        # then log any exception so that it can be fixed.
-        # This should not happen, since `callback_err` would have been call
-        # first.
-        except Exception as e:
-            log.critical(FILLUP_TICKET_MSG)
-            # If the exception happened in the threads, `log.exception` does
-            # not have the traceback.
-            # Using `format_exception` instead of of `print_exception` to show
-            # the traceback in all the log handlers.
-            log.warning(
-                "".join(
-                    traceback.format_exception(type(e), e, e.__traceback__)
+            # Register this measurement to the heartbeat module
+            hbeat.register_measured_fpr(target.fingerprint)
+            log.debug(
+                "Future measurement for target %s (%s) is done: %s",
+                target.fingerprint,
+                target.nickname,
+                future_measurement.done(),
+            )
+            try:
+                measurement = future_measurement.result()
+            except Exception as e:
+                result_putter_error(target, e)
+                import psutil
+
+                log.warning(psutil.Process(os.getpid()).memory_full_info())
+                virtualMemoryInfo = psutil.virtual_memory()
+                availableMemory = virtualMemoryInfo.available
+                log.warning(
+                    "Memory available %s MB.", availableMemory / 1024 ** 2
                 )
+                dumpstacks()
+            else:
+                log.info("Measurement ready: %s" % (measurement))
+                result_putter(result_dump, measurement)
+            # `pending_results` has all the initial queued `Future`s,
+            # they don't decrease as they get completed, but we know 1 has be
+            # completed in each loop,
+            num_pending_results -= 1
+            log.info(
+                "Pending measurements: %s out of %s: ",
+                num_pending_results,
+                num_relays_to_measure,
             )
 
 
+def force_get_results(pending_results):
+    """Wait for last futures to finish, before starting new loop."""
+    log.info("Wait for any remaining measurements.")
+    done, not_done = concurrent.futures.wait(
+        pending_results,
+        timeout=SOCKET_TIMEOUT + 10,  # HTTP timeout is 10
+        return_when=concurrent.futures.ALL_COMPLETED,
+    )
+    log.info("Completed futures: %s", len(done))
+    # log.debug([f.__dict__ for f in done])
+    cancelled = [f for f in done if f.cancelled()]
+    if cancelled:
+        log.warning("Cancelled futures: %s", len(cancelled))
+        for f, t in cancelled:
+            log.debug(t.fingerprint)
+    if not_done:
+        log.warning("Not completed futures: %s", len(not_done))
+        for f, t in not_done:
+            log.debug(t.fingerprint)
+
+
 def run_speedtest(args, conf):
     """Initializes all the data and threads needed to measure the relays.
 
@@ -899,7 +841,7 @@ def run_speedtest(args, conf):
     Finally, it calls the function that will manage the measurement threads.
 
     """
-    global rd, pool, controller
+    global rd, controller
 
     controller = stem_utils.launch_or_connect_to_tor(conf)
 
@@ -927,8 +869,6 @@ def run_speedtest(args, conf):
     )
     if not destinations:
         fail_hard(error_msg)
-    max_pending_results = conf.getint("scanner", "measurement_threads")
-    pool = Pool(max_pending_results)
     try:
         main_loop(args, conf, controller, rl, cb, rd, rp, destinations)
     except KeyboardInterrupt:
diff --git a/sbws/globals.py b/sbws/globals.py
index 6954fd1..d628bf0 100644
--- a/sbws/globals.py
+++ b/sbws/globals.py
@@ -73,7 +73,6 @@ SUPERVISED_USER_CONFIG_PATH = "/etc/sbws/sbws.ini"
 SUPERVISED_RUN_DPATH = "/run/sbws/tor"
 
 SOCKET_TIMEOUT = 60  # seconds
-TIMEOUT_MEASUREMENTS = 60 * 3  # 3 minutes
 
 SBWS_SCALE_CONSTANT = 7500
 TORFLOW_SCALING = 1
diff --git a/setup.cfg b/setup.cfg
index 88d4613..7f9d7fb 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -49,6 +49,7 @@ test =
   isort
   ; pylint  ; when we ever fix all the errors it throughs
   pytest
+  pytest-mock
   tox
   sphinx
 doc =
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
index 79d8f37..dafbb43 100644
--- a/tests/unit/conftest.py
+++ b/tests/unit/conftest.py
@@ -229,6 +229,13 @@ def conf(sbwshome_empty, tmpdir):
     conf = _get_default_config()
     conf["paths"]["sbws_home"] = sbwshome_empty
     conf["paths"]["state_fpath"] = str(tmpdir.join(".sbws", "state.dat"))
+    conf["destinations"]["local"] = "on"
+    conf["destinations.local"] = {
+        "url": "http://127.0.0.1:28888/sbws.bin",
+        "verify": False,
+        "country": "ZZ",
+    }
+
     return conf
 
 
diff --git a/tests/unit/core/test_scanner.py b/tests/unit/core/test_scanner.py
index 3f84472..f7ec69d 100644
--- a/tests/unit/core/test_scanner.py
+++ b/tests/unit/core/test_scanner.py
@@ -1,28 +1,101 @@
 """Unit tests for scanner.py."""
+import concurrent.futures
+import logging
+
 import pytest
+from freezegun import freeze_time
+
+from sbws.core import scanner
+from sbws.lib.circuitbuilder import CircuitBuilder
+from sbws.lib.destination import DestinationList
+from sbws.lib.heartbeat import Heartbeat
+from sbws.lib.relayprioritizer import RelayPrioritizer
 
-from sbws.core.scanner import result_putter
+log = logging.getLogger(__name__)
 
 
 def test_result_putter(sbwshome_only_datadir, result_success, rd, end_event):
     if rd is None:
         pytest.skip("ResultDump is None")
     # Put one item in the queue
-    callback = result_putter(rd)
-    callback(result_success)
+    scanner.result_putter(rd, result_success)
     assert rd.queue.qsize() == 1
 
     # Make queue maxsize 1, so that it'll be full after the first callback.
     # The second callback will wait 1 second, then the queue will be empty
     # again.
     rd.queue.maxsize = 1
-    callback(result_success)
+    scanner.result_putter(rd, result_success)
     # after putting 1 result, the queue will be full
     assert rd.queue.qsize() == 1
     assert rd.queue.full()
     # it's still possible to put another results, because the callback will
     # wait 1 second and the queue will be empty again.
-    callback(result_success)
+    scanner.result_putter(rd, result_success)
     assert rd.queue.qsize() == 1
     assert rd.queue.full()
     end_event.set()
+
+
+def test_complete_measurements(
+    args,
+    conf,
+    sbwshome_only_datadir,
+    controller,
+    relay_list,
+    result_dump,
+    rd,
+    mocker,
+):
+    """
+    Test that the ``ThreadPoolExecutor``` creates the epexted number of
+    futures, ``wait_for_results``process all of them and ``force_get_results``
+    completes them if they were not already completed by the time
+    ``wait_for_results`` has already processed them.
+    There are not real measurements done and the ``results`` are None objects.
+    Running the scanner with the test network, test the real measurements.
+
+    """
+    with freeze_time("2020-02-29 10:00:00"):
+        hbeat = Heartbeat(conf.getpath("paths", "state_fname"))
+        # rl = RelayList(args, conf, controller, measurements_period, state)
+        circuit_builder = CircuitBuilder(args, conf, controller, relay_list)
+        # rd = ResultDump(args, conf)
+        relay_prioritizer = RelayPrioritizer(args, conf, relay_list, rd)
+        destinations, error_msg = DestinationList.from_config(
+            conf, circuit_builder, relay_list, controller
+        )
+        num_threads = conf.getint("scanner", "measurement_threads")
+
+        mocker.patch(
+            "sbws.lib.destination.DestinationList.functional_destinations",
+            side_effect=[d for d in destinations._all_dests],
+        )
+        print("start threads")
+        with concurrent.futures.ThreadPoolExecutor(
+            max_workers=num_threads, thread_name_prefix="measurer"
+        ) as executor:
+            pending_results = {
+                executor.submit(
+                    scanner.dispatch_worker_thread,
+                    args,
+                    conf,
+                    destinations,
+                    circuit_builder,
+                    relay_list,
+                    target,
+                ): target
+                for target in relay_prioritizer.best_priority()
+            }
+
+            assert len(pending_results) == 321
+            assert len(hbeat.measured_fp_set) == 0
+            log.debug("Before wait_for_results.")
+            scanner.wait_for_results(executor, hbeat, rd, pending_results)
+            log.debug("After wait_for_results")
+            for pending_result in pending_results:
+                assert pending_result.done() is True
+            assert len(hbeat.measured_fp_set) == 321
+            scanner.force_get_results(pending_results)
+            log.debug("After force_get_results.")
+            assert concurrent.futures.ALL_COMPLETED





More information about the tor-commits mailing list