[tor-commits] [tor/master] Refactor the core of choosing by weights into a function

nickm at torproject.org nickm at torproject.org
Tue Sep 11 21:52:39 UTC 2012


commit 07df4dd52d3ab2eea2e8a8fc3222a5d297d077de
Author: Nick Mathewson <nickm at torproject.org>
Date:   Thu Aug 9 13:47:42 2012 -0400

    Refactor the core of choosing by weights into a function
    
    This eliminates duplicated code, and lets us test a hairy piece of
    functionality.
---
 changes/bug6538     |    4 +
 src/or/routerlist.c |  162 +++++++++++++++++++++------------------------------
 src/or/routerlist.h |    5 ++
 src/test/test.h     |    4 +
 src/test/test_dir.c |   81 +++++++++++++++++++++++++
 5 files changed, 160 insertions(+), 96 deletions(-)

diff --git a/changes/bug6538 b/changes/bug6538
index fc9e583..03c168b 100644
--- a/changes/bug6538
+++ b/changes/bug6538
@@ -10,3 +10,7 @@
       than it ran through the part of the loop before it had made its
       choice. Fix for bug 6538.
 
+  o Code simplifications and refactoring:
+    - Move the core of our "choose a weighted element at random" logic
+      into its own function, and give it unit tests.  Now the logic is
+      testable, and a little less fragile too.
diff --git a/src/or/routerlist.c b/src/or/routerlist.c
index 801c496..1c0aca8 100644
--- a/src/or/routerlist.c
+++ b/src/or/routerlist.c
@@ -11,6 +11,7 @@
  * servers.
  **/
 
+#define ROUTERLIST_PRIVATE
 #include "or.h"
 #include "circuitbuild.h"
 #include "config.h"
@@ -1652,6 +1653,53 @@ router_get_advertised_bandwidth_capped(const routerinfo_t *router)
   return result;
 }
 
+/** Pick a random element of <b>n_entries</b>-element array <b>entries</b>,
+ * choosing each element with a probability proportional to its value, and
+ * return the index of that element.  If all elements are 0, choose an index
+ * at random. If <b>total_out</b> is provided, set it to the sum of all
+ * elements in the array. Return -1 on error.
+ */
+/* private */ int
+choose_array_element_by_weight(const uint64_t *entries, int n_entries,
+                               uint64_t *total_out)
+{
+  int i, i_chosen=-1, n_chosen=0;
+  uint64_t total_so_far = 0;
+  uint64_t rand_val;
+  uint64_t total = 0;
+
+  for (i = 0; i < n_entries; ++i)
+    total += entries[i];
+
+  if (total_out)
+    *total_out = total;
+
+  if (n_entries < 1)
+    return -1;
+
+  if (total == 0)
+    return crypto_rand_int(n_entries);
+
+  rand_val = crypto_rand_uint64(total);
+
+  for (i = 0; i < n_entries; ++i) {
+    total_so_far += entries[i];
+    if (total_so_far > rand_val) {
+      i_chosen = i;
+      n_chosen++;
+      /* Set rand_val to UINT_MAX rather than stopping the loop. This way,
+       * the time we spend in the loop does not leak which element we chose. */
+      rand_val = UINT64_MAX;
+    }
+  }
+  tor_assert(total_so_far == total);
+  tor_assert(n_chosen == 1);
+  tor_assert(i_chosen >= 0);
+  tor_assert(i_chosen < n_entries);
+
+  return i_chosen;
+}
+
 /** When weighting bridges, enforce these values as lower and upper
  * bound for believable bandwidth, because there is no way for us
  * to verify a bridge's bandwidth currently. */
