Alex's Anthology of Algorithms Common Code for Contests in Concise C++
Strings / Pattern Matching

3.3.1 String Searching (KMP)

3-Strings/3.3.1_String_Searching_(KMP).cpp

View on GitHub

Given a single string (needle) and subsequent queries of texts (haystacks) to be searched, determine the first positions in which the needle occurs within the given haystacks in linear time using the Knuth-Morris-Pratt algorithm. In comparison, std::string::find runs in $O(n^{2})$ in the worst case. KMP precomputes, for every prefix of the needle, the length of its longest proper border (a prefix that is also a suffix). On a mismatch, the needle falls back to that border instead of restarting, so the position in the haystack never moves backward.

  • KMP(needle) constructs the partial match table for a string needle that is to be searched for subsequently in haystack queries.
  • find_in(haystack) returns the first position that needle occurs in haystack, or std::string::npos if it cannot be found. Note that the function can be modified to return all matches by simply letting the loop run and storing the results instead of returning early.

Implementation

#include <string>
#include <vector>
using std::string;

class KMP {
  string needle;
  std::vector<int> table;

 public:
  explicit KMP(const string &needle) : needle(needle), table(needle.size()) {
    for (int i = 1, j = 0; i < static_cast<int>(needle.size()); i++) {
      while (j > 0 && needle[i] != needle[j]) {
        j = table[j - 1];
      }
      if (needle[i] == needle[j]) {
        j++;
      }
      table[i] = j;
    }
  }

  size_t find_in(const string &haystack) {
    int m = static_cast<int>(needle.size());
    if (m == 0) {
      return 0;
    }
    for (int i = 0, j = 0; i < static_cast<int>(haystack.size()); i++) {
      while (j > 0 && needle[j] != haystack[i]) {
        j = table[j - 1];
      }
      if (needle[j] == haystack[i]) {
        j++;
      }
      if (j == m) {
        return i + 1 - m;
      }
    }
    return string::npos;
  }
};

Example Usage

#include <cassert>

int main() {
  assert(15 == KMP("ABCDABD").find_in("ABC ABCDAB ABCDABCDABDE"));
  return 0;
}