Browse Source

Refactor the core of choosing by weights into a function

This eliminates duplicated code, and lets us test a hairy piece of
functionality.
Nick Mathewson 12 years ago
parent
commit
07df4dd52d
5 changed files with 160 additions and 96 deletions
  1. 4 0
      changes/bug6538
  2. 66 96
      src/or/routerlist.c
  3. 5 0
      src/or/routerlist.h
  4. 4 0
      src/test/test.h
  5. 81 0
      src/test/test_dir.c

+ 4 - 0
changes/bug6538

@@ -10,3 +10,7 @@
       than it ran through the part of the loop before it had made its
       than it ran through the part of the loop before it had made its
       choice. Fix for bug 6538.
       choice. Fix for bug 6538.
 
 
+  o Code simplifications and refactoring:
+    - Move the core of our "choose a weighted element at random" logic
+      into its own function, and give it unit tests.  Now the logic is
+      testable, and a little less fragile too.

+ 66 - 96
src/or/routerlist.c

@@ -11,6 +11,7 @@
  * servers.
  * servers.
  **/
  **/
 
 
+#define ROUTERLIST_PRIVATE
 #include "or.h"
 #include "or.h"
 #include "circuitbuild.h"
 #include "circuitbuild.h"
 #include "config.h"
 #include "config.h"
@@ -1652,6 +1653,53 @@ router_get_advertised_bandwidth_capped(const routerinfo_t *router)
   return result;
   return result;
 }
 }
 
 
+/** Pick a random element of <b>n_entries</b>-element array <b>entries</b>,
+ * choosing each element with a probability proportional to its value, and
+ * return the index of that element.  If all elements are 0, choose an index
+ * at random. If <b>total_out</b> is provided, set it to the sum of all
+ * elements in the array. Return -1 on error.
+ */
+/* private */ int
+choose_array_element_by_weight(const uint64_t *entries, int n_entries,
+                               uint64_t *total_out)
+{
+  int i, i_chosen=-1, n_chosen=0;
+  uint64_t total_so_far = 0;
+  uint64_t rand_val;
+  uint64_t total = 0;
+
+  for (i = 0; i < n_entries; ++i)
+    total += entries[i];
+
+  if (total_out)
+    *total_out = total;
+
+  if (n_entries < 1)
+    return -1;
+
+  if (total == 0)
+    return crypto_rand_int(n_entries);
+
+  rand_val = crypto_rand_uint64(total);
+
+  for (i = 0; i < n_entries; ++i) {
+    total_so_far += entries[i];
+    if (total_so_far > rand_val) {
+      i_chosen = i;
+      n_chosen++;
+      /* Set rand_val to UINT_MAX rather than stopping the loop. This way,
+       * the time we spend in the loop does not leak which element we chose. */
+      rand_val = UINT64_MAX;
+    }
+  }
+  tor_assert(total_so_far == total);
+  tor_assert(n_chosen == 1);
+  tor_assert(i_chosen >= 0);
+  tor_assert(i_chosen < n_entries);
+
+  return i_chosen;
+}
+
 /** When weighting bridges, enforce these values as lower and upper
 /** When weighting bridges, enforce these values as lower and upper
  * bound for believable bandwidth, because there is no way for us
  * bound for believable bandwidth, because there is no way for us
  * to verify a bridge's bandwidth currently. */
  * to verify a bridge's bandwidth currently. */
