瀏覽代碼

Merge remote-tracking branch 'public/bug6538'

Conflicts:
	configure.ac
Nick Mathewson 11 年之前
父節點
當前提交
75c9ccd4f8
共有 8 個文件被更改,包括 331 次插入157 次删除
  1. 16 0
      changes/bug6538
  2. 1 0
      configure.ac
  3. 16 1
      src/common/util.c
  4. 1 0
      src/common/util.h
  5. 156 156
      src/or/routerlist.c
  6. 14 0
      src/or/routerlist.h
  7. 4 0
      src/test/test.h
  8. 123 0
      src/test/test_dir.c

+ 16 - 0
changes/bug6538

@@ -0,0 +1,16 @@
+  o Minor bugfixes:
+    - Switch weighted node selection rule from using a list of doubles
+      to using a list of int64_t. This should make the process slightly
+      easier to debug and maintain. Needed for fix for bug 6538.
+
+  o Security features:
+    - Switch to a completely time-invariant approach for picking nodes
+      weighted by bandwidth. Our old approach would run through the
+      part of the loop after it had made its choice slightly slower
+      than it ran through the part of the loop before it had made its
+      choice. Fix for bug 6538.
+
+  o Code simplifications and refactoring:
+    - Move the core of our "choose a weighted element at random" logic
+      into its own function, and give it unit tests.  Now the logic is
+      testable, and a little less fragile too.

+ 1 - 0
configure.ac

@@ -303,6 +303,7 @@ AC_CHECK_FUNCS(
         inet_aton \
         ioctl \
         issetugid \
+        llround \
         localtime_r \
         lround \
         memmem \

+ 16 - 1
src/common/util.c

@@ -333,7 +333,7 @@ tor_mathlog(double d)
 }
 
 /** Return the long integer closest to d.  We define this wrapper here so
- * that not all users of math.h need to use the right incancations to get
+ * that not all users of math.h need to use the right intancations to get
  * the c99 functions. */
 long
 tor_lround(double d)
@@ -347,6 +347,21 @@ tor_lround(double d)
 #endif
 }
 
+/** Return the 64-bit integer closest to d.  We define this wrapper here so
+ * that not all users of math.h need to use the right incantations to get the
+ * c99 functions. */
+int64_t
+tor_llround(double d)
+{
+#if defined(HAVE_LLROUND)
+  return (int64_t)llround(d);
+#elif defined(HAVE_RINT)
+  return (int64_t)rint(d);
+#else
+  return (int64_t)(d > 0 ? d + 0.5 : ceil(d - 0.5));
+#endif
+}
+
 /** Returns floor(log2(u64)).  If u64 is 0, (incorrectly) returns 0. */
 int
 tor_log2(uint64_t u64)

+ 1 - 0
src/common/util.h

@@ -160,6 +160,7 @@ void tor_log_mallinfo(int severity);
 /* Math functions */
 double tor_mathlog(double d) ATTR_CONST;
 long tor_lround(double d) ATTR_CONST;
+int64_t tor_llround(double d) ATTR_CONST;
 int tor_log2(uint64_t u64) ATTR_CONST;
 uint64_t round_to_power_of_2(uint64_t u64);
 unsigned round_to_next_multiple_of(unsigned number, unsigned divisor);

+ 156 - 156
src/or/routerlist.c

@@ -11,6 +11,7 @@
  * servers.
  **/
 
+#define ROUTERLIST_PRIVATE
 #include "or.h"
 #include "circuitbuild.h"
 #include "config.h"
@@ -1647,6 +1648,92 @@ router_get_advertised_bandwidth_capped(const routerinfo_t *router)
   return result;
 }
 
