Alex's Anthology of Algorithms Common Code for Contests in Concise C++
Data Structures / Dictionaries and Ordered Sets

2.3.6 Interval Treap

2-Data-Structures/2.3.6_Interval_Treap.cpp

View on GitHub

Maintain an ordered map from closed, one-dimensional intervals to values while supporting efficient reporting of any or all entries that intersect with a given query interval. A treap is used to process the entries, where keys are compared lexicographically as pairs. Each node is additionally augmented with the maximum upper endpoint among the intervals in its subtree, letting intersection queries skip any subtree whose maximum falls below the query interval's lower bound.

This implementation uses std::pair to represent intervals, requiring operators < and == to be defined on the numeric key type K.

  • IntervalTreap<K, V>() constructs an empty map.
  • size() returns the size of the map.
  • empty() returns whether the map is empty.
  • insert(lo, hi, v) adds an entry with key [lo, hi] and value v to the map, returning true if a new interval was added or false if the interval already exists (in which case the map is unchanged and the old value associated with the key is preserved).
  • erase(lo, hi) removes the entry with key [lo, hi] from the map, returning true if the removal was successful or false if the interval was not found.
  • find_key(lo, hi) returns a pointer to a const std::pair representing the key of some interval in the map which intersects with [lo, hi], or nullptr if no such entry was found.
  • find_value(lo, hi) returns a pointer to a const value of some entry in the map with a key that intersects with [lo, hi], or nullptr if no such entry was found.
  • find_all(lo, hi, f) calls the function f(lo, hi, v) on each entry in the map that overlaps with [lo, hi], in lexicographically ascending order of intervals.
  • entries() returns all intervals and their values in lexicographically ascending order of intervals.

Implementation

#include <cstdint>
#include <tuple>
#include <utility>
#include <vector>

template<class K, class V>
class IntervalTreap {
  using Interval = std::pair<K, K>;

  struct Node {
    static uint32_t rand32() {
      static uint32_t x = 123456789;
      x ^= x << 13;
      x ^= x >> 17;
      x ^= x << 5;
      return x;
    }

    Interval interval;
    V value;
    K max;
    uint32_t priority;
    Node *left, *right;

    Node(const Interval &i, const V &v)
        : interval(i), value(v), max(i.second), priority(rand32()), left(nullptr), right(nullptr) {}

    void update() {
      max = interval.second;
      if (left != nullptr && left->max > max) {
        max = left->max;
      }
      if (right != nullptr && right->max > max) {
        max = right->max;
      }
    }
  } *root;

  int num_nodes;

  static void rotate_left(Node *&n) {
    Node *tmp = n;
    n = n->right;
    tmp->right = n->left;
    n->left = tmp;
    tmp->update();
  }

  static void rotate_right(Node *&n) {
    Node *tmp = n;
    n = n->left;
    tmp->left = n->right;
    n->right = tmp;
    tmp->update();
  }

  bool insert(Node *&n, const Interval &i, const V &v) {
    if (n == nullptr) {
      n = new Node(i, v);
      num_nodes++;
      return true;
    }
    if (i < n->interval && insert(n->left, i, v)) {
      if (n->left->priority < n->priority) {
        rotate_right(n);
      }
      n->update();
      return true;
    }
    if (i > n->interval && insert(n->right, i, v)) {
      if (n->right->priority < n->priority) {
        rotate_left(n);
      }
      n->update();
      return true;
    }
    return false;
  }

  bool erase(Node *&n, const Interval &i) {
    if (n == nullptr) {
      return false;
    }
    if (i < n->interval) {
      return erase(n->left, i);
    }
    if (i > n->interval) {
      return erase(n->right, i);
    }
    if (n->left != nullptr && n->right != nullptr) {
      bool res;
      if (n->left->priority < n->right->priority) {
        rotate_right(n);
        res = erase(n->right, i);
      } else {
        rotate_left(n);
        res = erase(n->left, i);
      }
      n->update();
      return res;
    }
    Node *tmp = (n->left != nullptr) ? n->left : n->right;
    delete n;
    n = tmp;
    num_nodes--;
    return true;
  }