@@ -1702,15 +1750,10 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
                                            bandwidth_weight_rule_t rule)
                                            bandwidth_weight_rule_t rule)
 {
 {
   int64_t weight_scale;
   int64_t weight_scale;
-  uint64_t rand_bw;
   double Wg = -1, Wm = -1, We = -1, Wd = -1;
   double Wg = -1, Wm = -1, We = -1, Wd = -1;
   double Wgb = -1, Wmb = -1, Web = -1, Wdb = -1;
   double Wgb = -1, Wmb = -1, Web = -1, Wdb = -1;
-  uint64_t weighted_bw = 0, unweighted_bw = 0;
+  uint64_t weighted_bw = 0;
   uint64_t *bandwidths;
   uint64_t *bandwidths;
-  uint64_t tmp;
-  unsigned int i;
-  unsigned int i_chosen;
-  int have_unknown = 0; /* true iff sl contains element not in consensus. */
 
 
   /* Can't choose exit and guard at same time */
   /* Can't choose exit and guard at same time */
   tor_assert(rule == NO_WEIGHTING ||
   tor_assert(rule == NO_WEIGHTING ||
@@ -1814,7 +1857,6 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
     } else if (node->ri) {
     } else if (node->ri) {
       /* bridge or other descriptor not in our consensus */
       /* bridge or other descriptor not in our consensus */
       this_bw = bridge_get_advertised_bandwidth_bounded(node->ri);
       this_bw = bridge_get_advertised_bandwidth_bounded(node->ri);
-      have_unknown = 1;
     } else {
     } else {
       /* We can't use this one. */
       /* We can't use this one. */
       continue;
       continue;
@@ -1838,69 +1880,22 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
       weight = 0.0;
       weight = 0.0;
 
 
     bandwidths[node_sl_idx] = tor_llround(weight*this_bw + 0.5);
     bandwidths[node_sl_idx] = tor_llround(weight*this_bw + 0.5);
-    weighted_bw += bandwidths[node_sl_idx];
-    unweighted_bw += this_bw;
     if (is_me)
     if (is_me)
       sl_last_weighted_bw_of_me = bandwidths[node_sl_idx];
       sl_last_weighted_bw_of_me = bandwidths[node_sl_idx];
   } SMARTLIST_FOREACH_END(node);
   } SMARTLIST_FOREACH_END(node);
 
 
-  /* XXXX this is a kludge to expose these values. */
-  sl_last_total_weighted_bw = weighted_bw;
-
   log_debug(LD_CIRC, "Choosing node for rule %s based on weights "
   log_debug(LD_CIRC, "Choosing node for rule %s based on weights "
             "Wg=%f Wm=%f We=%f Wd=%f with total bw "U64_FORMAT,
             "Wg=%f Wm=%f We=%f Wd=%f with total bw "U64_FORMAT,
             bandwidth_weight_rule_to_string(rule),
             bandwidth_weight_rule_to_string(rule),
             Wg, Wm, We, Wd, U64_PRINTF_ARG(weighted_bw));
             Wg, Wm, We, Wd, U64_PRINTF_ARG(weighted_bw));
 
 
-  /* If there is no bandwidth, choose at random */
-  if (weighted_bw == 0) {
-    /* Don't warn when using bridges/relays not in the consensus */
-    if (!have_unknown) {
-#define ZERO_BANDWIDTH_WARNING_INTERVAL (15)
-      static ratelim_t zero_bandwidth_warning_limit =
-        RATELIM_INIT(ZERO_BANDWIDTH_WARNING_INTERVAL);
-      char *msg;
-      if ((msg = rate_limit_log(&zero_bandwidth_warning_limit,
-                                approx_time()))) {
-        log_warn(LD_CIRC,
-                 "Weighted bandwidth is "U64_FORMAT" in node selection for "
-                 "rule %s (unweighted was "U64_FORMAT") %s",
-                 U64_PRINTF_ARG(weighted_bw),
-                 bandwidth_weight_rule_to_string(rule),
-                 U64_PRINTF_ARG(unweighted_bw), msg);
-      }
-    }
+  {
+    int idx = choose_array_element_by_weight(bandwidths,
+                                             smartlist_len(sl),
+                                             &sl_last_total_weighted_bw);
     tor_free(bandwidths);
     tor_free(bandwidths);
-    return smartlist_choose(sl);
-  }
-
-  rand_bw = crypto_rand_uint64(weighted_bw);
-
-  /* Last, count through sl until we get to the element we picked */
-  i_chosen = (unsigned)smartlist_len(sl);
-  tmp = 0;
-  for (i=0; i < (unsigned)smartlist_len(sl); i++) {
-    tmp += bandwidths[i];
-    if (tmp > rand_bw) {
-      i_chosen = i;
-      rand_bw = UINT64_MAX;
-    }
-  }
-  i = i_chosen;
-
-  if (i == (unsigned)smartlist_len(sl)) {
-    /* This was once possible due to round-off error, but shouldn't be able
-     * to occur any longer. */
-    tor_fragile_assert();
-    --i;
-    log_warn(LD_BUG, "Round-off error in computing bandwidth had an effect on "
-             " which router we chose. Please tell the developers. "
-             U64_FORMAT" "U64_FORMAT" "U64_FORMAT,
-             U64_PRINTF_ARG(tmp), U64_PRINTF_ARG(rand_bw),
-             U64_PRINTF_ARG(weighted_bw));
+    return idx < 0 ? NULL : smartlist_get(sl, idx);
   }
   }
-  tor_free(bandwidths);
-  return smartlist_get(sl, i);
 }
 }
 
 
 /** Helper function:
 /** Helper function:
@@ -1921,14 +1916,12 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
                                    bandwidth_weight_rule_t rule)
                                    bandwidth_weight_rule_t rule)
 {
 {
   unsigned int i;
   unsigned int i;
-  unsigned int i_chosen;
   uint64_t *bandwidths;
   uint64_t *bandwidths;
   int is_exit;
   int is_exit;
   int is_guard;
   int is_guard;
   int is_fast;
   int is_fast;
-  uint64_t total_nonexit_bw = 0, total_exit_bw = 0, total_bw = 0;
+  uint64_t total_nonexit_bw = 0, total_exit_bw = 0;
   uint64_t total_nonguard_bw = 0, total_guard_bw = 0;
   uint64_t total_nonguard_bw = 0, total_guard_bw = 0;
-  uint64_t rand_bw, tmp;
   double exit_weight;
   double exit_weight;
   double guard_weight;
   double guard_weight;
   int n_unknown = 0;
   int n_unknown = 0;
@@ -2073,7 +2066,6 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
     if (guard_weight <= 0.0)
     if (guard_weight <= 0.0)
       guard_weight = 0.0;
       guard_weight = 0.0;
 
 
-    total_bw = 0;
     sl_last_weighted_bw_of_me = 0;
     sl_last_weighted_bw_of_me = 0;
     for (i=0; i < (unsigned)smartlist_len(sl); i++) {
     for (i=0; i < (unsigned)smartlist_len(sl); i++) {
       tor_assert(bandwidths[i] < UINT64_MAX);
       tor_assert(bandwidths[i] < UINT64_MAX);
@@ -2087,15 +2079,12 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
       else if (is_exit)
       else if (is_exit)
         bandwidths[i] = tor_llround(bandwidths[i] * exit_weight);
         bandwidths[i] = tor_llround(bandwidths[i] * exit_weight);
 
 
-      total_bw += bandwidths[i];
       if (i == (unsigned) me_idx)
       if (i == (unsigned) me_idx)
         sl_last_weighted_bw_of_me = bandwidths[i];
         sl_last_weighted_bw_of_me = bandwidths[i];
     }
     }
   }
   }
 
 
-  /* XXXX this is a kludge to expose these values. */
-  sl_last_total_weighted_bw = total_bw;
-
+#if 0
   log_debug(LD_CIRC, "Total weighted bw = "U64_FORMAT
   log_debug(LD_CIRC, "Total weighted bw = "U64_FORMAT
             ", exit bw = "U64_FORMAT
             ", exit bw = "U64_FORMAT
             ", nonexit bw = "U64_FORMAT", exit weight = %f "
             ", nonexit bw = "U64_FORMAT", exit weight = %f "
@@ -2108,37 +2097,18 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
             exit_weight, (int)(rule == WEIGHT_FOR_EXIT),
             exit_weight, (int)(rule == WEIGHT_FOR_EXIT),
             U64_PRINTF_ARG(total_guard_bw), U64_PRINTF_ARG(total_nonguard_bw),
             U64_PRINTF_ARG(total_guard_bw), U64_PRINTF_ARG(total_nonguard_bw),
             guard_weight, (int)(rule == WEIGHT_FOR_GUARD));
             guard_weight, (int)(rule == WEIGHT_FOR_GUARD));
