Problems
Master the Union-Find (Disjoint Set Union) pattern by solving these carefully selected problems. Each problem demonstrates a different aspect of connectivity, from basic component counting to cycle detection and minimum spanning trees.
For a theoretical overview, check out the Union-Find Guide or the Code Templates.
Easy & Medium Problems
1. Number of Provinces
LeetCode 547
- Brief: Given an adjacency matrix representing direct connections between cities, return the total number of provinces (connected components).
- Why this pattern?: Union-Find is the most direct way to merge elements into disjoint sets and count the number of unique representatives.
- Key Insight: For every
isConnected[i][j] == 1, perform aunion(i, j). The final answer is the number of distinct roots. - Visual:
graph TD subgraph "Input Matrix" M[1 1 0 / 1 1 0 / 0 0 1] end subgraph "Resulting Components" A[City 0] --- B[City 1] C[City 2] end
JavaScript:
var findCircleNum = function(isConnected) {
const n = isConnected.length;
const parent = Array.from({length: n}, (_, i) => i);
let components = n;
function find(i) {
if (parent[i] === i) return i;
return parent[i] = find(parent[i]);
}
function union(i, j) {
const rootI = find(i);
const rootJ = find(j);
if (rootI !== rootJ) {
parent[rootI] = rootJ;
components--;
}
}
for (let i = 0; i < n; i++) {
for (let j = i + 1; j < n; j++) {
if (isConnected[i][j] === 1) {
union(i, j);
}
}
}
return components;
};Python:
class Solution:
def findCircleNum(self, isConnected: List[List[int]]) -> int:
n = len(isConnected)
parent = list(range(n))
count = n
def find(i):
if parent[i] == i:
return i
parent[i] = find(parent[i])
return parent[i]
def union(i, j):
nonlocal count
root_i, root_j = find(i), find(j)
if root_i != root_j:
parent[root_i] = root_j
count -= 1
for i in range(n):
for j in range(i + 1, n):
if isConnected[i][j] == 1:
union(i, j)
return countJava:
class Solution {
private int[] parent;
private int count;
public int findCircleNum(int[][] isConnected) {
int n = isConnected.length;
parent = new int[n];
count = n;
for (int i = 0; i < n; i++) parent[i] = i;
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
if (isConnected[i][j] == 1) union(i, j);
}
}
return count;
}
private int find(int i) {
if (parent[i] == i) return i;
return parent[i] = find(parent[i]);
}
private void union(int i, int j) {
int rootI = find(i);
int rootJ = find(j);
if (rootI != rootJ) {
parent[rootI] = rootJ;
count--;
}
}
}C++:
class Solution {
public:
int findCircleNum(vector<vector<int>>& isConnected) {
int n = isConnected.size();
vector<int> parent(n);
iota(parent.begin(), parent.end(), 0);
int count = n;
function<int(int)> find = [&](int i) {
return parent[i] == i ? i : parent[i] = find(parent[i]);
};
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
if (isConnected[i][j] == 1) {
int rootI = find(i);
int rootJ = find(j);
if (rootI != rootJ) {
parent[rootI] = rootJ;
count--;
}
}
}
}
return count;
}
};Go:
func findCircleNum(isConnected [][]int) int {
n := len(isConnected)
parent := make([]int, n)
for i := range parent {
parent[i] = i
}
count := n
var find func(int) int
find = func(i int) int {
if parent[i] == i {
return i
}
parent[i] = find(parent[i])
return parent[i]
}
for i := 0; i < n; i++ {
for j := i + 1; j < n; j++ {
if isConnected[i][j] == 1 {
rootI, rootJ := find(i), find(j)
if rootI != rootJ {
parent[rootI] = rootJ
count--
}
}
}
}
return count
}Ruby:
def find_circle_num(is_connected)
n = is_connected.length
parent = (0...n).to_a
count = n
find = ->(i) {
return i if parent[i] == i
parent[i] = find.call(parent[i])
}
(0...n).each do |i|
(i + 1...n).each do |j|
if is_connected[i][j] == 1
root_i = find.call(i)
root_j = find.call(j)
if root_i != root_j
parent[root_i] = root_j
count -= 1
end
end
end
end
count
end2. Redundant Connection
LeetCode 684
- Brief: Find an edge that, if removed, would turn a graph with one extra edge back into a tree.
- Why this pattern?: Union-Find is the standard way to detect cycles in an undirected graph as we add edges.
- Key Insight: If we attempt to union two nodes that already share the same root, the current edge forms a cycle and is the “redundant” one.
- Visual:
graph LR 1(Node 1) --- 2(Node 2) 2 --- 3(Node 3) 3 == Redundant Edge ==> 1
JavaScript:
var findRedundantConnection = function(edges) {
const parent = Array.from({length: edges.length + 1}, (_, i) => i);
function find(i) {
if (parent[i] === i) return i;
return parent[i] = find(parent[i]);
}
for (const [u, v] of edges) {
const rootU = find(u);
const rootV = find(v);
if (rootU === rootV) return [u, v];
parent[rootU] = rootV;
}
};Python:
class Solution:
def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
parent = list(range(len(edges) + 1))
def find(i):
if parent[i] == i:
return i
parent[i] = find(parent[i])
return parent[i]
for u, v in edges:
root_u, root_v = find(u), find(v)
if root_u == root_v:
return [u, v]
parent[root_u] = root_vJava:
class Solution {
public int[] findRedundantConnection(int[][] edges) {
int[] parent = new int[edges.length + 1];
for (int i = 0; i < parent.length; i++) parent[i] = i;
for (int[] edge : edges) {
int rootU = find(parent, edge[0]);
int rootV = find(parent, edge[1]);
if (rootU == rootV) return edge;
parent[rootU] = rootV;
}
return new int[0];
}
private int find(int[] parent, int i) {
if (parent[i] == i) return i;
return parent[i] = find(parent, parent[i]);
}
}C++:
class Solution {
public:
vector<int> findRedundantConnection(vector<vector<int>>& edges) {
vector<int> parent(edges.size() + 1);
iota(parent.begin(), parent.end(), 0);
function<int(int)> find = [&](int i) {
return parent[i] == i ? i : parent[i] = find(parent[i]);
};
for (auto& edge : edges) {
int rootU = find(edge[0]);
int rootV = find(edge[1]);
if (rootU == rootV) return edge;
parent[rootU] = rootV;
}
return {};
}
};Go:
func findRedundantConnection(edges [][]int) []int {
parent := make([]int, len(edges)+1)
for i := range parent {
parent[i] = i
}
var find func(int) int
find = func(i int) int {
if parent[i] == i {
return i
}
parent[i] = find(parent[i])
return parent[i]
}
for _, edge := range edges {
u, v := edge[0], edge[1]
rootU, rootV := find(u), find(v)
if rootU == rootV {
return edge
}
parent[rootU] = rootV
}
return nil
}Ruby:
def find_redundant_connection(edges)
parent = (0..edges.length).to_a
find = ->(i) {
return i if parent[i] == i
parent[i] = find.call(parent[i])
}
edges.each do |u, v|
root_u = find.call(u)
root_v = find.call(v)
return [u, v] if root_u == root_v
parent[root_u] = root_v
end
end3. Number of Islands
LeetCode 200
- Brief: Count the number of islands in a 2D binary grid.
- Why this pattern?: While often solved with DFS, Union-Find is excellent for counting components, especially in dynamic or distributed scenarios.
- Key Insight: Map each 2D cell
(r, c)to a 1D indexr * cols + c. Union current land cells with their right and downward neighbors. - Visual:
graph TD R1C1[1] --- R1C2[1] R1C1 --- R2C1[1] R3C3[1] R1C3[0] R2C2[0] subgraph "Island 1" R1C1 R1C2 R2C1 end subgraph "Island 2" R3C3 end
JavaScript:
var numIslands = function(grid) {
if (!grid.length) return 0;
const rows = grid.length, cols = grid[0].length;
const parent = Array.from({length: rows * cols}, (_, i) => i);
let count = 0;
for (let r = 0; r < rows; r++) {
for (let c = 0; c < cols; c++) {
if (grid[r][c] === '1') count++;
}
}
function find(i) {
if (parent[i] === i) return i;
return parent[i] = find(parent[i]);
}
function union(i, j) {
const rootI = find(i);
const rootJ = find(j);
if (rootI !== rootJ) {
parent[rootI] = rootJ;
count--;
}
}
for (let r = 0; r < rows; r++) {
for (let c = 0; c < cols; c++) {
if (grid[r][c] === '0') continue;
const idx = r * cols + c;
if (r + 1 < rows && grid[r + 1][c] === '1') union(idx, (r + 1) * cols + c);
if (c + 1 < cols && grid[r][c + 1] === '1') union(idx, r * cols + (c + 1));
}
}
return count;
};Python:
class Solution:
def numIslands(self, grid: List[List[str]]) -> int:
if not grid: return 0
rows, cols = len(grid), len(grid[0])
parent = list(range(rows * cols))
count = sum(row.count('1') for row in grid)
def find(i):
if parent[i] == i: return i
parent[i] = find(parent[i])
return parent[i]
def union(i, j):
nonlocal count
root_i, root_j = find(i), find(j)
if root_i != root_j:
parent[root_i] = root_j
count -= 1
for r in range(rows):
for c in range(cols):
if grid[r][c] == '0': continue
idx = r * cols + c
if r + 1 < rows and grid[r+1][c] == '1': union(idx, (r+1) * cols + c)
if c + 1 < cols and grid[r][c+1] == '1': union(idx, r * cols + (c+1))
return countJava:
class Solution {
private int[] parent;
private int count;
public int numIslands(char[][] grid) {
if (grid.length == 0) return 0;
int rows = grid.length, cols = grid[0].length;
parent = new int[rows * cols];
count = 0;
for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) {
if (grid[r][c] == '1') {
parent[r * cols + c] = r * cols + c;
count++;
}
}
}
for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) {
if (grid[r][c] == '0') continue;
int idx = r * cols + c;
if (r + 1 < rows && grid[r+1][c] == '1') union(idx, (r+1) * cols + c);
if (c + 1 < cols && grid[r][c+1] == '1') union(idx, r * cols + (c+1));
}
}
return count;
}
private int find(int i) {
if (parent[i] == i) return i;
return parent[i] = find(parent[i]);
}
private void union(int i, int j) {
int rootI = find(i), rootJ = find(j);
if (rootI != rootJ) {
parent[rootI] = rootJ;
count--;
}
}
}C++:
class Solution {
public:
int numIslands(vector<vector<char>>& grid) {
int rows = grid.size(), cols = grid[0].size();
vector<int> parent(rows * cols);
int count = 0;
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
if (grid[r][c] == '1') {
parent[r * cols + c] = r * cols + c;
count++;
}
}
}
function<int(int)> find = [&](int i) {
return parent[i] == i ? i : parent[i] = find(parent[i]);
};
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
if (grid[r][c] == '0') continue;
int idx = r * cols + c;
if (r + 1 < rows && grid[r + 1][c] == '1') {
int rootI = find(idx), rootJ = find((r + 1) * cols + c);
if (rootI != rootJ) { parent[rootI] = rootJ; count--; }
}
if (c + 1 < cols && grid[r][c + 1] == '1') {
int rootI = find(idx), rootJ = find(r * cols + (c + 1));
if (rootI != rootJ) { parent[rootI] = rootJ; count--; }
}
}
}
return count;
}
};Go:
func numIslands(grid [][]byte) int {
if len(grid) == 0 { return 0 }
rows, cols := len(grid), len(grid[0])
parent := make([]int, rows*cols)
count := 0
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
if grid[r][c] == '1' {
parent[r*cols+c] = r*cols + c
count++
}
}
}
var find func(int) int
find = func(i int) int {
if parent[i] == i { return i }
parent[i] = find(parent[i])
return parent[i]
}
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
if grid[r][c] == '0' { continue }
idx := r*cols + c
if r+1 < rows && grid[r+1][c] == '1' {
rootI, rootJ := find(idx), find((r+1)*cols+c)
if rootI != rootJ { parent[rootI] = rootJ; count-- }
}
if c+1 < cols && grid[r][c+1] == '1' {
rootI, rootJ := find(idx), find(r*cols+(c+1))
if rootI != rootJ { parent[rootI] = rootJ; count-- }
}
}
}
return count
}Ruby:
def num_islands(grid)
return 0 if grid.empty?
rows, cols = grid.length, grid[0].length
parent = (0...rows * cols).to_a
count = 0
rows.times do |r|
cols.times do |c|
count += 1 if grid[r][c] == '1'
end
end
find = ->(i) {
return i if parent[i] == i
parent[i] = find.call(parent[i])
}
rows.times do |r|
cols.times do |c|
next if grid[r][c] == '0'
idx = r * cols + c
if r + 1 < rows && grid[r + 1][c] == '1'
root_i, root_j = find.call(idx), find.call((r + 1) * cols + c)
if root_i != root_j then parent[root_i] = root_j; count -= 1 end
end
if c + 1 < cols && grid[r][c + 1] == '1'
root_i, root_j = find.call(idx), find.call(r * cols + (c + 1))
if root_i != root_j then parent[root_i] = root_j; count -= 1 end
end
end
end
count
endHard Problems
4. Min Cost to Connect All Points
LeetCode 1584
- Brief: Find the minimum cost to connect all points such that there is a path between any two points.
- Why this pattern?: This is a classic Minimum Spanning Tree (MST) problem. Kruskal’s algorithm uses Union-Find to efficiently build the MST.
- Key Insight: Create all unique edges between points, sort them by Manhattan distance, and use Union-Find to pick the smallest edges that don’t create a cycle.
- Visual:
graph LR P1((0,0)) == cost:2 ==> P2((0,2)) P1 == cost:4 ==> P3((4,0)) P2 -- cost:6 --- P3
JavaScript:
var minCostConnectPoints = function(points) {
const n = points.length;
const edges = [];
for (let i = 0; i < n; i++) {
for (let j = i + 1; j < n; j++) {
const dist = Math.abs(points[i][0] - points[j][0]) + Math.abs(points[i][1] - points[j][1]);
edges.push([dist, i, j]);
}
}
edges.sort((a, b) => a[0] - b[0]);
const parent = Array.from({length: n}, (_, i) => i);
function find(i) {
if (parent[i] === i) return i;
return parent[i] = find(parent[i]);
}
let minCost = 0, edgesUsed = 0;
for (const [dist, u, v] of edges) {
const rootU = find(u), rootV = find(v);
if (rootU !== rootV) {
parent[rootU] = rootV;
minCost += dist;
if (++edgesUsed === n - 1) break;
}
}
return minCost;
};Python:
class Solution:
def minCostConnectPoints(self, points: List[List[int]]) -> int:
n = len(points)
edges = []
for i in range(n):
for j in range(i + 1, n):
dist = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
edges.append((dist, i, j))
edges.sort()
parent = list(range(n))
def find(i):
if parent[i] == i: return i
parent[i] = find(parent[i])
return parent[i]
min_cost, edges_used = 0, 0
for dist, u, v in edges:
root_u, root_v = find(u), find(v)
if root_u != root_v:
parent[root_u] = root_v
min_cost += dist
edges_used += 1
if edges_used == n - 1: break
return min_costJava:
class Solution {
public int minCostConnectPoints(int[][] points) {
int n = points.length;
PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> a[0] - b[0]);
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
int dist = Math.abs(points[i][0] - points[j][0]) + Math.abs(points[i][1] - points[j][1]);
pq.add(new int[]{dist, i, j});
}
}
int[] parent = new int[n];
for (int i = 0; i < n; i++) parent[i] = i;
int minCost = 0, edgesUsed = 0;
while (edgesUsed < n - 1) {
int[] edge = pq.poll();
int rootU = find(parent, edge[1]), rootV = find(parent, edge[2]);
if (rootU != rootV) {
parent[rootU] = rootV;
minCost += edge[0];
edgesUsed++;
}
}
return minCost;
}
private int find(int[] parent, int i) {
if (parent[i] == i) return i;
return parent[i] = find(parent, parent[i]);
}
}C++:
class Solution {
public:
int minCostConnectPoints(vector<vector<int>>& points) {
int n = points.size();
vector<pair<int, pair<int, int>>> edges;
for (int i = 0; i < n; ++i) {
for (int j = i + 1; j < n; ++j) {
int dist = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1]);
edges.push_back({dist, {i, j}});
}
}
sort(edges.begin(), edges.end());
vector<int> parent(n);
iota(parent.begin(), parent.end(), 0);
function<int(int)> find = [&](int i) {
return parent[i] == i ? i : parent[i] = find(parent[i]);
};
int minCost = 0, edgesUsed = 0;
for (auto& edge : edges) {
int rootU = find(edge.second.first), rootV = find(edge.second.second);
if (rootU != rootV) {
parent[rootU] = rootV;
minCost += edge.first;
if (++edgesUsed == n - 1) break;
}
}
return minCost;
}
};Go:
func minCostConnectPoints(points [][]int) int {
n := len(points)
type Edge struct{ dist, u, v int }
edges := make([]Edge, 0, n*(n-1)/2)
for i := 0; i < n; i++ {
for j := i + 1; j < n; j++ {
dist := abs(points[i][0]-points[j][0]) + abs(points[i][1]-points[j][1])
edges = append(edges, Edge{dist, i, j})
}
}
sort.Slice(edges, func(i, j int) bool { return edges[i].dist < edges[j].dist })
parent := make([]int, n)
for i := range parent { parent[i] = i }
var find func(int) int
find = func(i int) int {
if parent[i] == i { return i }
parent[i] = find(parent[i])
return parent[i]
}
minCost, edgesUsed := 0, 0
for _, edge := range edges {
rootU, rootV := find(edge.u), find(edge.v)
if rootU != rootV {
parent[rootU] = rootV
minCost += edge.dist
edgesUsed++
if edgesUsed == n-1 { break }
}
}
return minCost
}
func abs(x int) int { if x < 0 { return -x }; return x }Ruby:
def min_cost_connect_points(points)
n = points.length
edges = []
(0...n).each do |i|
(i + 1...n).each do |j|
dist = (points[i][0] - points[j][0]).abs + (points[i][1] - points[j][1]).abs
edges << [dist, i, j]
end
end
edges.sort_by! { |e| e[0] }
parent = (0...n).to_a
find = ->(i) {
return i if parent[i] == i
parent[i] = find.call(parent[i])
}
min_cost = 0
edges_used = 0
edges.each do |dist, u, v|
root_u, root_v = find.call(u), find.call(v)
if root_u != root_v
parent[root_u] = root_v
min_cost += dist
edges_used += 1
break if edges_used == n - 1
end
end
min_cost
endRecommended Study Order
- Basics: Start with Number of Provinces (LC 547) to understand basic connectivity and component counting.
- Core Application: Solve Redundant Connection (LC 684) to master cycle detection.
- Grid Modeling: Tackle Number of Islands (LC 200). Learning to model a 2D grid as a graph is critical for interview success.
- Advanced: Attempt Min Cost to Connect All Points (LC 1584) to see how Union-Find powers Kruskal’s MST algorithm.
Pro Tip: For almost all Union-Find problems, the “trick” is in how you model the components. Decide whether to represent grid cells, hash table keys, or array indices as your nodes. The Union-Find logic itself remains constant.