Alex's Anthology of Algorithms Common Code for Contests in Concise C++
Data Structures / Dictionaries and Ordered Sets

2.3.9 Skip List

2-Data-Structures/2.3.9_Skip_List.cpp

View on GitHub

Maintain an ordered map, that is, an ordered collection of key-value pairs such that each possible key appears at most once in the collection. This implementation requires operators < and == to be defined on the key type. A skip list maintains a linked hierarchy of sorted subsequences with each successive subsequence skipping over fewer elements than the previous one. Each new node joins a number of levels decided by repeated coin flips, and a search starts at the sparsest level, moving forward until the next key would overshoot and then dropping down a level, which makes operations take $O(\log n)$ with high probability.

  • SkipList<K, V>() constructs an empty map.
  • size() returns the size of the map.
  • empty() returns whether the map is empty.
  • insert(k, v) adds an entry with key k and value v to the map, returning true if a new entry was added or false if the key already exists (in which case the map is unchanged and the old value associated with the key is preserved).
  • erase(k) removes the entry with key k from the map, returning true if the removal was successful or false if the key to be removed was not found.
  • find(k) returns a pointer to the value associated with key k, or nullptr if the key was not found.
  • operator[k] returns a reference to key k's associated value (which may be modified), or if necessary, inserts and returns a new entry with the default constructed value if key k was not originally found.
  • entries() returns all key-value entries in ascending order of keys.

Implementation

#include <random>
#include <utility>
#include <vector>

template<class K, class V>
class SkipList {
  static const int MAX_LEVELS = 32;  // log2(max possible keys)

  struct Node {
    K key;
    V value;
    std::vector<Node *> next;

    Node(const K &k, const V &v, int levels) : key(k), value(v), next(levels, (Node *)nullptr) {}
  } *head;

  int num_nodes;

  static int random_level() {
    static std::mt19937 rng(std::random_device{}());
    static std::uniform_int_distribution<int> coin(0, 1);
    int level = 1;
    while (coin(rng) && level < MAX_LEVELS) {
      level++;
    }
    return level;
  }

  static int node_level(const std::vector<Node *> &v) {
    int i = 0;
    while (i < static_cast<int>(v.size()) && v[i] != nullptr) {
      i++;
    }
    return i + 1;
  }

 public:
  SkipList() : head(new Node(K(), V(), MAX_LEVELS)), num_nodes(0) {
    for (auto &ptr : head->next) {
      ptr = nullptr;
    }
  }

  ~SkipList() { delete head; }
  SkipList(const SkipList &) = delete;
  SkipList &operator=(const SkipList &) = delete;
  int size() const { return num_nodes; }
  bool empty() const { return num_nodes == 0; }

  bool insert(const K &k, const V &v) {
    std::vector<Node *> update(head->next);
    int curr_level = node_level(update);
    Node *n = head;
    for (int i = curr_level; i-- > 0;) {
      while (n->next[i] != nullptr && n->next[i]->key < k) {
        n = n->next[i];
      }
      update[i] = n;
    }
    n = n->next[0];
    if (n != nullptr && n->key == k) {
      return false;
    }
    int new_level = random_level();
    if (new_level > curr_level) {
      for (int i = curr_level; i < new_level; i++) {
        update[i] = head;
      }
    }
    n = new Node(k, v, new_level);
    for (int i = 0; i < new_level; i++) {
      n->next[i] = update[i]->next[i];
      update[i]->next[i] = n;
    }
    num_nodes++;
    return true;
  }

  bool erase(const K &k) {
    std::vector<Node *> update(head->next);
    Node *n = head;
    for (int i = node_level(update); i-- > 0;) {
      while (n->next[i] != nullptr && n->next[i]->key < k) {
        n = n->next[i];
      }
      update[i] = n;
    }
    n = n->next[0];
    if (n != nullptr && n->key == k) {
      for (int i = 0, levels = static_cast<int>(update.size()); i < levels; i++) {
        if (update[i]->next[i] != n) {
          break;
        }
        update[i]->next[i] = n->next[i];
      }
      delete n;
      num_nodes--;
      return true;
    }
    return false;
  }

  V *find(const K &k) const {
    Node *n = head;
    for (int i = node_level(n->next); i-- > 0;) {
      while (n->next[i] != nullptr && n->next[i]->key < k) {
        n = n->next[i];
      }
    }
    n = n->next[0];
    return (n != nullptr && n->key == k) ? &(n->value) : nullptr;
  }

  V &operator[](const K &k) {
    V *ptr = find(k);
    if (ptr != nullptr) {
      return *ptr;
    }
    insert(k, V());
    return *find(k);
  }

  std::vector<std::pair<K, V>> entries() const {
    std::vector<std::pair<K, V>> res;
    res.reserve(num_nodes);
    Node *n = head->next[0];
    while (n != nullptr) {
      res.push_back({n->key, n->value});
      n = n->next[0];
    }
    return res;
  }
};

Example Usage

#include <cassert>
#include <iostream>
using namespace std;

int main() {
  SkipList<int, char> l;
  l.insert(2, 'b');
  l.insert(1, 'a');
  l.insert(3, 'c');
  l.insert(5, 'e');
  assert(l.insert(4, 'd'));
  assert(*l.find(4) == 'd');
  assert(!l.insert(4, 'd'));
  for (const auto &[k, v] : l.entries()) {
    cout << v;
  }
  cout << endl;
  assert(l.erase(1));
  assert(!l.erase(1));
  assert(l.find(1) == nullptr);
  for (const auto &[k, v] : l.entries()) {
    cout << v;
  }
  cout << endl;
  return 0;
}

Example Output

abcde
bcde