@@ -1702,15 +1750,10 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
                                            bandwidth_weight_rule_t rule)
 {
   int64_t weight_scale;
-  uint64_t rand_bw;
   double Wg = -1, Wm = -1, We = -1, Wd = -1;
   double Wgb = -1, Wmb = -1, Web = -1, Wdb = -1;
-  uint64_t weighted_bw = 0, unweighted_bw = 0;
+  uint64_t weighted_bw = 0;
   uint64_t *bandwidths;
-  uint64_t tmp;
-  unsigned int i;
-  unsigned int i_chosen;
-  int have_unknown = 0; /* true iff sl contains element not in consensus. */
 
   /* Can't choose exit and guard at same time */
   tor_assert(rule == NO_WEIGHTING ||
@@ -1814,7 +1857,6 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
     } else if (node->ri) {
       /* bridge or other descriptor not in our consensus */
       this_bw = bridge_get_advertised_bandwidth_bounded(node->ri);
-      have_unknown = 1;
     } else {
       /* We can't use this one. */
       continue;
@@ -1838,69 +1880,22 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
       weight = 0.0;
 
     bandwidths[node_sl_idx] = tor_llround(weight*this_bw + 0.5);
-    weighted_bw += bandwidths[node_sl_idx];
-    unweighted_bw += this_bw;
     if (is_me)
       sl_last_weighted_bw_of_me = bandwidths[node_sl_idx];
   } SMARTLIST_FOREACH_END(node);
 
-  /* XXXX this is a kludge to expose these values. */
-  sl_last_total_weighted_bw = weighted_bw;
-
   log_debug(LD_CIRC, "Choosing node for rule %s based on weights "
             "Wg=%f Wm=%f We=%f Wd=%f with total bw "U64_FORMAT,
             bandwidth_weight_rule_to_string(rule),
             Wg, Wm, We, Wd, U64_PRINTF_ARG(weighted_bw));
 
-  /* If there is no bandwidth, choose at random */
-  if (weighted_bw == 0) {
-    /* Don't warn when using bridges/relays not in the consensus */
-    if (!have_unknown) {
-#define ZERO_BANDWIDTH_WARNING_INTERVAL (15)
-      static ratelim_t zero_bandwidth_warning_limit =
-        RATELIM_INIT(ZERO_BANDWIDTH_WARNING_INTERVAL);
-      char *msg;
-      if ((msg = rate_limit_log(&zero_bandwidth_warning_limit,
-                                approx_time()))) {
-        log_warn(LD_CIRC,
-                 "Weighted bandwidth is "U64_FORMAT" in node selection for "
-                 "rule %s (unweighted was "U64_FORMAT") %s",
-                 U64_PRINTF_ARG(weighted_bw),
-                 bandwidth_weight_rule_to_string(rule),
-                 U64_PRINTF_ARG(unweighted_bw), msg);
-      }
-    }
+  {
+    int idx = choose_array_element_by_weight(bandwidths,
+                                             smartlist_len(sl),
+                                             &sl_last_total_weighted_bw);
     tor_free(bandwidths);
-    return smartlist_choose(sl);
-  }
-
-  rand_bw = crypto_rand_uint64(weighted_bw);
-
-  /* Last, count through sl until we get to the element we picked */
-  i_chosen = (unsigned)smartlist_len(sl);
-  tmp = 0;
-  for (i=0; i < (unsigned)smartlist_len(sl); i++) {
-    tmp += bandwidths[i];
-    if (tmp > rand_bw) {
-      i_chosen = i;
-      rand_bw = UINT64_MAX;
-    }
-  }
-  i = i_chosen;
-
-  if (i == (unsigned)smartlist_len(sl)) {
-    /* This was once possible due to round-off error, but shouldn't be able
-     * to occur any longer. */
-    tor_fragile_assert();
-    --i;
-    log_warn(LD_BUG, "Round-off error in computing bandwidth had an effect on "
-             " which router we chose. Please tell the developers. "
-             U64_FORMAT" "U64_FORMAT" "U64_FORMAT,
-             U64_PRINTF_ARG(tmp), U64_PRINTF_ARG(rand_bw),
-             U64_PRINTF_ARG(weighted_bw));
+    return idx < 0 ? NULL : smartlist_get(sl, idx);
   }
-  tor_free(bandwidths);
-  return smartlist_get(sl, i);
 }
 
 /** Helper function:
@@ -1921,14 +1916,12 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
                                    bandwidth_weight_rule_t rule)
 {
   unsigned int i;
-  unsigned int i_chosen;
   uint64_t *bandwidths;
   int is_exit;
   int is_guard;
   int is_fast;
-  uint64_t total_nonexit_bw = 0, total_exit_bw = 0, total_bw = 0;
+  uint64_t total_nonexit_bw = 0, total_exit_bw = 0;
   uint64_t total_nonguard_bw = 0, total_guard_bw = 0;
-  uint64_t rand_bw, tmp;
   double exit_weight;
   double guard_weight;
   int n_unknown = 0;
@@ -2073,7 +2066,6 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
     if (guard_weight <= 0.0)
       guard_weight = 0.0;
 
-    total_bw = 0;
     sl_last_weighted_bw_of_me = 0;
     for (i=0; i < (unsigned)smartlist_len(sl); i++) {
       tor_assert(bandwidths[i] < UINT64_MAX);
@@ -2087,15 +2079,12 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
       else if (is_exit)
         bandwidths[i] = tor_llround(bandwidths[i] * exit_weight);
 
-      total_bw += bandwidths[i];
       if (i == (unsigned) me_idx)
         sl_last_weighted_bw_of_me = bandwidths[i];
     }
   }
 
-  /* XXXX this is a kludge to expose these values. */
-  sl_last_total_weighted_bw = total_bw;
-
+#if 0
   log_debug(LD_CIRC, "Total weighted bw = "U64_FORMAT
             ", exit bw = "U64_FORMAT
             ", nonexit bw = "U64_FORMAT", exit weight = %f "
@@ -2108,37 +2097,18 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
             exit_weight, (int)(rule == WEIGHT_FOR_EXIT),
             U64_PRINTF_ARG(total_guard_bw), U64_PRINTF_ARG(total_nonguard_bw),
             guard_weight, (int)(rule == WEIGHT_FOR_GUARD));
+#endif
 
-  /* Almost done: choose a random value from the bandwidth weights. */
-  rand_bw = crypto_rand_uint64(total_bw);
-
-  /* Last, count through sl until we get to the element we picked */
-  tmp = 0;
-  i_chosen = (unsigned)smartlist_len(sl);
-  for (i=0; i < (unsigned)smartlist_len(sl); i++) {
-    tmp += bandwidths[i];
-
-    if (tmp > rand_bw) {
-      i_chosen = i;
-      rand_bw = UINT64_MAX;
-    }
+  {
+    int idx = choose_array_element_by_weight(bandwidths,
+                                             smartlist_len(sl),
+                                             &sl_last_total_weighted_bw);
+    tor_free(bandwidths);
+    tor_free(fast_bits);
+    tor_free(exit_bits);
+    tor_free(guard_bits);
+    return idx < 0 ? NULL : smartlist_get(sl, idx);
   }
-  i = i_chosen;
-  if (i == (unsigned)smartlist_len(sl)) {
-    /* This was once possible due to round-off error, but shouldn't be able
-     * to occur any longer. */
-    tor_fragile_assert();
-    --i;
-    log_warn(LD_BUG, "Round-off error in computing bandwidth had an effect on "
-             " which router we chose. Please tell the developers. "
-             U64_FORMAT " " U64_FORMAT " " U64_FORMAT, U64_PRINTF_ARG(tmp),
-             U64_PRINTF_ARG(rand_bw), U64_PRINTF_ARG(total_bw));
-  }
-  tor_free(bandwidths);
-  tor_free(fast_bits);
-  tor_free(exit_bits);
-  tor_free(guard_bits);
-  return smartlist_get(sl, i);
 }
 
 /** Choose a random element of status list <b>sl</b>, weighted by
diff --git a/src/or/routerlist.h b/src/or/routerlist.h
index 8dcc6eb..0b9b297 100644
--- a/src/or/routerlist.h
+++ b/src/or/routerlist.h
@@ -216,5 +216,10 @@ int hex_digest_nickname_decode(const char *hexdigest,
                                char *nickname_qualifier_out,
                                char *nickname_out);
 
+#ifdef ROUTERLIST_PRIVATE
+int choose_array_element_by_weight(const uint64_t *entries, int n_entries,
+                                   uint64_t *total_out);
+#endif
+
 #endif
 
diff --git a/src/test/test.h b/src/test/test.h
index 0b6e6c6..6dcb949 100644
--- a/src/test/test.h
+++ b/src/test/test.h
@@ -65,6 +65,10 @@
 
 #define test_memeq_hex(expr1, hex) test_mem_op_hex(expr1, ==, hex)
 
+#define tt_double_op(a,op,b)                                            \
+  tt_assert_test_type(a,b,#a" "#op" "#b,double,(val1_ op val2_),"%f",   \
+                      TT_EXIT_TEST_FUNCTION)
+
 const char *get_fname(const char *name);
 crypto_pk_t *pk_generate(int idx);
 
diff --git a/src/test/test_dir.c b/src/test/test_dir.c
index 83c6120..ed0c5a1 100644
--- a/src/test/test_dir.c
+++ b/src/test/test_dir.c
@@ -7,6 +7,7 @@
 #define DIRSERV_PRIVATE
 #define DIRVOTE_PRIVATE
 #define ROUTER_PRIVATE
+#define ROUTERLIST_PRIVATE
 #define HIBERNATE_PRIVATE
 #include "or.h"
 #include "directory.h"
@@ -1381,6 +1382,85 @@ test_dir_v3_networkstatus(void)
     ns_detached_signatures_free(dsig2);
 }
 
+static void
+test_dir_random_weighted(void *testdata)
+{
+  int histogram[10];
+  uint64_t vals[10] = {3,1,2,4,6,0,7,5,8,9}, total=0;
+  uint64_t zeros[5] = {0,0,0,0,0};
+  int i, choice;
+  const int n = 50000;
+  double max_sq_error;
+  (void) testdata;
+
+  /* Try a ten-element array with values from 0 through 10. The values are
+   * in a scrambled order to make sure we don't depend on order. */
+  memset(histogram,0,sizeof(histogram));
+  for (i=0; i<10; ++i)
+    total += vals[i];
+  tt_int_op(total, ==, 45);
+  for (i=0; i<n; ++i) {
+    uint64_t t;
+    choice = choose_array_element_by_weight(vals, 10, &t);
+    tt_int_op(t, ==, total);
+    tt_int_op(choice, >=, 0);
+    tt_int_op(choice, <, 10);
+    histogram[choice]++;
+  }
+
+  /* Now see if we chose things about frequently enough. */
+  max_sq_error = 0;
+  for (i=0; i<10; ++i) {
+    int expected = (int)(n*vals[i]/total);
+    double frac_diff = 0, sq;
+    TT_BLATHER(("  %d : %5d vs %5d\n", (int)vals[i], histogram[i], expected));
+    if (expected)
+      frac_diff = (histogram[i] - expected) / ((double)expected);
+    else
+      tt_int_op(histogram[i], ==, 0);
+
+    sq = frac_diff * frac_diff;
+    if (sq > max_sq_error)
+      max_sq_error = sq;
+  }
+  /* It should almost always be much much less than this.  If you want to
+   * figure out the odds, please feel free. */
+  tt_double_op(max_sq_error, <, .05);
+
+  /* Now try a singleton; do we choose it? */
+  for (i = 0; i < 100; ++i) {
+    choice = choose_array_element_by_weight(vals, 1, NULL);
+    tt_int_op(choice, ==, 0);
+  }
+
+  /* Now try an array of zeros.  We should choose randomly. */
+  memset(histogram,0,sizeof(histogram));
+  for (i = 0; i < n; ++i) {
+    uint64_t t;
+    choice = choose_array_element_by_weight(zeros, 5, &t);
+    tt_int_op(t, ==, 0);
+    tt_int_op(choice, >=, 0);
+    tt_int_op(choice, <, 5);
+    histogram[choice]++;
+  }
+  /* Now see if we chose things about frequently enough. */
+  max_sq_error = 0;
+  for (i=0; i<5; ++i) {
+    int expected = n/5;
+    double frac_diff = 0, sq;
+    TT_BLATHER(("  %d : %5d vs %5d\n", (int)vals[i], histogram[i], expected));
+    frac_diff = (histogram[i] - expected) / ((double)expected);
+    sq = frac_diff * frac_diff;
+    if (sq > max_sq_error)
+      max_sq_error = sq;
+  }
+  /* It should almost always be much much less than this.  If you want to
+   * figure out the odds, please feel free. */
+  tt_double_op(max_sq_error, <, .05);
+ done:
+  ;
+}
+
 #define DIR_LEGACY(name)                                                   \
   { #name, legacy_test_helper, TT_FORK, &legacy_setup, test_dir_ ## name }
 
@@ -1396,6 +1476,7 @@ struct testcase_t dir_tests[] = {
   DIR_LEGACY(measured_bw),
   DIR_LEGACY(param_voting),
   DIR_LEGACY(v3_networkstatus),
+  DIR(random_weighted),
   END_OF_TESTCASES
 };
 





More information about the tor-commits mailing list