Compute first_byte using the Regexp, not the Prog.

Change-Id: Ic893767c752afee4f41d8d39805351cf91c883e7
Reviewed-on: https://code-review.googlesource.com/c/re2/+/56214
Reviewed-by: Paul Wankadia <junyer@google.com>
diff --git a/re2/compile.cc b/re2/compile.cc
index b3e5bdc..2f64a80 100644
--- a/re2/compile.cc
+++ b/re2/compile.cc
@@ -225,8 +225,8 @@
   // Single rune.
   Frag Literal(Rune r, bool foldcase);
 
-  void Setup(Regexp::ParseFlags, int64_t, RE2::Anchor);
-  Prog* Finish();
+  void Setup(Regexp::ParseFlags flags, int64_t max_mem, RE2::Anchor anchor);
+  Prog* Finish(Regexp* re);
 
   // Returns .* where dot = any byte
   Frag DotStar();
@@ -1167,10 +1167,10 @@
   c.prog_->set_start_unanchored(all.begin);
 
   // Hand ownership of prog_ to caller.
-  return c.Finish();
+  return c.Finish(re);
 }
 
-Prog* Compiler::Finish() {
+Prog* Compiler::Finish(Regexp* re) {
   if (failed_)
     return NULL;
 
@@ -1186,7 +1186,13 @@
   prog_->Optimize();
   prog_->Flatten();
   prog_->ComputeByteMap();
-  prog_->ComputeFirstByte();
+
+  // Compute first byte.
+  std::string prefix;
+  bool prefix_foldcase;
+  if (re->RequiredPrefixUnanchored(&prefix, &prefix_foldcase) &&
+      !prefix_foldcase)
+    prog_->set_first_byte(prefix[0]);
 
   // Record remaining memory for DFA.
   if (max_mem_ <= 0) {
@@ -1244,7 +1250,7 @@
   c.prog_->set_start(all.begin);
   c.prog_->set_start_unanchored(all.begin);
 
-  Prog* prog = c.Finish();
+  Prog* prog = c.Finish(re);
   if (prog == NULL)
     return NULL;
 
diff --git a/re2/nfa.cc b/re2/nfa.cc
index 75ca306..720f3e0 100644
--- a/re2/nfa.cc
+++ b/re2/nfa.cc
@@ -629,67 +629,6 @@
   return false;
 }
 
-void Prog::ComputeFirstByte() {
-  SparseSet q(size());
-  q.insert(start());
-  for (SparseSet::iterator it = q.begin(); it != q.end(); ++it) {
-    int id = *it;
-    Prog::Inst* ip = inst(id);
-    switch (ip->opcode()) {
-      default:
-        LOG(DFATAL) << "unhandled " << ip->opcode() << " in ComputeFirstByte";
-        break;
-
-      case kInstMatch:
-        // The empty string matches: no first byte.
-        first_byte_ = -1;
-        return;
-
-      case kInstByteRange:
-        if (!ip->last())
-          q.insert(id+1);
-
-        // Must match only a single byte.
-        if (ip->lo() != ip->hi() ||
-            (ip->foldcase() && 'a' <= ip->lo() && ip->lo() <= 'z')) {
-          first_byte_ = -1;
-          return;
-        }
-        // If we haven't seen any bytes yet, record it;
-        // otherwise must match the one we saw before.
-        if (first_byte_ == -1) {
-          first_byte_ = ip->lo();
-        } else if (first_byte_ != ip->lo()) {
-          first_byte_ = -1;
-          return;
-        }
-        break;
-
-      case kInstNop:
-      case kInstCapture:
-      case kInstEmptyWidth:
-        if (!ip->last())
-          q.insert(id+1);
-
-        // Continue on.
-        // Ignore ip->empty() flags for kInstEmptyWidth
-        // in order to be as conservative as possible
-        // (assume all possible empty-width flags are true).
-        if (ip->out())
-          q.insert(ip->out());
-        break;
-
-      case kInstAltMatch:
-        DCHECK(!ip->last());
-        q.insert(id+1);
-        break;
-
-      case kInstFail:
-        break;
-    }
-  }
-}
-
 bool
 Prog::SearchNFA(const StringPiece& text, const StringPiece& context,
                 Anchor anchor, MatchKind kind,
diff --git a/re2/prog.h b/re2/prog.h
index 4306672..cac107c 100644
--- a/re2/prog.h
+++ b/re2/prog.h
@@ -198,8 +198,8 @@
 
   Inst *inst(int id) { return &inst_[id]; }
   int start() { return start_; }
-  int start_unanchored() { return start_unanchored_; }
   void set_start(int start) { start_ = start; }
+  int start_unanchored() { return start_unanchored_; }
   void set_start_unanchored(int start) { start_unanchored_ = start; }
   int size() { return size_; }
   bool reversed() { return reversed_; }
@@ -207,8 +207,8 @@
   int list_count() { return list_count_; }
   int inst_count(InstOp op) { return inst_count_[op]; }
   uint16_t* list_heads() { return list_heads_.data(); }
-  void set_dfa_mem(int64_t dfa_mem) { dfa_mem_ = dfa_mem; }
   int64_t dfa_mem() { return dfa_mem_; }
+  void set_dfa_mem(int64_t dfa_mem) { dfa_mem_ = dfa_mem; }
   bool anchor_start() { return anchor_start_; }
   void set_anchor_start(bool b) { anchor_start_ = b; }
   bool anchor_end() { return anchor_end_; }
@@ -216,6 +216,7 @@
   int bytemap_range() { return bytemap_range_; }
   const uint8_t* bytemap() { return bytemap_; }
   int first_byte() { return first_byte_; }
+  void set_first_byte(int first_byte) { first_byte_ = first_byte; }
 
   // Returns string representation of program for debugging.
   std::string Dump();
@@ -293,9 +294,6 @@
   // Compute bytemap.
   void ComputeByteMap();
 
-  // Computes whether all matches must begin with the same first byte.
-  void ComputeFirstByte();
-
   // Run peep-hole optimizer on program.
   void Optimize();