Rerooting Techniques in Python
Rerooting techniques are typically used in the context of rooted trees in computational problems, such as dynamic programming on trees, network design, and more. The idea is to compute some values for a tree rooted at one node and then efficiently recompute the values for the tree rooted at other nodes.
To demonstrate rerooting techniques in Python, let’s consider a common problem: calculating the sum of distances from each node to all other nodes in an unweighted tree. This problem can be solved using rerooting techniques efficiently.
What is Rerooting Technique?
The rerooting technique is an algorithmic method used in tree data structures to efficiently compute properties of nodes relative to different roots. This technique allows for the recalculation of node properties when the root of the tree changes. It is used in solving problems involving dynamic programming on trees, path queries, and subtree properties.
Rerooting Technique Steps in Python:
Step 1: Representing the Tree
We’ll represent the tree using an adjacency list.
from collections import defaultdict
class Tree:
def __init__(self, n):
self.n = n
self.graph = defaultdict(list)
def add_edge(self, u, v):
self.graph[u].append(v)
self.graph[v].append(u)
Step 2: Initial Calculation of Distances
First, we calculate the sum of distances from a root node (let’s say node 0) to all other nodes using a Depth-First Search (DFS). We’ll also calculate the size of the subtree for each node.
def dfs(node, parent, tree, dp, subtree_size):
for neighbor in tree.graph[node]:
if neighbor == parent:
continue
dfs(neighbor, node, tree, dp, subtree_size)
subtree_size[node] += subtree_size[neighbor]
dp[node] += dp[neighbor] + subtree_size[neighbor]
Step 3: Rerooting to Calculate Distances for All Nodes
Next, we’ll use rerooting to efficiently compute the sum of distances for the tree rooted at each node.
def reroot(node, parent, tree, dp, subtree_size, result):
result[node] = dp[node]
for neighbor in tree.graph[node]:
if neighbor == parent:
continue
# Move root from node to neighbor
dp[node] -= dp[neighbor] + subtree_size[neighbor]
subtree_size[node] -= subtree_size[neighbor]
dp[neighbor] += dp[node] + subtree_size[node]
subtree_size[neighbor] += subtree_size[node]
# Recursively reroot
reroot(neighbor, node, tree, dp, subtree_size, result)
# Restore original values
subtree_size[neighbor] -= subtree_size[node]
dp[neighbor] -= dp[node] + subtree_size[node]
subtree_size[node] += subtree_size[neighbor]
dp[node] += dp[neighbor] + subtree_size[neighbor]
Step 4: Putting It All Together
Now we can combine everything into a function to compute the sum of distances from each node to all other nodes.
def sum_of_distances_in_tree(n, edges):
tree = Tree(n)
for u, v in edges:
tree.add_edge(u, v)
dp = [0] * n
subtree_size = [1] * n
result = [0] * n
# Initial DFS from node 0
dfs(0, -1, tree, dp, subtree_size)
# Rerooting to calculate the result for all nodes
reroot(0, -1, tree, dp, subtree_size, result)
return result
# Example usage
n = 6
edges = [(0, 1), (0, 2), (2, 3), (2, 4), (2, 5)]
print(sum_of_distances_in_tree(n, edges))
Complete Implementation of Rerooting technique in Python:
Below is the implementation in Python:
from collections import defaultdict
class Tree:
def __init__(self, n):
self.n = n
self.graph = defaultdict(list)
def add_edge(self, u, v):
self.graph[u].append(v)
self.graph[v].append(u)
def reroot(node, parent, tree, dp, subtree_size, result):
result[node] = dp[node]
for neighbor in tree.graph[node]:
if neighbor == parent:
continue
# Move root from node to neighbor
dp[node] -= dp[neighbor] + subtree_size[neighbor]
subtree_size[node] -= subtree_size[neighbor]
dp[neighbor] += dp[node] + subtree_size[node]
subtree_size[neighbor] += subtree_size[node]
# Recursively reroot
reroot(neighbor, node, tree, dp, subtree_size, result)
# Restore original values
subtree_size[neighbor] -= subtree_size[node]
dp[neighbor] -= dp[node] + subtree_size[node]
subtree_size[node] += subtree_size[neighbor]
dp[node] += dp[neighbor] + subtree_size[neighbor]
def dfs(node, parent, tree, dp, subtree_size):
for neighbor in tree.graph[node]:
if neighbor == parent:
continue
dfs(neighbor, node, tree, dp, subtree_size)
subtree_size[node] += subtree_size[neighbor]
dp[node] += dp[neighbor] + subtree_size[neighbor]
# function to find the sum of distances from each node to all other nodes
def sum_of_distances_in_tree(n, edges):
tree = Tree(n)
for u, v in edges:
tree.add_edge(u, v)
dp = [0] * n
subtree_size = [1] * n
result = [0] * n
# Initial DFS from node 0
dfs(0, -1, tree, dp, subtree_size)
# Rerooting to calculate the result for all nodes
reroot(0, -1, tree, dp, subtree_size, result)
return result
# Example usage
n = 6
edges = [(0, 1), (0, 2), (2, 3), (2, 4), (2, 5)]
print(sum_of_distances_in_tree(n, edges))
Time Complexity: O(N), where N is the number of nodes in the tree.
Auxiliary Space: O(N)
Contact Us