  static Node *find_any(Node *n, const Interval &i) {
    if (n == nullptr) {
      return nullptr;
    }
    if (n->interval.first <= i.second && i.first <= n->interval.second) {
      return n;
    }
    if (n->left != nullptr && i.first <= n->left->max) {
      return find_any(n->left, i);
    }
    return find_any(n->right, i);
  }

  template<class Fn>
  static void find_all(Node *n, const Interval &i, Fn f) {
    if (n == nullptr || n->max < i.first) {
      return;
    }
    find_all(n->left, i, f);
    if (n->interval.first <= i.second && i.first <= n->interval.second) {
      f(n->interval.first, n->interval.second, n->value);
    }
    if (n->interval.first <= i.second) {
      find_all(n->right, i, f);
    }
  }

  static void collect_entries(Node *n, std::vector<std::tuple<K, K, V>> &res) {
    if (n != nullptr) {
      collect_entries(n->left, res);
      res.push_back({n->interval.first, n->interval.second, n->value});
      collect_entries(n->right, res);
    }
  }

  static void clean_up(Node *n) {
    if (n != nullptr) {
      clean_up(n->left);
      clean_up(n->right);
      delete n;
    }
  }

 public:
  IntervalTreap() : root(nullptr), num_nodes(0) {}

  ~IntervalTreap() { clean_up(root); }
  IntervalTreap(const IntervalTreap &) = delete;
  IntervalTreap &operator=(const IntervalTreap &) = delete;
  int size() const { return num_nodes; }
  bool empty() const { return root == nullptr; }
  bool insert(const K &lo, const K &hi, const V &v) { return insert(root, Interval{lo, hi}, v); }
  bool erase(const K &lo, const K &hi) { return erase(root, Interval{lo, hi}); }

  const Interval *find_key(const K &lo, const K &hi) const {
    Node *n = find_any(root, Interval{lo, hi});
    return (n == nullptr) ? nullptr : &(n->interval);
  }

  const V *find_value(const K &lo, const K &hi) const {
    Node *n = find_any(root, Interval{lo, hi});
    return (n == nullptr) ? nullptr : &(n->value);
  }

  template<class Fn>
  void find_all(const K &lo, const K &hi, Fn f) const {
    find_all(root, Interval{lo, hi}, f);
  }

  std::vector<std::tuple<K, K, V>> entries() const {
    std::vector<std::tuple<K, K, V>> res;
    res.reserve(num_nodes);
    collect_entries(root, res);
    return res;
  }
};

Example Usage

#include <cassert>
#include <iostream>
using namespace std;

int main() {
  IntervalTreap<int, char> t;
  t.insert(15, 20, 'a');
  t.insert(10, 30, 'b');
  t.insert(17, 19, 'c');
  t.insert(5, 20, 'd');
  t.insert(12, 15, 'e');
  t.insert(10, 40, 'f');
  assert(t.size() == 6);
  assert(!t.insert(5, 20, 'x'));
  t.erase(17, 19);
  assert(t.size() == 5);
  assert(*t.find_key(3, 9) == (pair<int, int>{5, 20}));
  assert(*t.find_value(3, 9) == 'd');
  cout << "Intervals intersecting [16, 20]:";
  auto print = [](int lo, int hi, char v) { cout << " [" << lo << ", " << hi << "]"; };
  t.find_all(16, 20, print);
  cout << "\nAll intervals:";
  for (const auto &[lo, hi, v] : t.entries()) {
    print(lo, hi, v);
  }
  cout << endl;
  return 0;
}

Example Output

Intervals intersecting [16, 20]: [5, 20] [10, 30] [10, 40] [15, 20]
All intervals: [5, 20] [10, 30] [10, 40] [12, 15] [15, 20]