Problems

Master the Union-Find (Disjoint Set Union) pattern by solving these carefully selected problems. Each problem demonstrates a different aspect of connectivity, from basic component counting to cycle detection and minimum spanning trees.

For a theoretical overview, check out the Union-Find Guide or the Code Templates.

Easy & Medium Problems

1. Number of Provinces

LeetCode 547

  • Brief: Given an adjacency matrix representing direct connections between cities, return the total number of provinces (connected components).
  • Why this pattern?: Union-Find is the most direct way to merge elements into disjoint sets and count the number of unique representatives.
  • Key Insight: For every isConnected[i][j] == 1, perform a union(i, j). The final answer is the number of distinct roots.
  • Visual:
      graph TD
        subgraph "Input Matrix"
        M[1 1 0 / 1 1 0 / 0 0 1]
        end
        subgraph "Resulting Components"
        A[City 0] --- B[City 1]
        C[City 2]
        end
    

JavaScript:

var findCircleNum = function(isConnected) {
    const n = isConnected.length;
    const parent = Array.from({length: n}, (_, i) => i);
    let components = n;

    function find(i) {
        if (parent[i] === i) return i;
        return parent[i] = find(parent[i]);
    }

    function union(i, j) {
        const rootI = find(i);
        const rootJ = find(j);
        if (rootI !== rootJ) {
            parent[rootI] = rootJ;
            components--;
        }
    }

    for (let i = 0; i < n; i++) {
        for (let j = i + 1; j < n; j++) {
            if (isConnected[i][j] === 1) {
                union(i, j);
            }
        }
    }

    return components;
};

Python:

class Solution:
    def findCircleNum(self, isConnected: List[List[int]]) -> int:
        n = len(isConnected)
        parent = list(range(n))
        count = n

        def find(i):
            if parent[i] == i:
                return i
            parent[i] = find(parent[i])
            return parent[i]

        def union(i, j):
            nonlocal count
            root_i, root_j = find(i), find(j)
            if root_i != root_j:
                parent[root_i] = root_j
                count -= 1

        for i in range(n):
            for j in range(i + 1, n):
                if isConnected[i][j] == 1:
                    union(i, j)
        return count

Java:

class Solution {
    private int[] parent;
    private int count;

    public int findCircleNum(int[][] isConnected) {
        int n = isConnected.length;
        parent = new int[n];
        count = n;
        for (int i = 0; i < n; i++) parent[i] = i;

        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j < n; j++) {
                if (isConnected[i][j] == 1) union(i, j);
            }
        }
        return count;
    }

    private int find(int i) {
        if (parent[i] == i) return i;
        return parent[i] = find(parent[i]);
    }

    private void union(int i, int j) {
        int rootI = find(i);
        int rootJ = find(j);
        if (rootI != rootJ) {
            parent[rootI] = rootJ;
            count--;
        }
    }
}

C++:

class Solution {
public:
    int findCircleNum(vector<vector<int>>& isConnected) {
        int n = isConnected.size();
        vector<int> parent(n);
        iota(parent.begin(), parent.end(), 0);
        int count = n;

        function<int(int)> find = [&](int i) {
            return parent[i] == i ? i : parent[i] = find(parent[i]);
        };

        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j < n; j++) {
                if (isConnected[i][j] == 1) {
                    int rootI = find(i);
                    int rootJ = find(j);
                    if (rootI != rootJ) {
                        parent[rootI] = rootJ;
                        count--;
                    }
                }
            }
        }
        return count;
    }
};

Go:

func findCircleNum(isConnected [][]int) int {
    n := len(isConnected)
    parent := make([]int, n)
    for i := range parent {
        parent[i] = i
    }
    count := n

    var find func(int) int
    find = func(i int) int {
        if parent[i] == i {
            return i
        }
        parent[i] = find(parent[i])
        return parent[i]
    }

    for i := 0; i < n; i++ {
        for j := i + 1; j < n; j++ {
            if isConnected[i][j] == 1 {
                rootI, rootJ := find(i), find(j)
                if rootI != rootJ {
                    parent[rootI] = rootJ
                    count--
                }
            }
        }
    }
    return count
}

Ruby:

def find_circle_num(is_connected)
  n = is_connected.length
  parent = (0...n).to_a
  count = n

  find = ->(i) {
    return i if parent[i] == i
    parent[i] = find.call(parent[i])
  }

  (0...n).each do |i|
    (i + 1...n).each do |j|
      if is_connected[i][j] == 1
        root_i = find.call(i)
        root_j = find.call(j)
        if root_i != root_j
          parent[root_i] = root_j
          count -= 1
        end
      end
    end
  end
  count
