Alex's Anthology of Algorithms Common Code for Contests in Concise C++
Strings / Suffix Arrays and LCP

3.4.2 Suffix Array and LCP (Counting Sort)

3-Strings/3.4.2_Suffix_Array_and_LCP_(Counting_Sort).cpp

View on GitHub

Given a string $s$, a suffix array is the array of the smallest starting positions for the sorted suffixes of $s$. That is, the $i$-th position of the suffix array stores the starting position of the $i$-th lexicographically smallest suffix of $s$. For example, $s$ = "cab" has the suffixes "cab", "ab", and "b". When sorted, the indices of the suffixes are "ab", "b", and "cab", so the suffix array (assuming 0-based indices) is $[1, 2, 0]$.

For a string $s$ of length $n$, the longest common prefix (LCP) array of length $n - 1$ stores the lengths of the longest common prefixes between all pairs of lexicographically adjacent suffixes in $s$. For example, "baa" has the sorted suffixes "a", "aa", and "baa", with an LCP array of $[1, 0]$.

The original Manber-Myers doubling algorithm comparison-sorts the suffixes by their first $2^k$ characters for increasing $k$: each round orders suffixes by their pair of ranks from the previous round, so a comparison costs $O(1)$ and the full order emerges after $O(\log n)$ rounds. Replacing the comparison sort of each round with counting-sort-style rank updates removes a logarithmic factor.

  • SuffixArrayCountingSort(s) constructs a suffix array from string s.
  • get_sa() returns the constructed suffix array.
  • get_lcp() returns the corresponding LCP array for the suffix array.
  • find(needle) returns one position that needle occurs in s (not necessarily the first), or std::string::npos if it cannot be found. For a needle of length $m$, this implementation uses an $O(m \log n)$ binary search, but can be optimized to $O(m + \log n)$ by first computing the LCP-LR array using the LCP array.

Implementation

#include <algorithm>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
using std::string;

class SuffixArrayCountingSort {
  string s;
  std::vector<int> sa, rk;

 public:
  explicit SuffixArrayCountingSort(const string &s) : s(s), sa(s.size()), rk(s.size()) {
    int n = static_cast<int>(s.size());
    std::iota(sa.rbegin(), sa.rend(), 0);
    for (int i = 0; i < n; i++) {
      rk[i] = static_cast<unsigned char>(s[i]);
    }
    std::stable_sort(sa.begin(), sa.end(), [&](int i, int j) {
      return static_cast<unsigned char>(s[i]) < static_cast<unsigned char>(s[j]);
    });
    for (int gap = 1; gap < n; gap *= 2) {
      std::vector<int> prev_rk(rk), prev_sa(sa), cnt(n);
      std::iota(cnt.begin(), cnt.end(), 0);
      for (int i = 0; i < n; i++) {
        rk[sa[i]] = (i > 0 && prev_rk[sa[i - 1]] == prev_rk[sa[i]] && sa[i - 1] + gap < n &&
                     prev_rk[sa[i - 1] + gap / 2] == prev_rk[sa[i] + gap / 2])
                        ? rk[sa[i - 1]]
                        : i;
      }
      for (int i = 0; i < n; i++) {
        int s1 = prev_sa[i] - gap;
        if (s1 >= 0) {
          sa[cnt[rk[s1]]++] = s1;
        }
      }
    }
  }

  const std::vector<int> &get_sa() const { return sa; }

  std::vector<int> get_lcp() const {
    int n = static_cast<int>(s.size());
    if (n == 0) {
      return {};  // Avoid constructing a vector of size (size_t)(-1).
    }
    std::vector<int> lcp(n - 1);
    for (int i = 0, k = 0; i < n; i++) {
      if (rk[i] < n - 1) {
        int j = sa[rk[i] + 1];
        while (std::max(i, j) + k < n && s[i + k] == s[j + k]) {
          k++;
        }
        lcp[rk[i]] = k;
        if (k > 0) {
          k--;
        }
      }
    }
    return lcp;
  }

  size_t find(const string &needle) {
    if (needle.empty()) {
      return 0;
    }
    int lo = 0, hi = static_cast<int>(s.size()) - 1;
    while (lo <= hi) {
      int mid = lo + (hi - lo) / 2;
      int cmp = s.compare(sa[mid], needle.size(), needle);
      if (cmp < 0) {
        lo = mid + 1;
      } else if (cmp > 0) {
        hi = mid - 1;
      } else {
        return sa[mid];
      }
    }
    return string::npos;
  }
};

Example Usage

#include <cassert>
using namespace std;

int main() {
  SuffixArrayCountingSort sa("banana");
  vector<int> sarr = sa.get_sa(), lcp = sa.get_lcp();
  vector<int> sarr_expected{5, 3, 1, 0, 4, 2};
  vector<int> lcp_expected{1, 3, 0, 0, 2};
  assert(sarr == sarr_expected);
  assert(lcp == lcp_expected);
  assert(sa.find("ana") == 1);
  assert(sa.find("x") == string::npos);
  return 0;
}