+#endif
 
 
-  /* Almost done: choose a random value from the bandwidth weights. */
-  rand_bw = crypto_rand_uint64(total_bw);
-
-  /* Last, count through sl until we get to the element we picked */
-  tmp = 0;
-  i_chosen = (unsigned)smartlist_len(sl);
-  for (i=0; i < (unsigned)smartlist_len(sl); i++) {
-    tmp += bandwidths[i];
-
-    if (tmp > rand_bw) {
-      i_chosen = i;
-      rand_bw = UINT64_MAX;
-    }
+  {
+    int idx = choose_array_element_by_weight(bandwidths,
+                                             smartlist_len(sl),
+                                             &sl_last_total_weighted_bw);
+    tor_free(bandwidths);
+    tor_free(fast_bits);
+    tor_free(exit_bits);
+    tor_free(guard_bits);
+    return idx < 0 ? NULL : smartlist_get(sl, idx);
   }
   }
-  i = i_chosen;
-  if (i == (unsigned)smartlist_len(sl)) {
-    /* This was once possible due to round-off error, but shouldn't be able
-     * to occur any longer. */
-    tor_fragile_assert();
-    --i;
-    log_warn(LD_BUG, "Round-off error in computing bandwidth had an effect on "
-             " which router we chose. Please tell the developers. "
-             U64_FORMAT " " U64_FORMAT " " U64_FORMAT, U64_PRINTF_ARG(tmp),
-             U64_PRINTF_ARG(rand_bw), U64_PRINTF_ARG(total_bw));
-  }
-  tor_free(bandwidths);
-  tor_free(fast_bits);
-  tor_free(exit_bits);
-  tor_free(guard_bits);
-  return smartlist_get(sl, i);
 }
 }
 
 
 /** Choose a random element of status list <b>sl</b>, weighted by
 /** Choose a random element of status list <b>sl</b>, weighted by

+ 5 - 0
src/or/routerlist.h

@@ -216,5 +216,10 @@ int hex_digest_nickname_decode(const char *hexdigest,
                                char *nickname_qualifier_out,
                                char *nickname_qualifier_out,
                                char *nickname_out);
                                char *nickname_out);
 
 
+#ifdef ROUTERLIST_PRIVATE
+int choose_array_element_by_weight(const uint64_t *entries, int n_entries,
+                                   uint64_t *total_out);
+#endif
+
 #endif
 #endif
 
 

+ 4 - 0
src/test/test.h

@@ -65,6 +65,10 @@
 
 
 #define test_memeq_hex(expr1, hex) test_mem_op_hex(expr1, ==, hex)
 #define test_memeq_hex(expr1, hex) test_mem_op_hex(expr1, ==, hex)
 
 
+#define tt_double_op(a,op,b)                                            \
+  tt_assert_test_type(a,b,#a" "#op" "#b,double,(val1_ op val2_),"%f",   \
+                      TT_EXIT_TEST_FUNCTION)
+
 const char *get_fname(const char *name);
 const char *get_fname(const char *name);
 crypto_pk_t *pk_generate(int idx);
 crypto_pk_t *pk_generate(int idx);
 
 

+ 81 - 0
src/test/test_dir.c

@@ -7,6 +7,7 @@
 #define DIRSERV_PRIVATE
 #define DIRSERV_PRIVATE
 #define DIRVOTE_PRIVATE
 #define DIRVOTE_PRIVATE
 #define ROUTER_PRIVATE
 #define ROUTER_PRIVATE
+#define ROUTERLIST_PRIVATE
 #define HIBERNATE_PRIVATE
 #define HIBERNATE_PRIVATE
 #include "or.h"
 #include "or.h"
 #include "directory.h"
 #include "directory.h"
@@ -1381,6 +1382,85 @@ test_dir_v3_networkstatus(void)
     ns_detached_signatures_free(dsig2);
     ns_detached_signatures_free(dsig2);
 }
 }
 
 
+static void
+test_dir_random_weighted(void *testdata)
+{
+  int histogram[10];
+  uint64_t vals[10] = {3,1,2,4,6,0,7,5,8,9}, total=0;
+  uint64_t zeros[5] = {0,0,0,0,0};
+  int i, choice;
+  const int n = 50000;
+  double max_sq_error;
+  (void) testdata;
+
+  /* Try a ten-element array with values from 0 through 10. The values are
+   * in a scrambled order to make sure we don't depend on order. */
+  memset(histogram,0,sizeof(histogram));
+  for (i=0; i<10; ++i)
+    total += vals[i];
+  tt_int_op(total, ==, 45);
+  for (i=0; i<n; ++i) {
+    uint64_t t;
+    choice = choose_array_element_by_weight(vals, 10, &t);
+    tt_int_op(t, ==, total);
+    tt_int_op(choice, >=, 0);
+    tt_int_op(choice, <, 10);
+    histogram[choice]++;
+  }
+
+  /* Now see if we chose things about frequently enough. */
+  max_sq_error = 0;
+  for (i=0; i<10; ++i) {
+    int expected = (int)(n*vals[i]/total);
+    double frac_diff = 0, sq;
+    TT_BLATHER(("  %d : %5d vs %5d\n", (int)vals[i], histogram[i], expected));
+    if (expected)
+      frac_diff = (histogram[i] - expected) / ((double)expected);
+    else
+      tt_int_op(histogram[i], ==, 0);
+
+    sq = frac_diff * frac_diff;
+    if (sq > max_sq_error)
+      max_sq_error = sq;
+  }
+  /* It should almost always be much much less than this.  If you want to
+   * figure out the odds, please feel free. */
+  tt_double_op(max_sq_error, <, .05);
+
+  /* Now try a singleton; do we choose it? */
+  for (i = 0; i < 100; ++i) {
+    choice = choose_array_element_by_weight(vals, 1, NULL);
+    tt_int_op(choice, ==, 0);
+  }
+
+  /* Now try an array of zeros.  We should choose randomly. */
+  memset(histogram,0,sizeof(histogram));
+  for (i = 0; i < n; ++i) {
+    uint64_t t;
+    choice = choose_array_element_by_weight(zeros, 5, &t);
+    tt_int_op(t, ==, 0);
+    tt_int_op(choice, >=, 0);
+    tt_int_op(choice, <, 5);
+    histogram[choice]++;
+  }
+  /* Now see if we chose things about frequently enough. */
+  max_sq_error = 0;
+  for (i=0; i<5; ++i) {
+    int expected = n/5;
+    double frac_diff = 0, sq;
+    TT_BLATHER(("  %d : %5d vs %5d\n", (int)vals[i], histogram[i], expected));
+    frac_diff = (histogram[i] - expected) / ((double)expected);
+    sq = frac_diff * frac_diff;
+    if (sq > max_sq_error)
+      max_sq_error = sq;
+  }
+  /* It should almost always be much much less than this.  If you want to
+   * figure out the odds, please feel free. */
+  tt_double_op(max_sq_error, <, .05);
+ done:
+  ;
+}
+
 #define DIR_LEGACY(name)                                                   \
 #define DIR_LEGACY(name)                                                   \
   { #name, legacy_test_helper, TT_FORK, &legacy_setup, test_dir_ ## name }
   { #name, legacy_test_helper, TT_FORK, &legacy_setup, test_dir_ ## name }
 
 
@@ -1396,6 +1476,7 @@ struct testcase_t dir_tests[] = {
   DIR_LEGACY(measured_bw),
   DIR_LEGACY(measured_bw),
   DIR_LEGACY(param_voting),
   DIR_LEGACY(param_voting),
   DIR_LEGACY(v3_networkstatus),
   DIR_LEGACY(v3_networkstatus),
+  DIR(random_weighted),
   END_OF_TESTCASES
   END_OF_TESTCASES
 };
 };