end

2. Redundant Connection

LeetCode 684

  • Brief: Find an edge that, if removed, would turn a graph with one extra edge back into a tree.
  • Why this pattern?: Union-Find is the standard way to detect cycles in an undirected graph as we add edges.
  • Key Insight: If we attempt to union two nodes that already share the same root, the current edge forms a cycle and is the “redundant” one.
  • Visual:
      graph LR
        1(Node 1) --- 2(Node 2)
        2 --- 3(Node 3)
        3 == Redundant Edge ==> 1
    

JavaScript:

var findRedundantConnection = function(edges) {
    const parent = Array.from({length: edges.length + 1}, (_, i) => i);

    function find(i) {
        if (parent[i] === i) return i;
        return parent[i] = find(parent[i]);
    }

    for (const [u, v] of edges) {
        const rootU = find(u);
        const rootV = find(v);
        if (rootU === rootV) return [u, v];
        parent[rootU] = rootV;
    }
};

Python:

class Solution:
    def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
        parent = list(range(len(edges) + 1))

        def find(i):
            if parent[i] == i:
                return i
            parent[i] = find(parent[i])
            return parent[i]

        for u, v in edges:
            root_u, root_v = find(u), find(v)
            if root_u == root_v:
                return [u, v]
            parent[root_u] = root_v

Java:

class Solution {
    public int[] findRedundantConnection(int[][] edges) {
        int[] parent = new int[edges.length + 1];
        for (int i = 0; i < parent.length; i++) parent[i] = i;

        for (int[] edge : edges) {
            int rootU = find(parent, edge[0]);
            int rootV = find(parent, edge[1]);
            if (rootU == rootV) return edge;
            parent[rootU] = rootV;
        }
        return new int[0];
    }

    private int find(int[] parent, int i) {
        if (parent[i] == i) return i;
        return parent[i] = find(parent, parent[i]);
    }
}

C++:

class Solution {
public:
    vector<int> findRedundantConnection(vector<vector<int>>& edges) {
        vector<int> parent(edges.size() + 1);
        iota(parent.begin(), parent.end(), 0);

        function<int(int)> find = [&](int i) {
            return parent[i] == i ? i : parent[i] = find(parent[i]);
        };

        for (auto& edge : edges) {
            int rootU = find(edge[0]);
            int rootV = find(edge[1]);
            if (rootU == rootV) return edge;
            parent[rootU] = rootV;
        }
        return {};
    }
};

Go:

func findRedundantConnection(edges [][]int) []int {
    parent := make([]int, len(edges)+1)
    for i := range parent {
        parent[i] = i
    }

    var find func(int) int
    find = func(i int) int {
        if parent[i] == i {
            return i
        }
        parent[i] = find(parent[i])
        return parent[i]
    }

    for _, edge := range edges {
        u, v := edge[0], edge[1]
        rootU, rootV := find(u), find(v)
        if rootU == rootV {
            return edge
        }
        parent[rootU] = rootV
    }
    return nil
}

Ruby:

def find_redundant_connection(edges)
  parent = (0..edges.length).to_a

  find = ->(i) {
    return i if parent[i] == i
    parent[i] = find.call(parent[i])
  }

  edges.each do |u, v|
    root_u = find.call(u)
    root_v = find.call(v)
    return [u, v] if root_u == root_v
    parent[root_u] = root_v
  end
end

3. Number of Islands

LeetCode 200

  • Brief: Count the number of islands in a 2D binary grid.
  • Why this pattern?: While often solved with DFS, Union-Find is excellent for counting components, especially in dynamic or distributed scenarios.
  • Key Insight: Map each 2D cell (r, c) to a 1D index r * cols + c. Union current land cells with their right and downward neighbors.
  • Visual:
      graph TD
        R1C1[1] --- R1C2[1]
        R1C1 --- R2C1[1]
        R3C3[1]
        R1C3[0]
        R2C2[0]
        subgraph "Island 1"
        R1C1
        R1C2
        R2C1
        end
        subgraph "Island 2"
        R3C3
        end
    

JavaScript:

var numIslands = function(grid) {
    if (!grid.length) return 0;
    const rows = grid.length, cols = grid[0].length;
    const parent = Array.from({length: rows * cols}, (_, i) => i);
    let count = 0;

    for (let r = 0; r < rows; r++) {
        for (let c = 0; c < cols; c++) {
            if (grid[r][c] === '1') count++;
        }
    }

    function find(i) {
        if (parent[i] === i) return i;
        return parent[i] = find(parent[i]);
    }

    function union(i, j) {
        const rootI = find(i);
        const rootJ = find(j);
        if (rootI !== rootJ) {
            parent[rootI] = rootJ;
            count--;
        }
    }

    for (let r = 0; r < rows; r++) {
        for (let c = 0; c < cols; c++) {
            if (grid[r][c] === '0') continue;
            const idx = r * cols + c;
            if (r + 1 < rows && grid[r + 1][c] === '1') union(idx, (r + 1) * cols + c);
            if (c + 1 < cols && grid[r][c + 1] === '1') union(idx, r * cols + (c + 1));
        }
    }
    return count;
};

Python:

class Solution:
    def numIslands(self, grid: List[List[str]]) -> int:
        if not grid: return 0
        rows, cols = len(grid), len(grid[0])
        parent = list(range(rows * cols))
        count = sum(row.count('1') for row in grid)

        def find(i):
            if parent[i] == i: return i
            parent[i] = find(parent[i])
            return parent[i]

        def union(i, j):
            nonlocal count
            root_i, root_j = find(i), find(j)
            if root_i != root_j:
                parent[root_i] = root_j
                count -= 1

        for r in range(rows):
            for c in range(cols):
                if grid[r][c] == '0': continue
                idx = r * cols + c
                if r + 1 < rows and grid[r+1][c] == '1': union(idx, (r+1) * cols + c)
                if c + 1 < cols and grid[r][c+1] == '1': union(idx, r * cols + (c+1))
        return count

Java:

class Solution {
    private int[] parent;
    private int count;

    public int numIslands(char[][] grid) {
        if (grid.length == 0) return 0;
        int rows = grid.length, cols = grid[0].length;
        parent = new int[rows * cols];
        count = 0;

        for (int r = 0; r < rows; r++) {
            for (int c = 0; c < cols; c++) {
                if (grid[r][c] == '1') {
                    parent[r * cols + c] = r * cols + c;
                    count++;
                }
            }
        }

        for (int r = 0; r < rows; r++) {
            for (int c = 0; c < cols; c++) {
                if (grid[r][c] == '0') continue;
                int idx = r * cols + c;
                if (r + 1 < rows && grid[r+1][c] == '1') union(idx, (r+1) * cols + c);
                if (c + 1 < cols && grid[r][c+1] == '1') union(idx, r * cols + (c+1));
            }
        }
        return count;
    }

    private int find(int i) {
        if (parent[i] == i) return i;
        return parent[i] = find(parent[i]);
    }

    private void union(int i, int j) {
        int rootI = find(i), rootJ = find(j);
        if (rootI != rootJ) {
            parent[rootI] = rootJ;
            count--;
        }
    }
}

C++:

class Solution {
public:
    int numIslands(vector<vector<char>>& grid) {
        int rows = grid.size(), cols = grid[0].size();
        vector<int> parent(rows * cols);
        int count = 0;

        for (int r = 0; r < rows; ++r) {
            for (int c = 0; c < cols; ++c) {
                if (grid[r][c] == '1') {
                    parent[r * cols + c] = r * cols + c;
                    count++;
                }
            }
        }

        function<int(int)> find = [&](int i) {
            return parent[i] == i ? i : parent[i] = find(parent[i]);
        };

        for (int r = 0; r < rows; ++r) {
            for (int c = 0; c < cols; ++c) {
                if (grid[r][c] == '0') continue;
                int idx = r * cols + c;
                if (r + 1 < rows && grid[r + 1][c] == '1') {
                    int rootI = find(idx), rootJ = find((r + 1) * cols + c);
                    if (rootI != rootJ) { parent[rootI] = rootJ; count--; }
                }
                if (c + 1 < cols && grid[r][c + 1] == '1') {
                    int rootI = find(idx), rootJ = find(r * cols + (c + 1));
                    if (rootI != rootJ) { parent[rootI] = rootJ; count--; }
                }
            }
        }
        return count;
    }
};

Go:

func numIslands(grid [][]byte) int {
    if len(grid) == 0 { return 0 }
    rows, cols := len(grid), len(grid[0])
    parent := make([]int, rows*cols)
    count := 0

    for r := 0; r < rows; r++ {
        for c := 0; c < cols; c++ {
            if grid[r][c] == '1' {
                parent[r*cols+c] = r*cols + c
                count++
            }
        }
    }

    var find func(int) int
    find = func(i int) int {
        if parent[i] == i { return i }
        parent[i] = find(parent[i])
        return parent[i]
    }

    for r := 0; r < rows; r++ {
        for c := 0; c < cols; c++ {
            if grid[r][c] == '0' { continue }
            idx := r*cols + c
            if r+1 < rows && grid[r+1][c] == '1' {
                rootI, rootJ := find(idx), find((r+1)*cols+c)
                if rootI != rootJ { parent[rootI] = rootJ; count-- }
            }
            if c+1 < cols && grid[r][c+1] == '1' {
                rootI, rootJ := find(idx), find(r*cols+(c+1))
                if rootI != rootJ { parent[rootI] = rootJ; count-- }
            }
        }
    }
    return count
}

Ruby:

def num_islands(grid)
  return 0 if grid.empty?
  rows, cols = grid.length, grid[0].length
  parent = (0...rows * cols).to_a
  count = 0

  rows.times do |r|
    cols.times do |c|
      count += 1 if grid[r][c] == '1'
    end
  end

  find = ->(i) {
    return i if parent[i] == i
    parent[i] = find.call(parent[i])
  }

  rows.times do |r|
    cols.times do |c|
      next if grid[r][c] == '0'
      idx = r * cols + c
      if r + 1 < rows && grid[r + 1][c] == '1'
        root_i, root_j = find.call(idx), find.call((r + 1) * cols + c)
        if root_i != root_j then parent[root_i] = root_j; count -= 1 end
      end
      if c + 1 < cols && grid[r][c + 1] == '1'
        root_i, root_j = find.call(idx), find.call(r * cols + (c + 1))
        if root_i != root_j then parent[root_i] = root_j; count -= 1 end
      end
    end
  end
  count
end

Hard Problems

4. Min Cost to Connect All Points

LeetCode 1584

  • Brief: Find the minimum cost to connect all points such that there is a path between any two points.
  • Why this pattern?: This is a classic Minimum Spanning Tree (MST) problem. Kruskal’s algorithm uses Union-Find to efficiently build the MST.
  • Key Insight: Create all unique edges between points, sort them by Manhattan distance, and use Union-Find to pick the smallest edges that don’t create a cycle.
  • Visual:
      graph LR
        P1((0,0)) == cost:2 ==> P2((0,2))
        P1 == cost:4 ==> P3((4,0))
        P2 -- cost:6 --- P3
    

JavaScript:

var minCostConnectPoints = function(points) {
    const n = points.length;
    const edges = [];
    for (let i = 0; i < n; i++) {
        for (let j = i + 1; j < n; j++) {
            const dist = Math.abs(points[i][0] - points[j][0]) + Math.abs(points[i][1] - points[j][1]);
            edges.push([dist, i, j]);
        }
    }
    edges.sort((a, b) => a[0] - b[0]);

    const parent = Array.from({length: n}, (_, i) => i);
    function find(i) {
        if (parent[i] === i) return i;
        return parent[i] = find(parent[i]);
    }

    let minCost = 0, edgesUsed = 0;
    for (const [dist, u, v] of edges) {
        const rootU = find(u), rootV = find(v);
        if (rootU !== rootV) {
            parent[rootU] = rootV;
            minCost += dist;
            if (++edgesUsed === n - 1) break;
        }
    }
    return minCost;
};

Python:

class Solution:
    def minCostConnectPoints(self, points: List[List[int]]) -> int:
        n = len(points)
        edges = []
        for i in range(n):
            for j in range(i + 1, n):
                dist = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
                edges.append((dist, i, j))
        edges.sort()

        parent = list(range(n))
        def find(i):
            if parent[i] == i: return i
            parent[i] = find(parent[i])
            return parent[i]

        min_cost, edges_used = 0, 0
        for dist, u, v in edges:
            root_u, root_v = find(u), find(v)
            if root_u != root_v:
                parent[root_u] = root_v
                min_cost += dist
                edges_used += 1
                if edges_used == n - 1: break
        return min_cost

Java:

class Solution {
    public int minCostConnectPoints(int[][] points) {
        int n = points.length;
        PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> a[0] - b[0]);
        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j < n; j++) {
                int dist = Math.abs(points[i][0] - points[j][0]) + Math.abs(points[i][1] - points[j][1]);
                pq.add(new int[]{dist, i, j});
            }
        }

        int[] parent = new int[n];
        for (int i = 0; i < n; i++) parent[i] = i;

        int minCost = 0, edgesUsed = 0;
        while (edgesUsed < n - 1) {
            int[] edge = pq.poll();
            int rootU = find(parent, edge[1]), rootV = find(parent, edge[2]);
            if (rootU != rootV) {
                parent[rootU] = rootV;
                minCost += edge[0];
                edgesUsed++;
            }
        }
        return minCost;
    }

    private int find(int[] parent, int i) {
        if (parent[i] == i) return i;
        return parent[i] = find(parent, parent[i]);
    }
}

C++:

class Solution {
public:
    int minCostConnectPoints(vector<vector<int>>& points) {
        int n = points.size();
        vector<pair<int, pair<int, int>>> edges;
        for (int i = 0; i < n; ++i) {
            for (int j = i + 1; j < n; ++j) {
                int dist = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1]);
                edges.push_back({dist, {i, j}});
            }
        }
        sort(edges.begin(), edges.end());

        vector<int> parent(n);
        iota(parent.begin(), parent.end(), 0);
        function<int(int)> find = [&](int i) {
            return parent[i] == i ? i : parent[i] = find(parent[i]);
        };

        int minCost = 0, edgesUsed = 0;
        for (auto& edge : edges) {
            int rootU = find(edge.second.first), rootV = find(edge.second.second);
            if (rootU != rootV) {
                parent[rootU] = rootV;
                minCost += edge.first;
                if (++edgesUsed == n - 1) break;
            }
        }
        return minCost;
    }
};

Go:

func minCostConnectPoints(points [][]int) int {
    n := len(points)
    type Edge struct{ dist, u, v int }
    edges := make([]Edge, 0, n*(n-1)/2)
    for i := 0; i < n; i++ {
        for j := i + 1; j < n; j++ {
            dist := abs(points[i][0]-points[j][0]) + abs(points[i][1]-points[j][1])
            edges = append(edges, Edge{dist, i, j})
        }
    }
    sort.Slice(edges, func(i, j int) bool { return edges[i].dist < edges[j].dist })

    parent := make([]int, n)
    for i := range parent { parent[i] = i }
    var find func(int) int
    find = func(i int) int {
        if parent[i] == i { return i }
        parent[i] = find(parent[i])
        return parent[i]
    }

    minCost, edgesUsed := 0, 0
    for _, edge := range edges {
        rootU, rootV := find(edge.u), find(edge.v)
        if rootU != rootV {
            parent[rootU] = rootV
            minCost += edge.dist
            edgesUsed++
            if edgesUsed == n-1 { break }
        }
    }
    return minCost
}
func abs(x int) int { if x < 0 { return -x }; return x }

Ruby:

def min_cost_connect_points(points)
  n = points.length
  edges = []
  (0...n).each do |i|
    (i + 1...n).each do |j|
      dist = (points[i][0] - points[j][0]).abs + (points[i][1] - points[j][1]).abs
      edges << [dist, i, j]
    end
  end
  edges.sort_by! { |e| e[0] }

  parent = (0...n).to_a
  find = ->(i) {
    return i if parent[i] == i
    parent[i] = find.call(parent[i])
  }

  min_cost = 0
  edges_used = 0
  edges.each do |dist, u, v|
    root_u, root_v = find.call(u), find.call(v)
    if root_u != root_v
      parent[root_u] = root_v
      min_cost += dist
      edges_used += 1
      break if edges_used == n - 1
    end
  end
  min_cost
end

Recommended Study Order

  1. Basics: Start with Number of Provinces (LC 547) to understand basic connectivity and component counting.
  2. Core Application: Solve Redundant Connection (LC 684) to master cycle detection.
  3. Grid Modeling: Tackle Number of Islands (LC 200). Learning to model a 2D grid as a graph is critical for interview success.
  4. Advanced: Attempt Min Cost to Connect All Points (LC 1584) to see how Union-Find powers Kruskal’s MST algorithm.
Pro Tip: For almost all Union-Find problems, the “trick” is in how you model the components. Decide whether to represent grid cells, hash table keys, or array indices as your nodes. The Union-Find logic itself remains constant.