Monday, March 28, 2011

Immutable Binary Trees

I have been reading a fair bit about functional programming recently, and the advantages you can get in a concurrent execution environment by using immutable data types (or persistent data structures).

While it's not too hard to wrap your head around an immutable list structure with an efficient prepend operation, I wanted to try implementing a slightly more complicated immutable structure (but not as complicated as a trie). So, I decided to try implementing an immutable binary tree in Java, and then work on improving it a bit.

Note that this is not a "good" implementation. In particular, I'm not going to bother rebalancing the tree, so it's very possible for many operations to be O(n) instead of the O(log n) that a balanced tree would provide.

The main idea that I had trouble fully grasping when learning about immutable data structures is the idea that you usually don't need to copy everything when adding/removing an element. So, let's look at an example of a binary tree in pictures and consider what elements need to be replaced and what can be reused when we add a new element to the tree.

Consider the following tree:



The rules are simple: at each node, values less than the current node's value are stored in the left subtree, while values greater are stored in the right subtree.

Now, suppose we want to add the number 11 to the tree. It should be added as a right child of the 10 node. Since the current 10 node has no children (and is itself an immutable tree), we must allocate a new 10 node with 11 as a right child. The new 10 node will be the left child of a new 13 node, which can reuse the old 16 node as its right child (along with the children of 16). Finally, a new 8 node must be allocated with its right child pointing to the new 13 node and its left child pointing to the existing 4 node.

Here is a picture with the newly-allocated nodes highlighted in red:



In total, four new nodes were allocated: one for each level in the tree. For a balanced tree, this generalizes to O(log n) node allocations, or the same runtime efficiency of a mutable binary tree. The remainder of the tree continues to point to the old values. Assuming we aren't holding onto a reference to the old tree root, the original 8, 13, and 10 nodes would now be eligible for garbage collection. (Though if another thread were accessing the old version of the tree, it would not be affected by this change, avoiding a potential race condition.)

Now, let's take a look at how to implement a simple immutable binary tree in Java.

To start with, I'll create a simple ImmutableSet interface (which doesn't follow the java.util.Set interface, since that is fundamentally based on mutability -- e.g. the "normal" add method has a void return type):



I've included a toList method for testing purposes, since it was easier than implementing the Iterable interface.

The fields and constructor of the ImmutableBinaryTree implementation are as follows:


All fields must be final in an immutable data structure.

Here is the implementation of the add method:


We create a leaf node for the new element and add it as a subtree to the existing tree. The addSubtree method is reused in the implementation of remove below, and could be used to implement a public addAll method that allows merging of two ImmutableBinaryTrees.



The remove method has two base cases (when we don't recurse into child branches). If the element is not in the tree, we return the unmodified this. If the current element is to be removed, we merge the left and right subtrees, arbitrarily choosing to add the right subtree as a descendent of the left (assuming the left subtree is not null, or else we simply return the right subtree).

The contains method is effectively the same as it would be in a mutable binary tree:


The toList method does an in-order traversal of the tree to return an ordered list:


For debugging purposes, and to visualize the tree depth, I decided to override toString:


Finally, here is a runner I used to compare the correctness of my implementation for addition/removal against the standard mutable TreeSet from the Java standard library:


There are several problems with this implementation that I plan to address in a future post:

  • The remove method creates a new node at each level, even if the element to remove does not exist.

  • The use of null to indicate the lack of a left or right child opens us up to the potential of NullPointerExceptions and wastes space from those fields. (In particular, in a balanced binary tree, half the elements will be leaf nodes.)
  • Currently, there is no way to create an empty tree, since every node has a value. If you look at the ImmutableTreeRunner implementation above, I had to generate my first random value before I could create the ImmutableBinaryTree.