+/** Given an array of double/uint64_t unions that are currently being used as
+ * doubles, convert them to uint64_t, and try to scale them linearly so as to
+ * much of the range of uint64_t. If <b>total_out</b> is provided, set it to
+ * the sum of all elements in the array _before_ scaling. */
+/* private */ void
+scale_array_elements_to_u64(u64_dbl_t *entries, int n_entries,
+                            uint64_t *total_out)
+{
+  double total = 0.0;
+  double scale_factor;
+  int i;
+  /* big, but far away from overflowing an int64_t */
+#define SCALE_TO_U64_MAX (INT64_MAX / 4)
+
+  for (i = 0; i < n_entries; ++i)
+    total += entries[i].dbl;
+
+  scale_factor = SCALE_TO_U64_MAX / total;
+
+  for (i = 0; i < n_entries; ++i)
+    entries[i].u64 = tor_llround(entries[i].dbl * scale_factor);
+
+  if (total_out)
+    *total_out = (uint64_t) total;
+
+#undef SCALE_TO_U64_MAX
+}
+
+/** Time-invariant 64-bit greater-than; works on two integers in the range
+ * (0,INT64_MAX). */
+#if SIZEOF_VOID_P == 8
+#define gt_i64_timei(a,b) ((a) > (b))
+#else
+static INLINE int
+gt_i64_timei(uint64_t a, uint64_t b)
+{
+  int64_t diff = (int64_t) (b - a);
+  int res = diff >> 63;
+  return res & 1;
+}
+#endif
+
+/** Pick a random element of <b>n_entries</b>-element array <b>entries</b>,
+ * choosing each element with a probability proportional to its (uint64_t)
+ * value, and return the index of that element.  If all elements are 0, choose
+ * an index at random. Return -1 on error.
+ */
+/* private */ int
+choose_array_element_by_weight(const u64_dbl_t *entries, int n_entries)
+{
+  int i, i_chosen=-1, n_chosen=0;
+  uint64_t total_so_far = 0;
+  uint64_t rand_val;
+  uint64_t total = 0;
+
+  for (i = 0; i < n_entries; ++i)
+    total += entries[i].u64;
+
+  if (n_entries < 1)
+    return -1;
+
+  if (total == 0)
+    return crypto_rand_int(n_entries);
+
+  tor_assert(total < INT64_MAX);
+
+  rand_val = crypto_rand_uint64(total);
+
+  for (i = 0; i < n_entries; ++i) {
+    total_so_far += entries[i].u64;
+    if (gt_i64_timei(total_so_far, rand_val)) {
+      i_chosen = i;
+      n_chosen++;
+      /* Set rand_val to INT64_MAX rather than stopping the loop. This way,
+       * the time we spend in the loop does not leak which element we chose. */
+      rand_val = INT64_MAX;
+    }
+  }
+  tor_assert(total_so_far == total);
+  tor_assert(n_chosen == 1);
+  tor_assert(i_chosen >= 0);
+  tor_assert(i_chosen < n_entries);
+
+  return i_chosen;
+}
+
 /** When weighting bridges, enforce these values as lower and upper
  * bound for believable bandwidth, because there is no way for us
  * to verify a bridge's bandwidth currently. */
