Make the BitState bitmap use list heads.

Change-Id: I5b8acd2236291504e8a4ca4a21e99c4e644f1d70
Reviewed-on: https://code-review.googlesource.com/c/re2/+/38990
Reviewed-by: Paul Wankadia <junyer@google.com>
diff --git a/re2/bitstate.cc b/re2/bitstate.cc
index b69d9bf..6f045b1 100644
--- a/re2/bitstate.cc
+++ b/re2/bitstate.cc
@@ -5,10 +5,10 @@
 // Tested by search_test.cc, exhaustive_test.cc, tester.cc
 
 // Prog::SearchBitState is a regular expression search with submatch
-// tracking for small regular expressions and texts.  Like
-// testing/backtrack.cc, it allocates a bit vector with (length of
-// text) * (length of prog) bits, to make sure it never explores the
-// same (character position, instruction) state multiple times.  This
+// tracking for small regular expressions and texts.  Similarly to
+// testing/backtrack.cc, it allocates a bitmap with (count of
+// lists) * (length of prog) bits to make sure it never explores the
+// same (instruction list, character position) multiple times.  This
 // limits the search to run in time linear in the length of the text.
 //
 // Unlike testing/backtrack.cc, SearchBitState is not recursive
@@ -64,7 +64,7 @@
 
   // Search state
   static const int VisitedBits = 32;
-  PODArray<uint32_t> visited_;  // bitmap: (Inst*, char*) pairs visited
+  PODArray<uint32_t> visited_;  // bitmap: (list ID, char*) pairs visited
   PODArray<const char*> cap_;   // capture registers
   PODArray<Job> job_;           // stack of text positions to explore
   int njob_;                    // stack size
@@ -80,11 +80,12 @@
     njob_(0) {
 }
 
