Persistent Segment Tree in Python

Persistent data structures are a powerful tool in computer science, enabling us to maintain and access multiple versions of a data structure over time. One such structure is the Persistent Segment Tree. Segment trees are versatile data structures that allow efficient querying and updating of array intervals. By making a segment tree persistent, we enhance its capability to maintain historical versions, which is particularly useful in competitive programming and real-time applications where rollback and point-in-time queries are needed. This article explores the concept, implementation in Python.

Persistent segment trees

What is Persistent Segment Tree?

Persistence in data structures refers to the ability to maintain access to previous versions of the data structure even after modifications. This can be achieved using techniques such as path copying, where only the parts of the structure that need to be changed are copied, thus saving space and time.

Representation of Persistent Segment Tree:

Each node in a Persistent Segment Tree contains:

  • Value: The value stored at this node, which typically represents an aggregate (like sum, min, or max) over a segment of the array.
  • Left: A reference to the left child node.
  • Right: A reference to the right child node.
Node Structure
class Node:
    def __init__(self, value=0, left=None, right=None):

        # Initialize a new node
        self.value = value  # The value stored in this node
        self.left = left    # Reference to the left child node
        self.right = right  # Reference to the right child node

Persistent Segment Tree Operations

Persistent Segment Trees allow you to perform updates and queries efficiently while preserving the history of changes. Here are the primary operations you can perform on a Persistent Segment Tree:

  1. Build: Construct the initial segment tree from an array.
  2. Update: Create a new version of the tree with an updated value at a specific index.
  3. Query: Perform range queries on any version of the tree.

1. Building the Initial Segment Tree

The build function constructs the segment tree from an array. This is similar to building a regular segment tree but sets up the foundation for persistence.

Algorithm:

  1. If the current segment is a single element (i.e., left == right), create a leaf node with the value of that element.
  2. Otherwise, split the segment into two halves and recursively build the left and right subtrees.
  3. Create a new node whose value is the sum (or other aggregate function) of the values of the left and right children.

Implementation:

Python
def build(arr, left, right):

    # Function to build the initial segment tree
    if left == right:

        # If the current segment is a single element, create a leaf node
        return Node(value=arr[left])

    # Calculate the mid-point of the current segment
    mid = (left + right) // 2

    # Recursively build the left and right subtrees
    left_child = build(arr, left, mid)
    right_child = build(arr, mid + 1, right)

    # Create a new node with the sum of values of left and right children
    return Node(value=left_child.value + right_child.value, left=left_child, right=right_child)

2. Updating the Segment Tree Persistently

The update function creates a new version of the segment tree by copying only the necessary parts. It returns a new root node for the updated version, leaving the previous versions intact.

Algorithm:

  1. If the current segment is a single element (i.e., left == right), create a new node with the updated value.
  2. Otherwise, determine whether the index to be updated lies in the left or right subtree.
  3. Recursively update the relevant subtree while keeping the other subtree unchanged.
  4. Create a new node whose value is the sum (or other aggregate function) of the values of the updated left and right children.

Implementation:

Python
def update(prev_node, left, right, idx, new_value):

    # Function to perform an update operation on the segment tree persistently
    if left == right:

        # If the current segment is a single element, create a new node with the updated value
        return Node(value=new_value)

    # Calculate the mid-point of the current segment
    mid = (left + right) // 2

    # Determine whether the index to be updated lies in the left or right subtree
    if idx <= mid:

        # Recursively update the left subtree, keep the right subtree unchanged
        left_child = update(prev_node.left, left, mid, idx, new_value)
        right_child = prev_node.right
    else:

        # Recursively update the right subtree, keep the left subtree unchanged
        left_child = prev_node.left
        right_child = update(prev_node.right, mid + 1, right, idx, new_value)

    # Create a new node with the sum of values of the updated left and right children
    return Node(value=left_child.value + right_child.value, left=left_child, right=right_child)

3. Querying the Segment Tree

The query function retrieves information from a segment of the tree. It can perform range queries on any version of the tree.

Algorithm:

  1. If the query range does not overlap with the current segment, return 0 (or the identity value for the aggregate function).
  2. If the current segment is completely within the query range, return the value of the current node.
  3. Otherwise, split the query range and recursively query the left and right subtrees.
  4. Return the sum (or other aggregate function) of the results of the left and right subtree queries.

Implementation:

Python
def query(node, left, right, query_left, query_right):
    # Function to perform a range query on the segment tree
    if query_left > right or query_right < left:
        # If the query range does not overlap with the current segment, return 0
        return 0
    
    if query_left <= left and right <= query_right:
        # If the current segment is completely within the query range, return the value of the current node
        return node.value
    
    # Calculate the mid-point of the current segment
    mid = (left + right) // 2
    
    # Recursively query the left and right subtrees and return the sum of results
    return query(node.left, left, mid, query_left, query_right) + query(node.right, mid + 1, right, query_left, query_right)