@@ -1697,16 +1784,10 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
                                            bandwidth_weight_rule_t rule)
 {
   int64_t weight_scale;
-  int64_t rand_bw;
   double Wg = -1, Wm = -1, We = -1, Wd = -1;
   double Wgb = -1, Wmb = -1, Web = -1, Wdb = -1;
-  double weighted_bw = 0, unweighted_bw = 0;
-  double *bandwidths;
-  double tmp = 0;
-  unsigned int i;
-  unsigned int i_chosen;
-  unsigned int i_has_been_chosen;
-  int have_unknown = 0; /* true iff sl contains element not in consensus. */
+  uint64_t weighted_bw = 0;
+  u64_dbl_t *bandwidths;
 
   /* Can't choose exit and guard at same time */
   tor_assert(rule == NO_WEIGHTING ||
@@ -1787,7 +1868,7 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
   Web /= weight_scale;
   Wdb /= weight_scale;
 
-  bandwidths = tor_malloc_zero(sizeof(double)*smartlist_len(sl));
+  bandwidths = tor_malloc_zero(sizeof(u64_dbl_t)*smartlist_len(sl));
 
   // Cycle through smartlist and total the bandwidth.
   SMARTLIST_FOREACH_BEGIN(sl, const node_t *, node) {
@@ -1810,7 +1891,6 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
     } else if (node->ri) {
       /* bridge or other descriptor not in our consensus */
       this_bw = bridge_get_advertised_bandwidth_bounded(node->ri);
-      have_unknown = 1;
     } else {
       /* We can't use this one. */
       continue;
@@ -1826,72 +1906,32 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
     } else { // middle
       weight = (is_dir ? Wmb*Wm : Wm);
     }
-
-    bandwidths[node_sl_idx] = weight*this_bw;
-    weighted_bw += weight*this_bw;
-    unweighted_bw += this_bw;
+    /* These should be impossible; but overflows here would be bad, so let's
+     * make sure. */
+    if (this_bw < 0)
+      this_bw = 0;
+    if (weight < 0.0)
+      weight = 0.0;
+
+    bandwidths[node_sl_idx].dbl = weight*this_bw + 0.5;
     if (is_me)
-      sl_last_weighted_bw_of_me = weight*this_bw;
+      sl_last_weighted_bw_of_me = (uint64_t) bandwidths[node_sl_idx].dbl;
   } SMARTLIST_FOREACH_END(node);
 
-  /* XXXX this is a kludge to expose these values. */
-  sl_last_total_weighted_bw = weighted_bw;
-
   log_debug(LD_CIRC, "Choosing node for rule %s based on weights "
-            "Wg=%f Wm=%f We=%f Wd=%f with total bw %f",
+            "Wg=%f Wm=%f We=%f Wd=%f with total bw "U64_FORMAT,
             bandwidth_weight_rule_to_string(rule),
-            Wg, Wm, We, Wd, weighted_bw);
-
-  /* If there is no bandwidth, choose at random */
-  if (DBL_TO_U64(weighted_bw) == 0) {
-    /* Don't warn when using bridges/relays not in the consensus */
-    if (!have_unknown) {
-#define ZERO_BANDWIDTH_WARNING_INTERVAL (15)
-      static ratelim_t zero_bandwidth_warning_limit =
-        RATELIM_INIT(ZERO_BANDWIDTH_WARNING_INTERVAL);
-      char *msg;
-      if ((msg = rate_limit_log(&zero_bandwidth_warning_limit,
-                                approx_time()))) {
-        log_warn(LD_CIRC,
-                 "Weighted bandwidth is %f in node selection for rule %s "
-                 "(unweighted was %f) %s",
-                 weighted_bw, bandwidth_weight_rule_to_string(rule),
-                 unweighted_bw, msg);
-      }
-    }
-    tor_free(bandwidths);
-    return smartlist_choose(sl);
-  }
+            Wg, Wm, We, Wd, U64_PRINTF_ARG(weighted_bw));
 
-  rand_bw = crypto_rand_uint64(DBL_TO_U64(weighted_bw));
-  rand_bw++; /* crypto_rand_uint64() counts from 0, and we need to count
-              * from 1 below. See bug 1203 for details. */
+  scale_array_elements_to_u64(bandwidths, smartlist_len(sl),
+                              &sl_last_total_weighted_bw);
 
-  /* Last, count through sl until we get to the element we picked */
-  i_chosen = (unsigned)smartlist_len(sl);
-  i_has_been_chosen = 0;
-  tmp = 0.0;
-  for (i=0; i < (unsigned)smartlist_len(sl); i++) {
-    tmp += bandwidths[i];
-    if (tmp >= rand_bw && !i_has_been_chosen) {
-      i_chosen = i;
-      i_has_been_chosen = 1;
-    }
-  }
-  i = i_chosen;
-
-  if (i == (unsigned)smartlist_len(sl)) {
-    /* This was once possible due to round-off error, but shouldn't be able
-     * to occur any longer. */
-    tor_fragile_assert();
-    --i;
-    log_warn(LD_BUG, "Round-off error in computing bandwidth had an effect on "
-             " which router we chose. Please tell the developers. "
-             "%f " U64_FORMAT " %f", tmp, U64_PRINTF_ARG(rand_bw),
-             weighted_bw);
+  {
+    int idx = choose_array_element_by_weight(bandwidths,
+                                             smartlist_len(sl));
+    tor_free(bandwidths);
+    return idx < 0 ? NULL : smartlist_get(sl, idx);
   }
-  tor_free(bandwidths);
-  return smartlist_get(sl, i);
 }
 
 /** Helper function:
@@ -1912,17 +1952,16 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
                                    bandwidth_weight_rule_t rule)
 {
   unsigned int i;
-  unsigned int i_chosen;
-  unsigned int i_has_been_chosen;
-  int32_t *bandwidths;
+  u64_dbl_t *bandwidths;
   int is_exit;
   int is_guard;
-  uint64_t total_nonexit_bw = 0, total_exit_bw = 0, total_bw = 0;
-  uint64_t total_nonguard_bw = 0, total_guard_bw = 0;
-  uint64_t rand_bw, tmp;
+  int is_fast;
+  double total_nonexit_bw = 0, total_exit_bw = 0;
+  double total_nonguard_bw = 0, total_guard_bw = 0;
   double exit_weight;
   double guard_weight;
   int n_unknown = 0;
+  bitarray_t *fast_bits;
   bitarray_t *exit_bits;
   bitarray_t *guard_bits;
   int me_idx = -1;
@@ -1946,10 +1985,9 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
   }
 
   /* First count the total bandwidth weight, and make a list
-   * of each value.  <0 means "unknown; no routerinfo."  We use the
-   * bits of negative values to remember whether the router was fast (-x)&1
-   * and whether it was an exit (-x)&2 or guard (-x)&4.  Yes, it's a hack. */
-  bandwidths = tor_malloc(sizeof(int32_t)*smartlist_len(sl));
+   * of each value.  We use UINT64_MAX to indicate "unknown". */
+  bandwidths = tor_malloc_zero(sizeof(u64_dbl_t)*smartlist_len(sl));
+  fast_bits = bitarray_init_zero(smartlist_len(sl));
   exit_bits = bitarray_init_zero(smartlist_len(sl));
   guard_bits = bitarray_init_zero(smartlist_len(sl));
 
