Переглянути джерело

Merge remote-tracking branch 'andrea/bug7191_v2'

Nick Mathewson 11 роки тому
батько
коміт
04a509e04b
3 змінених файлів з 134 додано та 17 видалено
  1. 3 0
      changes/bug7191
  2. 95 17
      src/common/container.c
  3. 36 0
      src/test/test_containers.c

+ 3 - 0
changes/bug7191

@@ -0,0 +1,3 @@
+ o Bugfixes
+   - The smartlist_bsearch_idx() function was broken for lists of length zero
+     or one; fix it.  This fixes bug7191.

+ 95 - 17
src/common/container.c

@@ -571,31 +571,109 @@ smartlist_bsearch_idx(const smartlist_t *sl, const void *key,
                       int (*compare)(const void *key, const void **member),
                       int *found_out)
 {
-  int hi = smartlist_len(sl) - 1, lo = 0, cmp, mid;
+  int hi, lo, cmp, mid, len, diff;
+
+  tor_assert(sl);
+  tor_assert(compare);
+  tor_assert(found_out);
+
+  len = smartlist_len(sl);
+
+  /* Check for the trivial case of a zero-length list */
+  if (len == 0) {
+    *found_out = 0;
+    /* We already know smartlist_len(sl) is 0 in this case */
+    return 0;
+  }
+
+  /* Okay, we have a real search to do */
+  tor_assert(len > 0);
+  lo = 0;
+  hi = len - 1;
+
+  /*
+   * These invariants are always true:
+   *
+   * For all i such that 0 <= i < lo, sl[i] < key
+   * For all i such that hi < i <= len, sl[i] > key
+   */
 
   while (lo <= hi) {
-    mid = (lo + hi) / 2;
+    diff = hi - lo;
+    /*
+     * We want mid = (lo + hi) / 2, but that could lead to overflow, so
+     * instead diff = hi - lo (non-negative because of loop condition), and
+     * then hi = lo + diff, mid = (lo + lo + diff) / 2 = lo + (diff / 2).
+     */
+    mid = lo + (diff / 2);
     cmp = compare(key, (const void**) &(sl->list[mid]));
-    if (cmp>0) { /* key > sl[mid] */
-      lo = mid+1;
-    } else if (cmp<0) { /* key < sl[mid] */
-      hi = mid-1;
-    } else { /* key == sl[mid] */
+    if (cmp == 0) {
+      /* sl[mid] == key; we found it */
       *found_out = 1;
       return mid;
-    }
-  }
-  /* lo > hi. */
-  {
-    tor_assert(lo >= 0);
-    if (lo < smartlist_len(sl)) {
-      cmp = compare(key, (const void**) &(sl->list[lo]));
+    } else if (cmp > 0) {
+      /*
+       * key > sl[mid] and an index i such that sl[i] == key must
+       * have i > mid if it exists.
+       */
+
+      /*
+       * Since lo <= mid <= hi, hi can only decrease on each iteration (by
+       * being set to mid - 1) and hi is initially len - 1, mid < len should
+       * always hold, and this is not symmetric with the left end of list
+       * mid > 0 test below.  A key greater than the right end of the list
+       * should eventually lead to lo == hi == mid == len - 1, and then
+       * we set lo to len below and fall out to the same exit we hit for
+       * a key in the middle of the list but not matching.  Thus, we just
+       * assert for consistency here rather than handle a mid == len case.
+       */
+      tor_assert(mid < len);
+      /* Move lo to the element immediately after sl[mid] */
+      lo = mid + 1;
+    } else {
+      /* This should always be true in this case */
       tor_assert(cmp < 0);
-    } else if (smartlist_len(sl)) {
-      cmp = compare(key, (const void**) &(sl->list[smartlist_len(sl)-1]));
-      tor_assert(cmp > 0);
+
+      /*
+       * key < sl[mid] and an index i such that sl[i] == key must
+       * have i < mid if it exists.
+       */
+
+      if (mid > 0) {
+        /* Normal case, move hi to the element immediately before sl[mid] */
+        hi = mid - 1;
+      } else {
+        /* These should always be true in this case */
+        tor_assert(mid == lo);
+        tor_assert(mid == 0);
+        /*
+         * We were at the beginning of the list and concluded that every
+         * element e compares e > key.
+         */
+        *found_out = 0;
+        return 0;
+      }
     }
   }
+
+  /*
+   * lo > hi; we have no element matching key but we have elements falling
+   * on both sides of it.  The lo index points to the first element > key.
+   */
+  tor_assert(lo == hi + 1); /* All other cases should have been handled */
+  tor_assert(lo >= 0);
+  tor_assert(lo <= len);
+  tor_assert(hi >= 0);
+  tor_assert(hi <= len);
+
+  if (lo < len) {
+    cmp = compare(key, (const void **) &(sl->list[lo]));
+    tor_assert(cmp < 0);
+  } else {
+    cmp = compare(key, (const void **) &(sl->list[len-1]));
+    tor_assert(cmp > 0);
+  }
+
   *found_out = 0;
   return lo;
 }

+ 36 - 0
src/test/test_containers.c

@@ -16,6 +16,15 @@ compare_strs_(const void **a, const void **b)
   return strcmp(s1, s2);
 }
 
