Prune `PrefilterTree` edges instead of nodes.

Credit and gratitude to Daniel Keller for devising and refining the
probability-based heuristics.

Change-Id: Id99ffb0be27a2549f1ee0af1f7925027a3076127
Reviewed-on: https://code-review.googlesource.com/c/re2/+/60050
Reviewed-by: Perry Lorier <perryl@google.com>
Reviewed-by: Paul Wankadia <junyer@google.com>
diff --git a/re2/prefilter_tree.cc b/re2/prefilter_tree.cc
index 1e4eaff..af50b31 100644
--- a/re2/prefilter_tree.cc
+++ b/re2/prefilter_tree.cc
@@ -4,6 +4,7 @@
 
 #include "re2/prefilter_tree.h"
 
+#include <math.h>
 #include <stddef.h>
 #include <algorithm>
 #include <map>
@@ -65,33 +66,6 @@
 
   NodeMap nodes;
   AssignUniqueIds(&nodes, atom_vec);
-
-  // Identify nodes that are too common among prefilters and are
-  // triggering too many parents. Then get rid of them if possible.
-  // Note that getting rid of a prefilter node simply means they are
-  // no longer necessary for their parent to trigger; that is, we do
-  // not miss out on any regexps triggering by getting rid of a
-  // prefilter node.
-  for (size_t i = 0; i < entries_.size(); i++) {
-    std::vector<int>& parents = entries_[i].parents;
-    if (parents.size() > 8) {
-      // This one triggers too many things. If all the parents are AND
-      // nodes and have other things guarding them, then get rid of
-      // this trigger. TODO(vsri): Adjust the threshold appropriately,
-      // make it a function of total number of nodes?
-      bool have_other_guard = true;
-      for (int parent : parents) {
-        have_other_guard = have_other_guard &&
-            (entries_[parent].propagate_up_at_count > 1);
-      }
-      if (have_other_guard) {
-        for (int parent : parents)
-          entries_[parent].propagate_up_at_count -= 1;
-        parents.clear();  // Forget the parents
-      }
-    }
-  }
-
   if (ExtraDebug)
     PrintDebugInfo(&nodes);
 }
@@ -215,12 +189,9 @@
     Prefilter* prefilter = v[i];
     if (prefilter == NULL)
       continue;
-
     if (CanonicalNode(nodes, prefilter) != prefilter)
       continue;
-
     int id = prefilter->unique_id();
-
     switch (prefilter->op()) {
       default:
         LOG(DFATAL) << "Unexpected op: " << prefilter->op();
@@ -261,6 +232,51 @@
     Entry* entry = &entries_[id];
     entry->regexps.push_back(static_cast<int>(i));
   }
+
+  // Lastly, using probability-based heuristics, we identify nodes
+  // that trigger too many parents and then we try to prune edges.
+  // We use logarithms below to avoid the likelihood of underflow.
+  double log_num_regexps = std::log(prefilter_vec_.size() - unfiltered_.size());
+  // Hoisted this above the loop so that we don't thrash the heap.
+  std::vector<std::pair<int, int>> entries_by_num_edges;
+  for (int i = static_cast<int>(v.size()) - 1; i >= 0; i--) {
+    Prefilter* prefilter = v[i];
+    // Pruning applies only to AND nodes because it "just" reduces
+    // precision; applied to OR nodes, it would break correctness.
+    if (prefilter == NULL || prefilter->op() != Prefilter::AND)
+      continue;
+    if (CanonicalNode(nodes, prefilter) != prefilter)
+      continue;
+    int id = prefilter->unique_id();
+
+    // Sort the current node's children by the numbers of parents.
+    entries_by_num_edges.clear();
+    for (size_t j = 0; j < prefilter->subs()->size(); j++) {
+      int child_id = (*prefilter->subs())[j]->unique_id();
+      const std::vector<int>& parents = entries_[child_id].parents;
+      entries_by_num_edges.emplace_back(parents.size(), child_id);
+    }
+    std::stable_sort(entries_by_num_edges.begin(), entries_by_num_edges.end());
+
+    // A running estimate of how many regexps will be triggered by
+    // pruning the remaining children's edges to the current node.
+    // Our nominal target is one, so the threshold is log(1) == 0;
+    // pruning occurs iff the child has more than nine edges left.
+    double log_num_triggered = log_num_regexps;
+    for (const auto& pair : entries_by_num_edges) {
+      int child_id = pair.second;
+      std::vector<int>& parents = entries_[child_id].parents;
+      if (log_num_triggered > 0.) {
+        log_num_triggered += std::log(parents.size());
+        log_num_triggered -= log_num_regexps;
+      } else if (parents.size() > 9) {
+        auto it = std::find(parents.begin(), parents.end(), id);
+        DCHECK(it != parents.end());
+        parents.erase(it);
+        entries_[id].propagate_up_at_count--;
+      }
+    }
+  }
 }
 
 // Functions for triggering during search.