-// Should the search visit the (id, p) pair?
+// Given id, which *must* be a list head, we can look up its list ID.
+// Then the question is: Should the search visit the (list ID, p) pair?
 // If so, remember that it was visited so that the next time,
 // we don't repeat the visit.
 bool BitState::ShouldVisit(int id, const char* p) {
-  int n = id * static_cast<int>(text_.size()+1) +
+  int n = prog_->list_heads()[id] * static_cast<int>(text_.size()+1) +
           static_cast<int>(p-text_.begin());
   if (visited_[n/VisitedBits] & (1 << (n & (VisitedBits-1))))
     return false;
@@ -302,7 +303,7 @@
     submatch_[i] = StringPiece();
 
   // Allocate scratch space.
-  int nvisited = prog_->size() * static_cast<int>(text.size()+1);
+  int nvisited = prog_->list_count() * static_cast<int>(text.size()+1);
   nvisited = (nvisited + VisitedBits-1) / VisitedBits;
   visited_ = PODArray<uint32_t>(nvisited);
   memset(visited_.data(), 0, nvisited*sizeof visited_[0]);
diff --git a/re2/compile.cc b/re2/compile.cc
index 3f8e0cc..ab18cef 100644
--- a/re2/compile.cc
+++ b/re2/compile.cc
@@ -1202,7 +1202,10 @@
   if (max_mem_ <= 0) {
     prog_->set_dfa_mem(1<<20);
   } else {
-    int64_t m = max_mem_ - sizeof(Prog) - prog_->size_*sizeof(Prog::Inst);
+    int64_t m = max_mem_ - sizeof(Prog);
+    m -= prog_->size_*sizeof(Prog::Inst);  // account for inst_
+    if (prog_->CanBitState())
+      m -= prog_->size_*sizeof(uint16_t);  // account for list_heads_
     if (m < 0)
       m = 0;
     prog_->set_dfa_mem(m);
diff --git a/re2/prog.cc b/re2/prog.cc
index 102e6d5..4b3ea67 100644
--- a/re2/prog.cc
+++ b/re2/prog.cc
@@ -630,7 +630,17 @@
   // Finally, replace the old instructions with the new instructions.
   size_ = static_cast<int>(flat.size());
   inst_ = PODArray<Inst>(size_);
-  memmove(inst_.data(), flat.data(), size_*sizeof(inst_[0]));
+  memmove(inst_.data(), flat.data(), size_*sizeof inst_[0]);
+
+  // Populate the list heads for BitState.
+  // 512 instructions limits the memory footprint to 1KiB.
+  if (size_ <= 512) {
+    list_heads_ = PODArray<uint16_t>(size_);
+    // 0xFF makes it more obvious if we try to look up a non-head.
+    memset(list_heads_.data(), 0xFF, size_*sizeof list_heads_[0]);
+    for (int i = 0; i < list_count_; ++i)
+      list_heads_[flatmap[i]] = i;
+  }
 }
 
 void Prog::MarkSuccessors(SparseArray<int>* rootmap,
diff --git a/re2/prog.h b/re2/prog.h
index 332b5fc..68df8d6 100644
--- a/re2/prog.h
+++ b/re2/prog.h
@@ -206,6 +206,7 @@
   void set_reversed(bool reversed) { reversed_ = reversed; }
   int list_count() { return list_count_; }
   int inst_count(InstOp op) { return inst_count_[op]; }
+  uint16_t* list_heads() { return list_heads_.data(); }
   void set_dfa_mem(int64_t dfa_mem) { dfa_mem_ = dfa_mem; }
   int64_t dfa_mem() { return dfa_mem_; }
   int flags() { return flags_; }
@@ -312,7 +313,8 @@
                      StringPiece* match, int nmatch);
 
   // Bit-state backtracking.  Fast on small cases but uses memory
-  // proportional to the product of the program size and the text size.
+  // proportional to the product of the list count and the text size.
+  bool CanBitState() { return list_heads_.data() != NULL; }
   bool SearchBitState(const StringPiece& text, const StringPiece& context,
                       Anchor anchor, MatchKind kind,
                       StringPiece* match, int nmatch);
@@ -403,10 +405,12 @@
   int first_byte_;          // required first byte for match, or -1 if none
   int flags_;               // regexp parse flags
 
-  int list_count_;            // count of lists (see above)
-  int inst_count_[kNumInst];  // count of instructions by opcode
+  int list_count_;                 // count of lists (see above)
+  int inst_count_[kNumInst];       // count of instructions by opcode
+  PODArray<uint16_t> list_heads_;  // sparse array enumerating list heads
+                                   // not populated if size_ is overly large
 
-  PODArray<Inst> inst_;     // pointer to instruction array
+  PODArray<Inst> inst_;              // pointer to instruction array
   PODArray<uint8_t> onepass_nodes_;  // data for OnePass nodes
 
   int64_t dfa_mem_;         // Maximum memory for DFAs.
diff --git a/re2/re2.cc b/re2/re2.cc
index f101f74..1c85d3d 100644
--- a/re2/re2.cc
+++ b/re2/re2.cc
@@ -645,15 +645,13 @@
 
   bool can_one_pass = (is_one_pass_ && ncap <= Prog::kMaxOnePassCapture);
 
-  // SearchBitState allocates a bit vector of size prog_->size() * text.size().
+  // BitState allocates a bitmap of size prog_->list_count() * text.size().
   // It also allocates a stack of 3-word structures which could potentially
-  // grow as large as prog_->size() * text.size() but in practice is much
-  // smaller.
-  // Conditions for using SearchBitState:
-  const int MaxBitStateProg = 500;   // prog_->size() <= Max.
-  const int MaxBitStateVector = 256*1024;  // bit vector size <= Max (bits)
-  bool can_bit_state = prog_->size() <= MaxBitStateProg;
-  size_t bit_state_text_max = MaxBitStateVector / prog_->size();
+  // grow as large as prog_->list_count() * text.size(), but in practice is
+  // much smaller.
+  const int kMaxBitStateBitmapSize = 256*1024;  // bitmap size <= max (bits)
+  bool can_bit_state = prog_->CanBitState();
+  size_t bit_state_text_max = kMaxBitStateBitmapSize / prog_->list_count();
 
   bool dfa_failed = false;
   switch (re_anchor) {
diff --git a/re2/testing/regexp_benchmark.cc b/re2/testing/regexp_benchmark.cc
index 8b82e0b..68ab6d8 100644
--- a/re2/testing/regexp_benchmark.cc
+++ b/re2/testing/regexp_benchmark.cc
@@ -34,6 +34,7 @@
   Prog* prog = re->CompileToProg(0);
   CHECK(prog);
   CHECK(prog->IsOnePass());
+  CHECK(prog->CanBitState());
   const char* text = "650-253-0001";
   StringPiece sp[4];
   CHECK(prog->SearchOnePass(text, text, Prog::kAnchored, Prog::kFullMatch, sp, 4));
@@ -61,6 +62,7 @@
     Prog* prog = re->CompileToProg(0);
     CHECK(prog);
     CHECK(prog->IsOnePass());
+    CHECK(prog->CanBitState());
     fprintf(stderr, "Prog:   %7lld bytes (peak=%lld)\n", mc.HeapGrowth(), mc.PeakHeapGrowth());
     mc.Reset();
 
@@ -932,6 +934,7 @@
     CHECK(re);
     Prog* prog = re->CompileToProg(0);
     CHECK(prog);
+    CHECK(prog->CanBitState());
     CHECK_EQ(prog->SearchBitState(text, text, anchor, Prog::kFirstMatch, NULL, 0),
              expect_match);
     delete prog;
@@ -1019,6 +1022,7 @@
   CHECK(re);
   Prog* prog = re->CompileToProg(0);
   CHECK(prog);
+  CHECK(prog->CanBitState());
   for (int i = 0; i < iters; i++)
     CHECK_EQ(prog->SearchBitState(text, text, anchor, Prog::kFirstMatch, NULL, 0),
              expect_match);
@@ -1088,6 +1092,7 @@
     CHECK(re);
     Prog* prog = re->CompileToProg(0);
     CHECK(prog);
+    CHECK(prog->CanBitState());
     StringPiece sp[4];  // 4 because sp[0] is whole match.
     CHECK(prog->SearchBitState(text, text, Prog::kAnchored, Prog::kFullMatch, sp, 4));
     delete prog;
@@ -1158,6 +1163,7 @@
   CHECK(re);
   Prog* prog = re->CompileToProg(0);
   CHECK(prog);
+  CHECK(prog->CanBitState());
   StringPiece sp[4];  // 4 because sp[0] is whole match.
   for (int i = 0; i < iters; i++)
     CHECK(prog->SearchBitState(text, text, Prog::kAnchored, Prog::kFullMatch, sp, 4));
@@ -1233,6 +1239,7 @@
     CHECK(re);
     Prog* prog = re->CompileToProg(0);
     CHECK(prog);
+    CHECK(prog->CanBitState());
     StringPiece sp[2];  // 2 because sp[0] is whole match.
     CHECK(prog->SearchBitState(text, text, Prog::kAnchored, Prog::kFullMatch, sp, 2));
     delete prog;
@@ -1290,6 +1297,7 @@
   CHECK(re);
   Prog* prog = re->CompileToProg(0);
   CHECK(prog);
+  CHECK(prog->CanBitState());
   StringPiece sp[2];  // 2 because sp[0] is whole match.
   for (int i = 0; i < iters; i++)
     CHECK(prog->SearchBitState(text, text, Prog::kAnchored, Prog::kFullMatch, sp, 2));
diff --git a/re2/testing/tester.cc b/re2/testing/tester.cc
index c37aada..66d8d4f 100644
--- a/re2/testing/tester.cc
+++ b/re2/testing/tester.cc
@@ -364,8 +364,8 @@
 
     case kEngineOnePass:
       if (prog_ == NULL ||
-          anchor == Prog::kUnanchored ||
           !prog_->IsOnePass() ||
+          anchor == Prog::kUnanchored ||
           nsubmatch > Prog::kMaxOnePassCapture) {
         result->skipped = true;
         break;
@@ -376,7 +376,8 @@
       break;
 
     case kEngineBitState:
-      if (prog_ == NULL) {
+      if (prog_ == NULL ||
+          !prog_->CanBitState()) {
         result->skipped = true;
         break;
       }