Prepare to shard the DFA state cache mutex.

Change-Id: I78220d805bfb545ad4ceb34f3f928d50dc17d804
Reviewed-on: https://code-review.googlesource.com/c/36312
Reviewed-by: Paul Wankadia <junyer@google.com>
diff --git a/re2/dfa.cc b/re2/dfa.cc
index 9bc8499..2df10b2 100644
--- a/re2/dfa.cc
+++ b/re2/dfa.cc
@@ -36,6 +36,7 @@
 #include <utility>
 #include <vector>
 
+#include "util/util.h"
 #include "util/logging.h"
 #include "util/mix.h"
 #include "util/mutex.h"
@@ -267,7 +268,7 @@
     bool run_forward;
     State* start;
     int first_byte;
-    RWLocker *cache_lock;
+    RWLocker* cache_lock;
     bool failed;     // "out" parameter: whether search gave up
     const char* ep;  // "out" parameter: end pointer for match
     SparseSet* matches;
@@ -351,18 +352,32 @@
   int nastack_;
 
   // State* cache.  Many threads use and add to the cache simultaneously,
-  // holding cache_mutex_ for reading and mutex_ (above) when adding.
+  // holding one cache_mutex_ for reading and mutex_ (above) when adding.
   // If the cache fills and needs to be discarded, the discarding is done
-  // while holding cache_mutex_ for writing, to avoid interrupting other
-  // readers.  Any State* pointers are only valid while cache_mutex_
-  // is held.
-  Mutex cache_mutex_;
+  // while holding every cache_mutex_ for writing to avoid disrupting any
+  // readers.  Any State* is valid only while one (or every) cache_mutex_
+  // is held by the thread using the State*.
+  class alignas(CACHELINE_SIZE) AlignedMutex : public Mutex {};
+  AlignedMutex* cache_mutex_;
   int64_t mem_budget_;     // Total memory budget for all States.
   int64_t state_budget_;   // Amount of memory remaining for new States.
   StateSet state_cache_;   // All States computed so far.
   StartInfo start_[kMaxStart];
+
+  // Until we can use C++17, we must handle the alignment ourselves. :(
+  // Someday, std::unique_ptr<AlignedMutex[]> will be quite sufficient.
+  char* cache_mutex_storage_;
+  int cache_mutex_count_;
 };
 