Illustration of working of Persistent Segment Tree in Python:

Let’s walk through a detailed example step-by-step to illustrate how the persistent segment tree works:

Initial Array

  • arr = [1,2,3,4,5]

Building the Initial Segment Tree

  1. Build the leaf nodes:
    • Leaf node for arr[0]: value = 1
    • Leaf node for arr[1]: value = 2
    • Leaf node for arr[2]: value = 3
    • Leaf node for arr[3]: value = 4
    • Leaf node for arr[4]: value = 5
  2. Build the internal nodes:
    • Node for range [0, 1]: value = 1 + 2 = 3
    • Node for range [2, 2]: value = 3
    • Node for range [3, 4]: value = 4 + 5 = 9
    • Node for range [0, 2]: value = 3 + 3 = 6
    • Node for range [0, 4]: value = 6 + 9 = 15

Updating the Segment Tree

  1. Update arr[2] from 3 to 10:
    • Create new leaf node for arr[2]: value = 10
    • Update node for range [2, 2]: new value = 10
    • Update node for range [0, 2]: new value = 3 + 10 = 13
    • Update node for range [0, 4]: new value = 13 + 9 = 22

Querying the Segment Tree

  1. Query range [1, 3] in the original tree:
    • Node for range [1, 3] overlaps with range [0, 4], [0, 2], and [3, 4]
    • Query range [1, 3] results in sum = 2 + 3 + 4 = 9
  2. Query range [1, 3] in the updated tree:
    • Node for range [1, 3] overlaps with range [0, 4], [0, 2], and [3, 4]
    • Query range [1, 3] results in sum = 2 + 10 + 4 = 16

This example demonstrates how the persistent segment tree maintains different versions efficiently and supports range queries on any version.

Implementation of Persistent segment trees in Python

Below is the complete implementation of a Persistent Segment Tree in Python:

Python
class Node:
    def __init__(self, value=0, left=None, right=None):

        # Initialize a new node
        self.value = value  # The value stored in this node
        self.left = left    # Reference to the left child node
        self.right = right  # Reference to the right child node


def build(arr, left, right):

    # Function to build the initial segment tree
    if left == right:

        # If the current segment is a single element, create a leaf node
        return Node(value=arr[left])

    # Calculate the mid-point of the current segment
    mid = (left + right) // 2

    # Recursively build the left and right subtrees
    left_child = build(arr, left, mid)
    right_child = build(arr, mid + 1, right)

    # Create a new node with the sum of values of left and right children
    return Node(value=left_child.value + right_child.value, left=left_child, right=right_child)


def update(prev_node, left, right, idx, new_value):

    # Function to perform an update operation on the segment tree persistently
    if left == right:

        # If the current segment is a single element, create a new node with the updated value
        return Node(value=new_value)

    # Calculate the mid-point of the current segment
    mid = (left + right) // 2

    # Determine whether the index to be updated lies in the left or right subtree
    if idx <= mid:

        # Recursively update the left subtree, keep the right subtree unchanged
        left_child = update(prev_node.left, left, mid, idx, new_value)
        right_child = prev_node.right
    else:

        # Recursively update the right subtree, keep the left subtree unchanged
        left_child = prev_node.left
        right_child = update(prev_node.right, mid + 1, right, idx, new_value)

    # Create a new node with the sum of values of the updated left and right children
    return Node(value=left_child.value + right_child.value, left=left_child, right=right_child)


def query(node, left, right, query_left, query_right):

    # Function to perform a range query on the segment tree
    if query_left > right or query_right < left:
        # If the query range does not overlap with the current segment, return 0
        return 0

    if query_left <= left and right <= query_right:

        # If the current segment is completely within the query range, return the value of the current node
        return node.value

    # Calculate the mid-point of the current segment
    mid = (left + right) // 2

    # Recursively query the left and right subtrees and return the sum of results
    return query(node.left, left, mid, query_left, query_right) + query(node.right, mid + 1, right, query_left, query_right)


# Example usage
if __name__ == "__main__":

    # Initial array
    arr = [1, 2, 3, 4, 5]

    # Build the initial segment tree
    root = build(arr, 0, len(arr) - 1)

    # Create a new version with an update (change the value at index 2 to 10)
    new_root = update(root, 0, len(arr) - 1, 2, 10)

    # Query the original and new versions
    print(query(root, 0, len(arr) - 1, 1, 3))      # Output: 9 (2+3+4)
    print(query(new_root, 0, len(arr) - 1, 1, 3))  # Output: 16 (2+10+4)

Output
9
16

Complexity Analysis of Persistent Segment Tree:

OperationTime ComplexitySpace Complexity (per update)Description
BuildO(n log n)O(n log n)Construct the initial segment tree from an array of size n
UpdateO(log n)O(log n)Create a new version of the tree with an updated value at a specific index
QueryO(log n)O(1)Perform a range query on any version of the tree


Contact Us