@@ -1957,7 +1995,6 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
   SMARTLIST_FOREACH_BEGIN(sl, const node_t *, node) {
     /* first, learn what bandwidth we think i has */
     int is_known = 1;
-    int32_t flags = 0;
     uint32_t this_bw = 0;
     i = node_sl_idx;
 
@@ -1970,12 +2007,7 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
       if (node->rs->has_bandwidth) {
         this_bw = kb_to_bytes(node->rs->bandwidth);
       } else { /* guess */
-        /* XXX024 once consensuses always list bandwidths, we can take
-         * this guessing business out. -RD */
         is_known = 0;
-        flags = node->rs->is_fast ? 1 : 0;
-        flags |= is_exit ? 2 : 0;
-        flags |= is_guard ? 4 : 0;
       }
     } else if (node->ri) {
       /* Must be a bridge if we're willing to use it */
@@ -1986,12 +2018,11 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
       bitarray_set(exit_bits, i);
     if (is_guard)
       bitarray_set(guard_bits, i);
+    if (node->is_fast)
+      bitarray_set(fast_bits, i);
+
     if (is_known) {
-      bandwidths[i] = (int32_t) this_bw;
-      /* Casting this_bw to int32_t is safe because both kb_to_bytes
-         and bridge_get_advertised_bandwidth_bounded limit it to below
-         INT32_MAX. */
-      tor_assert(bandwidths[i] >= 0);
+      bandwidths[i].dbl = this_bw;
       if (is_guard)
         total_guard_bw += this_bw;
       else
@@ -2002,14 +2033,16 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
         total_nonexit_bw += this_bw;
     } else {
       ++n_unknown;
-      bandwidths[node_sl_idx] = -flags;
+      bandwidths[i].dbl = -1.0;
     }
   } SMARTLIST_FOREACH_END(node);
 