+/** Helper: return a tristate based on comparing the strings in <b>a</b> and
+ * *<b>b</b>. */
+static int
+compare_strs_for_bsearch_(const void *a, const void **b)
+{
+  const char *s1 = a, *s2 = *b;
+  return strcmp(s1, s2);
+}
+
 /** Helper: return a tristate based on comparing the strings in *<b>a</b> and
  * *<b>b</b>, excluding a's first character, and ignoring case. */
 static int
@@ -204,6 +213,8 @@ test_container_smartlist_strings(void)
   /* Test bsearch_idx */
   {
     int f;
+    smartlist_t *tmp = NULL;
+
     test_eq(0, smartlist_bsearch_idx(sl," aaa",compare_without_first_ch_,&f));
     test_eq(f, 0);
     test_eq(0, smartlist_bsearch_idx(sl," and",compare_without_first_ch_,&f));
@@ -216,6 +227,31 @@ test_container_smartlist_strings(void)
     test_eq(f, 0);
     test_eq(7, smartlist_bsearch_idx(sl," zzzz",compare_without_first_ch_,&f));
     test_eq(f, 0);
+
+    /* Test trivial cases for list of length 0 or 1 */
+    tmp = smartlist_new();
+    test_eq(0, smartlist_bsearch_idx(tmp, "foo",
+                                     compare_strs_for_bsearch_, &f));
+    test_eq(f, 0);
+    smartlist_insert(tmp, 0, (void *)("bar"));
+    test_eq(1, smartlist_bsearch_idx(tmp, "foo",
+                                     compare_strs_for_bsearch_, &f));
+    test_eq(f, 0);
+    test_eq(0, smartlist_bsearch_idx(tmp, "aaa",
+                                     compare_strs_for_bsearch_, &f));
+    test_eq(f, 0);
+    test_eq(0, smartlist_bsearch_idx(tmp, "bar",
+                                     compare_strs_for_bsearch_, &f));
+    test_eq(f, 1);
+    /* ... and one for length 2 */
+    smartlist_insert(tmp, 1, (void *)("foo"));
+    test_eq(1, smartlist_bsearch_idx(tmp, "foo",
+                                     compare_strs_for_bsearch_, &f));
+    test_eq(f, 1);
+    test_eq(2, smartlist_bsearch_idx(tmp, "goo",
+                                     compare_strs_for_bsearch_, &f));
+    test_eq(f, 0);
+    smartlist_free(tmp);
   }
 
   /* Test reverse() and pop_last() */