Alex's Anthology of Algorithms Common Code for Contests in Concise C++
Graphs / Spanning Trees

4.4.1 Minimum Spanning Tree (Prim)

4-Graphs/4.4.1_Minimum_Spanning_Tree_(Prim).cpp

View on GitHub

Given a connected, undirected, weighted graph with possibly negative weights, its minimum spanning tree (MST) is a subgraph which is a tree that connects all nodes with a subset of its edges such that their total weight is minimized. If the input graph is not connected, then this implementation will find the minimum spanning forest.

Prim's algorithm grows the tree from an arbitrary start node, repeatedly adding the minimum-weight edge that joins a new node to the current tree, with a priority queue supplying the cheapest such edge at each step.

  • prim_mst() populates mst with the edge IDs in the minimum spanning tree (returning the total MST weight) for a global, bidirectionally pre-populated adjacency list adj which must consist of nodes numbered [0, n), where n is adj.size(). Edges are stored as (neighbor, weight, edge_id).

The priority queue stores candidate edges as (weight, from, to, edge_id) and uses std::greater to make it a min-heap. To find a maximum spanning tree instead, use the default max-heap ordering.

Implementation

#include <functional>
#include <queue>
#include <tuple>
#include <utility>
#include <vector>

std::vector<std::vector<std::tuple<int, int, int>>> adj;  // adj[u] = {(v, weight, edge_id), ...}
std::vector<int> mst;

int prim_mst() {
  int n = static_cast<int>(adj.size());
  mst.clear();
  std::vector<bool> visit(n);
  int total_dist = 0;
  for (int i = 0; i < n; i++) {
    if (visit[i]) {
      continue;
    }
    visit[i] = true;
    using qnode = std::tuple<int, int, int, int>;  // (weight, from, to, edge_id)
    std::priority_queue<qnode, std::vector<qnode>, std::greater<qnode>> pq;
    for (auto &[v, w, id] : adj[i]) {
      pq.emplace(w, i, v, id);
    }
    while (!pq.empty()) {
      auto [w, u, v, id] = pq.top();
      pq.pop();
      if (visit[u] && !visit[v]) {
        visit[v] = true;
        if (v != i) {
          mst.push_back(id);
          total_dist += w;
        }
        for (auto &[nb, ew, eid] : adj[v]) {
          pq.emplace(ew, v, nb, eid);
        }
      }
    }
  }
  return total_dist;
}

Example Usage

#include <iostream>
using namespace std;

vector<pair<int, int>> edge_endpoints;

void add_edge(int u, int v, int w) {
  int id = static_cast<int>(edge_endpoints.size());
  edge_endpoints.push_back({u, v});
  adj[u].push_back({v, w, id});
  adj[v].push_back({u, w, id});
}

int main() {
  int nodes = 7;
  adj.assign(nodes, {});
  add_edge(0, 1, 4);
  add_edge(1, 2, 6);
  add_edge(2, 0, 3);
  add_edge(3, 4, 1);
  add_edge(4, 5, 2);
  add_edge(5, 6, 3);
  add_edge(6, 4, 4);
  cout << "Total distance: " << prim_mst() << endl;
  for (int id : mst) {
    auto &[u, v] = edge_endpoints[id];
    cout << u << " <-> " << v << endl;
  }
  return 0;
}

Example Output

Total distance: 13
0 <-> 2
0 <-> 1
3 <-> 4
4 <-> 5
5 <-> 6