+#define EPSILON .1
+
   /* Now, fill in the unknown values. */
   if (n_unknown) {
     int32_t avg_fast, avg_slow;
-    if (total_exit_bw+total_nonexit_bw) {
+    if (total_exit_bw+total_nonexit_bw < EPSILON) {
       /* if there's some bandwidth, there's at least one known router,
        * so no worries about div by 0 here */
       int n_known = smartlist_len(sl)-n_unknown;
@@ -2020,26 +2053,27 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
       avg_slow = 20000;
     }
     for (i=0; i<(unsigned)smartlist_len(sl); ++i) {
-      int32_t bw = bandwidths[i];
-      if (bw>=0)
+      if (bandwidths[i].dbl >= 0.0)
         continue;
-      is_exit = ((-bw)&2);
-      is_guard = ((-bw)&4);
-      bandwidths[i] = ((-bw)&1) ? avg_fast : avg_slow;
+      is_fast = bitarray_is_set(fast_bits, i);
+      is_exit = bitarray_is_set(exit_bits, i);
+      is_guard = bitarray_is_set(guard_bits, i);
+      bandwidths[i].dbl = is_fast ? avg_fast : avg_slow;
       if (is_exit)
-        total_exit_bw += bandwidths[i];
+        total_exit_bw += bandwidths[i].dbl;
       else
-        total_nonexit_bw += bandwidths[i];
+        total_nonexit_bw += bandwidths[i].dbl;
       if (is_guard)
-        total_guard_bw += bandwidths[i];
+        total_guard_bw += bandwidths[i].dbl;
       else
-        total_nonguard_bw += bandwidths[i];
+        total_nonguard_bw += bandwidths[i].dbl;
     }
   }
 
   /* If there's no bandwidth at all, pick at random. */
-  if (!(total_exit_bw+total_nonexit_bw)) {
+  if (total_exit_bw+total_nonexit_bw < EPSILON) {
     tor_free(bandwidths);
+    tor_free(fast_bits);
     tor_free(exit_bits);
     tor_free(guard_bits);
     return smartlist_choose(sl);
@@ -2054,12 +2088,12 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
      * For detailed derivation of this formula, see
      *   http://archives.seul.org/or/dev/Jul-2007/msg00056.html
      */
-    if (rule == WEIGHT_FOR_EXIT || !total_exit_bw)
+    if (rule == WEIGHT_FOR_EXIT || total_exit_bw<EPSILON)
       exit_weight = 1.0;
     else
       exit_weight = 1.0 - all_bw/(3.0*exit_bw);
 
-    if (rule == WEIGHT_FOR_GUARD || !total_guard_bw)
+    if (rule == WEIGHT_FOR_GUARD || total_guard_bw<EPSILON)
       guard_weight = 1.0;
     else
       guard_weight = 1.0 - all_bw/(3.0*guard_bw);
@@ -2070,29 +2104,25 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
     if (guard_weight <= 0.0)
       guard_weight = 0.0;
 
-    total_bw = 0;
     sl_last_weighted_bw_of_me = 0;
     for (i=0; i < (unsigned)smartlist_len(sl); i++) {
-      uint64_t bw;
+      tor_assert(bandwidths[i].dbl >= 0.0);
+
       is_exit = bitarray_is_set(exit_bits, i);
       is_guard = bitarray_is_set(guard_bits, i);
       if (is_exit && is_guard)
-        bw = ((uint64_t)(bandwidths[i] * exit_weight * guard_weight));
+        bandwidths[i].dbl *= exit_weight * guard_weight;
       else if (is_guard)
-        bw = ((uint64_t)(bandwidths[i] * guard_weight));
+        bandwidths[i].dbl *= guard_weight;
       else if (is_exit)
-        bw = ((uint64_t)(bandwidths[i] * exit_weight));
-      else
-        bw = bandwidths[i];
-      total_bw += bw;
+        bandwidths[i].dbl *= exit_weight;
+
       if (i == (unsigned) me_idx)
-        sl_last_weighted_bw_of_me = bw;
+        sl_last_weighted_bw_of_me = (uint64_t) bandwidths[i].dbl;
     }
   }
 
-  /* XXXX this is a kludge to expose these values. */
-  sl_last_total_weighted_bw = total_bw;
-
+#if 0
   log_debug(LD_CIRC, "Total weighted bw = "U64_FORMAT
             ", exit bw = "U64_FORMAT
             ", nonexit bw = "U64_FORMAT", exit weight = %f "
@@ -2105,50 +2135,20 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
             exit_weight, (int)(rule == WEIGHT_FOR_EXIT),
             U64_PRINTF_ARG(total_guard_bw), U64_PRINTF_ARG(total_nonguard_bw),
             guard_weight, (int)(rule == WEIGHT_FOR_GUARD));
+#endif
 
-  /* Almost done: choose a random value from the bandwidth weights. */
-  rand_bw = crypto_rand_uint64(total_bw);
-  rand_bw++; /* crypto_rand_uint64() counts from 0, and we need to count
-              * from 1 below. See bug 1203 for details. */
-
-  /* Last, count through sl until we get to the element we picked */
-  tmp = 0;
-  i_chosen = (unsigned)smartlist_len(sl);
-  i_has_been_chosen = 0;
-  for (i=0; i < (unsigned)smartlist_len(sl); i++) {
-    is_exit = bitarray_is_set(exit_bits, i);
-    is_guard = bitarray_is_set(guard_bits, i);
-
-    /* Weights can be 0 if not counting guards/exits */
-    if (is_exit && is_guard)
-      tmp += ((uint64_t)(bandwidths[i] * exit_weight * guard_weight));
-    else if (is_guard)
-      tmp += ((uint64_t)(bandwidths[i] * guard_weight));
-    else if (is_exit)
-      tmp += ((uint64_t)(bandwidths[i] * exit_weight));
-    else
-      tmp += bandwidths[i];
+  scale_array_elements_to_u64(bandwidths, smartlist_len(sl),
+                              &sl_last_total_weighted_bw);
 
-    if (tmp >= rand_bw && !i_has_been_chosen) {
-      i_chosen = i;
-      i_has_been_chosen = 1;
-    }
+  {
+    int idx = choose_array_element_by_weight(bandwidths,
+                                             smartlist_len(sl));
+    tor_free(bandwidths);
+    tor_free(fast_bits);
+    tor_free(exit_bits);
+    tor_free(guard_bits);
+    return idx < 0 ? NULL : smartlist_get(sl, idx);
   }
-  i = i_chosen;
-  if (i == (unsigned)smartlist_len(sl)) {
-    /* This was once possible due to round-off error, but shouldn't be able
-     * to occur any longer. */
-    tor_fragile_assert();
-    --i;
-    log_warn(LD_BUG, "Round-off error in computing bandwidth had an effect on "
-             " which router we chose. Please tell the developers. "
-             U64_FORMAT " " U64_FORMAT " " U64_FORMAT, U64_PRINTF_ARG(tmp),
-             U64_PRINTF_ARG(rand_bw), U64_PRINTF_ARG(total_bw));
-  }
-  tor_free(bandwidths);
-  tor_free(exit_bits);
-  tor_free(guard_bits);
-  return smartlist_get(sl, i);
 }
 
 /** Choose a random element of status list <b>sl</b>, weighted by

+ 14 - 0
src/or/routerlist.h

@@ -216,5 +216,19 @@ int hex_digest_nickname_decode(const char *hexdigest,
                                char *nickname_qualifier_out,
                                char *nickname_out);
 
+#ifdef ROUTERLIST_PRIVATE
+/** Helper type for choosing routers by bandwidth: contains a union of
+ * double and uint64_t. Before we call scale_array_elements_to_u64, it holds
+ * a double; after, it holds a uint64_t. */
+typedef union u64_dbl_t {
+  uint64_t u64;
+  double dbl;
+} u64_dbl_t;
+
+int choose_array_element_by_weight(const u64_dbl_t *entries, int n_entries);
+void scale_array_elements_to_u64(u64_dbl_t *entries, int n_entries,
+                                 uint64_t *total_out);
+#endif
+
 #endif
 

+ 4 - 0
src/test/test.h

@@ -65,6 +65,10 @@
 
 #define test_memeq_hex(expr1, hex) test_mem_op_hex(expr1, ==, hex)
 
+#define tt_double_op(a,op,b)                                            \
+  tt_assert_test_type(a,b,#a" "#op" "#b,double,(val1_ op val2_),"%f",   \
+                      TT_EXIT_TEST_FUNCTION)
+
 const char *get_fname(const char *name);
 crypto_pk_t *pk_generate(int idx);
 

+ 123 - 0
src/test/test_dir.c

@@ -4,9 +4,12 @@
 /* See LICENSE for licensing information */
 
 #include "orconfig.h"
+#include <math.h>
+
 #define DIRSERV_PRIVATE
 #define DIRVOTE_PRIVATE
 #define ROUTER_PRIVATE
+#define ROUTERLIST_PRIVATE
 #define HIBERNATE_PRIVATE
 #include "or.h"
 #include "directory.h"
@@ -1389,6 +1392,124 @@ test_dir_v3_networkstatus(void)
     ns_detached_signatures_free(dsig2);
 }
 
+static void
+test_dir_scale_bw(void *testdata)
+{
+  double v[8] = { 2.0/3,
+                  7.0,
+                  1.0,
+                  3.0,
+                  1.0/5,
+                  1.0/7,
+                  12.0,
+                  24.0 };
+  u64_dbl_t vals[8];
+  uint64_t total;
+  int i;
+
+  (void) testdata;
+
+  for (i=0; i<8; ++i)
+    vals[i].dbl = v[i];
+
+  scale_array_elements_to_u64(vals, 8, &total);
+
+  tt_int_op((int)total, ==, 48);
+  total = 0;
+  for (i=0; i<8; ++i) {
+    total += vals[i].u64;
+  }
+  tt_assert(total >= (U64_LITERAL(1)<<60));
+  tt_assert(total <= (U64_LITERAL(1)<<62));
+
+  for (i=0; i<8; ++i) {
+    double ratio = ((double)vals[i].u64) / vals[2].u64;
+    tt_double_op(fabs(ratio - v[i]), <, .00001);
+  }
+
+ done:
+  ;
+}
+
+static void
+test_dir_random_weighted(void *testdata)
+{
+  int histogram[10];
+  uint64_t vals[10] = {3,1,2,4,6,0,7,5,8,9}, total=0;
+  u64_dbl_t inp[10];
+  int i, choice;
+  const int n = 50000;
+  double max_sq_error;
+  (void) testdata;
+
+  /* Try a ten-element array with values from 0 through 10. The values are
+   * in a scrambled order to make sure we don't depend on order. */
+  memset(histogram,0,sizeof(histogram));
+  for (i=0; i<10; ++i) {
+    inp[i].u64 = vals[i];
+    total += vals[i];
+  }
+  tt_int_op(total, ==, 45);
+  for (i=0; i<n; ++i) {
+    choice = choose_array_element_by_weight(inp, 10);
+    tt_int_op(choice, >=, 0);
+    tt_int_op(choice, <, 10);
+    histogram[choice]++;
+  }
+
+  /* Now see if we chose things about frequently enough. */
+  max_sq_error = 0;
+  for (i=0; i<10; ++i) {
+    int expected = (int)(n*vals[i]/total);
+    double frac_diff = 0, sq;
+    TT_BLATHER(("  %d : %5d vs %5d\n", (int)vals[i], histogram[i], expected));
+    if (expected)
+      frac_diff = (histogram[i] - expected) / ((double)expected);
+    else
+      tt_int_op(histogram[i], ==, 0);
+
+    sq = frac_diff * frac_diff;
+    if (sq > max_sq_error)
+      max_sq_error = sq;
+  }
+  /* It should almost always be much much less than this.  If you want to
+   * figure out the odds, please feel free. */
+  tt_double_op(max_sq_error, <, .05);
+
+  /* Now try a singleton; do we choose it? */
+  for (i = 0; i < 100; ++i) {
+    choice = choose_array_element_by_weight(inp, 1);
+    tt_int_op(choice, ==, 0);
+  }
+
+  /* Now try an array of zeros.  We should choose randomly. */
+  memset(histogram,0,sizeof(histogram));
+  for (i = 0; i < 5; ++i)
+    inp[i].u64 = 0;
+  for (i = 0; i < n; ++i) {
+    choice = choose_array_element_by_weight(inp, 5);
+    tt_int_op(choice, >=, 0);
+    tt_int_op(choice, <, 5);
+    histogram[choice]++;
+  }
+  /* Now see if we chose things about frequently enough. */
+  max_sq_error = 0;
+  for (i=0; i<5; ++i) {
+    int expected = n/5;
+    double frac_diff = 0, sq;
+    TT_BLATHER(("  %d : %5d vs %5d\n", (int)vals[i], histogram[i], expected));
+    frac_diff = (histogram[i] - expected) / ((double)expected);
+    sq = frac_diff * frac_diff;
+    if (sq > max_sq_error)
+      max_sq_error = sq;
+  }
+  /* It should almost always be much much less than this.  If you want to
+   * figure out the odds, please feel free. */
+  tt_double_op(max_sq_error, <, .05);
+ done:
+  ;
+}
+
 #define DIR_LEGACY(name)                                                   \
   { #name, legacy_test_helper, TT_FORK, &legacy_setup, test_dir_ ## name }
 
@@ -1404,6 +1525,8 @@ struct testcase_t dir_tests[] = {
   DIR_LEGACY(measured_bw),
   DIR_LEGACY(param_voting),
   DIR_LEGACY(v3_networkstatus),
+  DIR(random_weighted),
+  DIR(scale_bw),
   END_OF_TESTCASES
 };