+template <typename T>
+static inline T* Align(T* base, size_t align) {
+  intptr_t tmp = reinterpret_cast<intptr_t>(base);
+  tmp += align - 1;
+  tmp &= ~(align - 1);
+  return reinterpret_cast<T*>(tmp);
+}
+
 // Shorthand for casting to uint8_t*.
 static inline const uint8_t* BytePtr(const void* v) {
   return reinterpret_cast<const uint8_t*>(v);
@@ -442,7 +457,9 @@
     q0_(NULL),
     q1_(NULL),
     astack_(NULL),
-    mem_budget_(max_mem) {
+    mem_budget_(max_mem),
+    cache_mutex_storage_(NULL),
+    cache_mutex_count_(1) {
   if (ExtraDebug)
     fprintf(stderr, "\nkind %d\n%s\n", (int)kind_, prog_->DumpUnanchored().c_str());
   int nmark = 0;
@@ -454,11 +471,13 @@
              prog_->inst_count(kInstNop) +
              nmark + 1;  // + 1 for start inst
 
-  // Account for space needed for DFA, q0, q1, astack.
+  // Account for memory needed for DFA, q0, q1, astack, cache mutexes.
   mem_budget_ -= sizeof(DFA);
   mem_budget_ -= (prog_->size() + nmark) *
                  (sizeof(int)+sizeof(int)) * 2;  // q0, q1
   mem_budget_ -= nastack_ * sizeof(int);  // astack
+  mem_budget_ -= alignof(AlignedMutex) +
+                 sizeof(AlignedMutex) * cache_mutex_count_;  // cache mutexes
   if (mem_budget_ < 0) {
     init_failed_ = true;
     return;
@@ -483,12 +502,21 @@
   q0_ = new Workq(prog_->size(), nmark);
   q1_ = new Workq(prog_->size(), nmark);
   astack_ = new int[nastack_];
+
+  char* storage = new char[alignof(AlignedMutex) +
+                           sizeof(AlignedMutex) * cache_mutex_count_];
+  cache_mutex_storage_ = storage;
+  cache_mutex_ = new (Align(storage, alignof(AlignedMutex)))
+      AlignedMutex[cache_mutex_count_]();
 }
 
 DFA::~DFA() {
-  delete q0_;
-  delete q1_;
+  for (int i = 0; i < cache_mutex_count_; i++)
+    cache_mutex_[i].~AlignedMutex();
+  delete[] cache_mutex_storage_;
   delete[] astack_;
+  delete q1_;
+  delete q0_;
   ClearCache();
 }
 
@@ -770,8 +798,8 @@
   mem_budget_ -= mem + kStateCacheOverhead;
 
   // Allocate new state along with room for next_ and inst_.
-  char* space = std::allocator<char>().allocate(mem);
-  State* s = new (space) State;
+  char* storage = std::allocator<char>().allocate(mem);
+  State* s = new (storage) State;
   (void) new (s->next_) std::atomic<State*>[nnext];
   // Work around a unfortunate bug in older versions of libstdc++.
   // (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=64658)
@@ -1111,7 +1139,7 @@
 
 class DFA::RWLocker {
  public:
-  explicit RWLocker(Mutex* mu);
+  explicit RWLocker(DFA* dfa);
   ~RWLocker();
 
   // If the lock is only held for reading right now,
@@ -1121,35 +1149,38 @@
   void LockForWriting();
 
  private:
-  Mutex* mu_;
+  DFA* dfa_;
+  int which_;
   bool writing_;
 
   RWLocker(const RWLocker&) = delete;
   RWLocker& operator=(const RWLocker&) = delete;
 };
 
-DFA::RWLocker::RWLocker(Mutex* mu) : mu_(mu), writing_(false) {
-  mu_->ReaderLock();
+DFA::RWLocker::RWLocker(DFA* dfa) : dfa_(dfa), which_(0), writing_(false) {
+  dfa_->cache_mutex_[which_].ReaderLock();
 }
 
 // This function is marked as NO_THREAD_SAFETY_ANALYSIS because the annotations
 // does not support lock upgrade.
 void DFA::RWLocker::LockForWriting() NO_THREAD_SAFETY_ANALYSIS {
   if (!writing_) {
-    mu_->ReaderUnlock();
-    mu_->WriterLock();
+    dfa_->cache_mutex_[which_].ReaderUnlock();
+    for (int i = 0; i < dfa_->cache_mutex_count_; i++)
+      dfa_->cache_mutex_[i].WriterLock();
     writing_ = true;
   }
 }
 
 DFA::RWLocker::~RWLocker() {
-  if (!writing_)
-    mu_->ReaderUnlock();
-  else
-    mu_->WriterUnlock();
+  if (!writing_) {
+    dfa_->cache_mutex_[which_].ReaderUnlock();
+  } else {
+    for (int i = 0; i < dfa_->cache_mutex_count_; i++)
+      dfa_->cache_mutex_[i].WriterUnlock();
+  }
 }
 
-
 // When the DFA's State cache fills, we discard all the states in the
 // cache and start over.  Many threads can be using and adding to the
 // cache at the same time, so we synchronize using the cache_mutex_
@@ -1775,7 +1806,7 @@
             run_forward, kind_);
   }
 
-  RWLocker l(&cache_mutex_);
+  RWLocker l(this);
   SearchParams params(text, context, &l);
   params.anchored = anchored;
   params.want_earliest_match = want_earliest_match;
@@ -1925,7 +1956,7 @@
 
   // Pick out start state for unanchored search
   // at beginning of text.
-  RWLocker l(&cache_mutex_);
+  RWLocker l(this);
   SearchParams params(StringPiece(), StringPiece(), &l);
   params.anchored = false;
   if (!AnalyzeSearch(&params) ||
@@ -2017,7 +2048,7 @@
   std::unordered_map<State*, int> previously_visited_states;
 
   // Pick out start state for anchored search at beginning of text.
-  RWLocker l(&cache_mutex_);
+  RWLocker l(this);
   SearchParams params(StringPiece(), StringPiece(), &l);
   params.anchored = true;
   if (!AnalyzeSearch(&params))
diff --git a/util/util.h b/util/util.h
index 33d100a..4233956 100644
--- a/util/util.h
+++ b/util/util.h
@@ -21,6 +21,10 @@
 #endif
 #endif
 
+#ifndef CACHELINE_SIZE
+#define CACHELINE_SIZE 64
+#endif
+
 #ifndef FALLTHROUGH_INTENDED
 #if defined(__clang__)
 #define FALLTHROUGH_INTENDED [[clang::fallthrough]]