Saturday, April 21, 2012

Intro to Functional Programming (with examples in Java) - Part 4

Immutable data structures

In the past few posts, we've seen some interesting things that can be done with first-class and higher-order functions. However, I've glossed over the fact that functional programming also tends to rely on parameters being immutable. For example, in the previous post, memoization relies on an underlying HashMap, where the keys are function arguments. We assumed that the arguments implemented reasonable hashCode and equals methods, but things fall apart if those function arguments can be modified after they've been memoized. (In particular, their hashCode method should return a new value, but the existing memoization HashMap has already placed them in a bucket based on their previous hashCode.) Since functional programming dictates that the same function called with the same arguments should always return the same value, one of the easiest ways to guarantee that is to ensure that inputs and outputs are immutable.

Immutability also makes it easier to guarantee thread safety. If a given value is immutable, then any thread can access it, and read from it, without fearing that another thread may modify it (since it cannot be modified).

In some of my first few posts, I described an implementation of immutable binary trees. In this post, we're going to look at some simpler immutable data structures, namely ImmutableLists and Options.

Immutable Lists

For lists, we're going to implement a singly linked list structure that supports only two "modification" operations, prepend and tail. These add a new element at the head of the list, and return the list without the head element, respectively. These are both O(1) operations, since prepend simply creates a new head node whose tail is the previous list, while tail returns the list pointed to by the tail of the head node. Neither operation modifies the existing list, but rather returns "new" list reference (where prepend creates a new list reference, and tail returns a list reference that already exists). To extract that actual elements from the list, we'll have a single operation, head that returns the element at the front of the list. Readers familiar with Lisp should recognize head and tail as car and cdr. For convenient list traversal, we'll use the familiar Java Iterable interface, so we can use the friendly Java 1.5 "foreach" syntax. Also, for convenience, we'll define an isEmpty method.

As with the immutable binary trees, we're going to take advantage of immutability to assert that there is only one empty list (regardless of type arguments), and create a specialized singleton subclass for it, called EmptyList. Everything else will be a NonEmptyList. Here is the code:

public abstract class ImmutableList<T> implements Iterable<T> {
public abstract T head();
public abstract ImmutableList<T> tail();
public abstract boolean isEmpty();
public ImmutableList<T> prepend(T element) {
return new NonEmptyList<T>(element, this);
}
@Override
public Iterator<T> iterator() {
return new Iterator<T>() {
private ImmutableList<T> list = ImmutableList.this;
@Override
public boolean hasNext() {
return !list.isEmpty();
}
@Override
public T next() {
T element = list.head();
list = list.tail();
return element;
}
@Override
public void remove() {
// Java iterators must implement remove(), even though they
// make no sense for immutable collections. Fortunately,
// it is specified as an optional operation.
throw new RuntimeException("Cannot remove from immutable list");
}
};
}
public static <T> ImmutableList<T> nil() {
return EMPTY_LIST;
}
private static final EmptyList EMPTY_LIST = new EmptyList();
private static class EmptyList extends ImmutableList {
@Override
public Object head() {
throw new NoSuchElementException("head() called on empty list");
}
@Override
public ImmutableList tail() {
throw new NoSuchElementException("tail() called on empty list");
}
@Override
public boolean isEmpty() {
return true;
}
}
private static class NonEmptyList<T> extends ImmutableList<T> {
private final T element;
private final ImmutableList<T> tail;
private NonEmptyList(final T element, final ImmutableList<T> tail) {
this.element = element;
this.tail = tail;
}
@Override
public T head() {
return element;
}
@Override
public ImmutableList<T> tail() {
return tail;
}
@Override
public boolean isEmpty() {
return false;
}
}
}

That's not terribly complicated, but also a little unpleasant to use. In particular, the only way we have of creating a list is via the nil() factory method, which returns an empty list, followed by a series of prepend() operations. Let's add a factory method that takes a varargs parameter to produce a list. Since we only have the prepend operation at our disposal, we'll need to iterate through the varargs array backwards:

public abstract class ImmutableList<T> implements Iterable<T> {
/* ... previous method and static class definitions ... */
public static <T> ImmutableList<T> list(T... elements) {
ImmutableList<T> output = nil();
for (int i = elements.length - 1; i >= 0 ; i--) {
output = output.prepend(elements[i]);
}
return output;
}
}

Let's create a quick test to confirm that the list factory method and list iteration actually work:

import org.junit.Test;
import static collections.ImmutableList.list;
import static org.junit.Assert.assertEquals;
public class ImmutableListTest {
@Test
public void testListIteration() {
ImmutableList<Integer> list = list(1, 2, 3);
int i = 1;
for (Integer l : list) {
assertEquals(i++, l.intValue());
}
}
}

Now, since it's a frequently used operation, let's add a size method. The size of a list can be defined recursively, as follows: the size of the empty list is 0, the size of a non-empty list is the size of its tail plus 1. As a first try, we'll define size exactly that way:

public abstract class ImmutableList<T> implements Iterable<T> {
/* ... previous method and static class definitions ... */
public final int size() {
if (isEmpty()) {
return 0;
}
return 1 + tail().size();
}
}

That's nice and simple. Let's create a unit test for it:

/* ... previous imports ... */
public class ImmutableListTest {
/* ... previous test definitions ... */
// This tests the size() method using a relatively large
// list of 100k Bytes. A recursive implementation of
// size() won't be able to handle that many recursive
// calls on the stack.
@Test
public void testSize() throws Exception {
int size = 100000;
Byte[] bytes = new Byte[size];
for (int i = 0; i < size; i++) {
bytes[i] = (byte) (i % 256);
}
ImmutableList<Byte> byteList = list(bytes);
assertEquals(size, byteList.size());
}
}

Uh-oh -- at least on my machine, that test triggered a StackOverflowError. If it doesn't trigger a StackOverflowError on your machine (which should be unlikely -- the Java stack usually overflows around 50k calls), try increasing the value of the size local variable in the test. The problem is that our recursive definition is triggering 100k recursive calls to size(). But "Michael," I hear you say, "aren't recursive calls a fundamental keystone of functional programming?" Well, yes, but even in Scheme, Haskell, or Erlang, I believe this particular implementation would blow the stack. For those languages, the problem is that our method is not tail-recursive. A tail-recursive method/function is one where any recursive call is the last operation that executes. Since we add 1 to the result of the recursive call to size, that addition is the last operation to execute. Many functional languages implement what's called tail-call optimization, which basically turns tail-calls into a goto that returns to the beginning of the method. To clarify, let's first implement a tail-recursive version of size:

public abstract class ImmutableList<T> implements Iterable<T> {
/* ... previous method and static class definitions ... */
public final int size() {
return sizeHelper(this, 0);
}
// This method is tail-recursive. When it calls itself, it is the last
// operation before returning.
// Unfortunately, it still overflows the stack in Java.
private static <T> int sizeHelper(ImmutableList<T> input, int output) {
if (input.isEmpty()) {
return output;
}
return sizeHelper(input.tail(), output + 1);
}
}

In this case, we've implemented a helper method that keeps a so-called accumulator, that keeps track of the size up to now, in the recursive call stack. The recursive call to sizeHelper is simply modifying the input parameters, and jumping back to the beginning of the method (which is why tail-call optimization is so easy to implement). Unfortunately, Java compilers do not traditionally implement tail-call optimization, so this code will still cause a StackOverflowError on our unit test. Instead, we can simulate tail-call optimization by optimizing by hand:

public abstract class ImmutableList<T> implements Iterable<T> {
/* ... previous method and static class definitions ... */
public final int size() {
return sizeHelper(this, 0);
}
// This method compiles to roughly the same bytecode as the tail-recursive
// version of sizeHelper would, if the Java compiler supported tail-call
// optimization. (The while(true) basically gives us a "goto" back to the
// start of the method.)
private static <T> int sizeHelper(ImmutableList<T> input, int output) {
while (true) {
if (input.isEmpty()) {
return output;
}
output = output + 1;
input = input.tail();
}
}
}

If Java did support tail-call optimization, it would produce roughly the same bytecode as the code above. The code above passes the unit test with flying colours. That said, because Java does not support tail-call optimization, we are probably better off inlining sizeHelper to produce the following:

public abstract class ImmutableList<T> implements Iterable<T> {
/* ... previous method and static class definitions ... */
public final int size() {
ImmutableList<T> input = this;
int size = 0;
while (!input.isEmpty()) {
size++;
input = input.tail();
}
return size;
}
}

Note that Scala, a JVM language, would actually optimize the previous tail-call (by producing the goto bytecode to return to the start of the helper method), so we wouldn't have to optimize it into a while loop. That said, if you come from an object-oriented background, the optimized version may actually be more intuitive. Where Scala is not able to match "traditional" functional programming languages is with "mutual tail-calls", where function a tail-calls function b, which tail-calls function a. I believe you could implement this by hand in C by longjmping to the beginning of the body of the other function/procedure, but my understanding is that the JVM's builtin security requirements prevent you from gotoing to code outside the current method. Basically, the goto bytecode is only designed to make control statements (if, while, for, switch, etc.) work within the current method, I think. If I recall correctly, Clojure (another JVM language) is able to produce optimized mutual tail-calls if you use "trampolines", but I have no idea how those work at the bytecode level.

For another example of a function that we can implement with tail-recursion (and would be implemented with tail-recursion in languages that lack while loops), consider reversing a list. When we reverse a list, the first element becomes the last element, then we prepend the second element to that first element, then prepend the third element to that, etc. As a recursive function, we pass in the original list and an empty list as the initial accumulator value. We prepend the tail to the accumulator, and recurse with the tail of the original list and the new accumulator value:

public abstract class ImmutableList<T> implements Iterable<T> {
/* ... previous method and subclass definitions ... */
public ImmutableList<T> reverse() {
return reverseHelper(this, ImmutableList.<T>nil());
}
// This tail-recursive function "pours" values off the input stack
// onto the output stack. Unfortunately, it will also overflow the stack
// for a large input list.
private static <T> ImmutableList<T> reverseHelper(ImmutableList<T> input,
ImmutableList<T> output) {
if (input == nil()) {
return output;
}
return reverseHelper(input.tail(), output.prepend(input.head()));
}
}

As with size, this implementation of reverse will overflow the stack on large lists. Also, like with size, we can perform the tail-call optimization by hand in almost exactly the same way:

public abstract class ImmutableList<T> implements Iterable<T> {
/* ... previous method and subclass definitions ... */
public ImmutableList<T> reverse() {
ImmutableList<T> input = this;
ImmutableList<T> output = nil();
while (input != nil()) {
output = output.prepend(input.head());
input = input.tail();
}
return output;
}
}

Below, we will make extensive use of reverse, and use while loops to implement several other methods that would traditionally (in the functional world) be written with tail-recursive functions.

Anyway, let's move on to an even simpler collection type, which essentially makes null obsolete.

Option

The Option type is effectively a collection of size zero or one, with the semantics of "maybe" having a value. (In fact, the Haskell equivalent of Option is the Maybe monad. I personally like the name Maybe better, since the idea of a function that returns Maybe int nicely expresses that it may return an int, but might not. That said, I've decided to go with Scala's naming in this case, and call it Option.)

Why would we want to return or keep an Option value instead of just returning or keeping an object reference? We do this partly to explicitly document the fact that a method may not return a valid value, or that a field may not be initialized yet. In traditional Java, this is accomplished by returning null or storing a null object reference. Unfortunately, any object reference could be null, so this leads to the common situation of developers adding defensive null-checks everywhere, including on return values from methods that will never return null (since they are guaranteed to produce a valid result in non-exceptional circumstances). With Option, you're able to explicitly say "This method may not have a defined return value for all inputs." Assuming you and your team settle on a coding standard where you never return null from a method, you should be able to guarantee that a method that returns String returns a well-defined (not null) String, while a method that returns Option<String> might not.

As with our other immutable collections, we'll follow the same pattern: abstract base class (in this case Option), with a singleton subclass for the empty collection (called None) and a subclass for non-empty collections (called Some). As with immutable lists, there will be no public constructors, but instances will be made available through factory methods on the base Option class.

public abstract class Option<T> {
public abstract T get();
public abstract boolean isDefined();
// Factory method to return the singleton None instance
@SuppressWarnings({"unchecked"})
public static <T> Option<T> none() {
return NONE;
}
// Factory method to return a non-empty Some instance
public static <T> Option<T> some(final T value) {
return new Some<T>(value);
}
private static None NONE = new None();
private static class None extends Option {
// None has no element to return from get()
@Override
public Object get() {
throw new NoSuchElementException("get() called on None");
}
// None is never defined
@Override
public boolean isDefined() {
return false;
}
// We'll override toString() to make tests/debugging clearer
@Override
public String toString() {
return "None";
}
}
private static class Some<T> extends Option<T> {
private final T value;
// Some wraps an object value. It is up to the caller
// of the some() factory method above to ensure that
// value is not null.
public Some(final T value) {
this.value = value;
}
// Return the wrapped value
@Override
public T get() {
return value;
}
// Some is always defined
@Override
public boolean isDefined() {
return true;
}
// We'll override toString() to make tests/debugging clearer
@Override
public String toString() {
return "Some(" + value + ")";
}
// We didn't need to override equals() for None, since it is a
// singleton and referential equality is fine. Some needs to
// define equals() in terms of the contained value.
public boolean equals(Object other) {
return other instanceof Some && ((Some) other).value.equals(value);
}
}
}
view raw Option.java hosted with ❤ by GitHub

Now, if we have a method that returns an Option, we can call isDefined() before calling get() to avoid having an exception thrown. Of course, this is only marginally better than just using null pointers, since the Option warns us that we should call isDefined(), but we're still using the same clunky system of checks. It would be nicer if we could just say "Do something to the value held in this Option, if it is a Some, or propagate the None otherwise", or simply "Do something with a Some, and nothing with a None". Fortunately, we can do both of these. Let's do the second one first, by making Option iterable:

public abstract class Option<T> implements Iterable<T> {
/* ... previous method and subclass definitions ... */
@Override
public Iterator<T> iterator() {
return isDefined() ? Collections.singleton(get()).iterator() :
Collections.<T>emptySet().iterator();
}
}

Here is a test showing how we can iterate over Options to execute code only if there is an actual value:

import static collections.Option.none;
import static collections.Option.some;
import static org.junit.Assert.*;
public class OptionTest {
@Test
public void testIterable() throws Exception {
// For Some, the body of the for loop will execute
Option<Integer> a = some(5);
boolean didRun = false;
for (Integer i : a) {
didRun = true;
assertEquals(Integer.valueOf(5), i);
}
assertTrue(didRun);
// For None, it does not execute
Option<Integer> b = none();
for (Integer i : b) {
fail("This should not execute");
}
}
}

Using this idiom, we can wrap the None check and the extraction of the value into the for statement. It's a little more elegant than writing if (a.isDefined()) a.get(), and ties in nicely with the idea of thinking of Option as a collection of size 0 or 1.

We can also replace a common Java null idiom, where we use a default value if a variable is null, using the getOrElse method:

public abstract class Option<T> implements Iterable<T> {
/* ... previous method and subclass definitions ... */
public final T getOrElse(T defaultVal) {
return isDefined() ? get() : defaultVal;
}
}

Here is a test that shows how that works (and verifies the standard get() behaviour):

public class OptionTest {
/* ... previous tests ... */
@Test
public void testGet() throws Exception {
Option<Integer> a = some(5);
Option<Integer> b = none();
assertEquals(5, a.get().intValue());
try {
b.get();
fail("Should have thrown exception");
} catch (NoSuchElementException e) {
// Exception should have been thrown
}
// Some should return its own value
assertEquals(5, a.getOrElse(1).intValue());
// None returns the given default value
assertEquals(1, b.getOrElse(1).intValue());
}
}

Before getting into some other collection operations that we'll add to ImmutableList and Option, let's add a convenient factory method that wraps an object reference in Some or returns None, depending on whether the reference is null.

public abstract class Option<T> implements Iterable<T> {
/* ... previous method and subclass definitions ... */
public static <T> Option<T> option(T value) {
if (value == null) {
return none();
}
return some(value);
}
}

And the test:

/* ... previous imports ... */
import static collections.Option.option;
public class OptionTest {
/* ... previous tests ... */
@Test
public void testWrapNull() throws Exception {
Map<String, Integer> scoreMap = new HashMap<String, Integer>();
scoreMap.put("Michael", 42);
Option<Integer> a = option(scoreMap.get("Michael"));
assertTrue(a.isDefined());
Option<Integer> b = option(scoreMap.get("Bob"));
assertFalse(b.isDefined());
// Here is the ugly "traditional" Java way of dealing with Maps
Integer score = scoreMap.get("Michael");
if (score != null) {
System.out.println("Michael's score is " + score);
}
// This feels more elegant to me
for (Integer myScore : option(scoreMap.get("Michael"))) {
System.out.println("Michael's score is " + score);
}
}
}

Higher-order Collection Methods

In the second post in this series, we saw the map and fold higher-order functions. Since these require a List parameter anyway (and, specifically, according to my previous implementation, one of those nasty mutable java.util.Lists), for convenience, we can add map and fold operations to our immutable collections as methods. While we're at it, we'll add two more higher-order functions, filter and flatMap. Since all of our collections should support for-each syntax, we'll specify the contract for our methods as an interface extending Iterable.
public interface AugmentedIterable<T> extends Iterable<T> {
// Apply function f to each element in the collection (in order), collecting
// the results in a new collection, which is returned.
<R> AugmentedIterable<R> map(Function1<R, T> f);
// Apply function f to the seed value and the first element in the
// collection, then apply f to the result and the second element in
// the collection, then apply f to that result and the third element in
// the collection, etc. returning the final computed result.
// If the collection is empty, then return the seed value.
<R> R foldLeft(Function2<R, R, T> f, R seed);
// Apply function f to the last element in the collection and the seed
// value, then apply f to the second-last element and the previous result,
// then apply f to the third-last element and that result, etc.
// returning the final computed result.
// If the collection is empty, then return the seed value.
<R> R foldRight(Function2<R, T, R> f, R seed);
// Apply function f to each element in the collection (in order), and
// collect the values returned from each f application, to be returned in
// a new collection.
<R> AugmentedIterable<R> flatMap(Function1<? extends Iterable<R>, T> f);
// Return the elements of the this collection (in their original order)
// that return true for the given predicate.
AugmentedIterable<T> filter(Function1<Boolean, T> predicate);
}

We'll implement these methods on ImmutableList and Option. Let's look at map first:

public abstract class ImmutableList<T> implements AugmentedIterable<T> {
/* ... previous method and subclass definitions ... */
@Override
public final <R> ImmutableList<R> map(Function1<R, T> f) {
ImmutableList<T> input = this;
ImmutableList<R> output = nil();
// This traversal constructs the output list by
// prepending, yielding a list in reverse order.
while (!input.isEmpty()) {
output = output.prepend(f.evaluate(input.head()));
input = input.tail();
}
// Use reverse to get back the original order.
return output.reverse();
}
}
public abstract class Option<T> implements AugmentedIterable<T> {
/* ... previous method and subclass definitions ... */
public final <R> Option<R> map(Function1<R, T> f) {
return isDefined() ? option(f.evaluate(get())) : Option.<R>none();
}
}
view raw Map.java hosted with ❤ by GitHub

To establish that these map methods word, let's try mapping a function that doubles integers:

public class ImmutableListTest {
/* ... previous test methods ... */
@Test
public void testMap() throws Exception {
ImmutableList<Integer> list = list(1, 2, 3);
Function1<Integer, Integer> dbl = new Function1<Integer, Integer>() {
@Override
public Integer evaluate(final Function1<Integer, Integer> self,
final Integer i1) {
return i1 * 2;
}
};
ImmutableList<Integer> doubledList = list.map(dbl);
assertEquals(list(2, 4, 6), doubledList);
}
}
public class OptionTest {
/* ... previous test methods ... */
@Test
public void testMap() throws Exception {
Option<Integer> a = some(5);
Option<Integer> b = none();
Function1<Integer, Integer> dbl = new Function1<Integer, Integer>() {
@Override
public Integer evaluate(final Function1<Integer, Integer> self,
final Integer i1) {
return i1 * 2;
}
};
assertEquals(some(10), a.map(dbl));
assertEquals(Option.<Integer>none(), b.map(dbl));
}
}
view raw MapTest.java hosted with ❤ by GitHub

Next, we'll move on to the two folds. For an ImmutableList, the foldLeft operation traverses the list in the natural order, while a foldRight is easier if we reverse the list first. For Option, since there is (at most) one element, the only difference between the folds is the order in which parameters are passed to the function argument. To distinguish between foldLeft and foldRight, we need to use a non-commutative operation, so we'll use string concatenation in the tests.

public abstract class ImmutableList<T> implements AugmentedIterable<T> {
/* ... previous method and subclass definitions ... */
@Override
public final <R> R foldLeft(Function2<R, R, T> f, R seed) {
ImmutableList<T> input = this;
R output = seed;
while (!input.isEmpty()) {
output = f.evaluate(output, input.head());
input = input.tail();
}
return output;
}
@Override
public final <R> R foldRight(Function2<R, T, R> f, R seed) {
// Reverse the list so that we can traverse from right to left
// by traversing from left to right, effectively reducing the
// problem to foldLeft.
ImmutableList<T> input = this.reverse();
R output = seed;
while (!input.isEmpty()) {
output = f.evaluate(input.head(), output);
input = input.tail();
}
return output;
}
}
public abstract class Option<T> implements AugmentedIterable<T> {
@Override
public final <R> R foldLeft(final Function2<R, R, T> f, final R seed) {
return isDefined() ? f.evaluate(seed, get()) : seed;
}
@Override
public final <R> R foldRight(final Function2<R, T, R> f, final R seed) {
return isDefined() ? f.evaluate(get(), seed) : seed;
}
}
view raw Fold.java hosted with ❤ by GitHub
public class ImmutableListTest {
/* ... previous test methods ... */
@Test
public void testFold() {
Function2<String, String, String> concat =
new Function2<String, String, String>() {
@Override
public String evaluate(
final Function2<String, String, String> self,
final String i1, final String i2) {
return i1 + i2;
}
};
ImmutableList<String> list = list("a", "b", "c");
assertEquals("dabc", list.foldLeft(concat, "d"));
assertEquals("abcd", list.foldRight(concat, "d"));
ImmutableList<String> emptyList = list();
assertEquals("d", emptyList.foldLeft(concat, "d"));
assertEquals("d", emptyList.foldRight(concat, "d"));
}
}
public class OptionTest {
/* ... previous test methods ... */
@Test
public void testFold() throws Exception {
Function2<String, String, String> concat =
new Function2<String, String, String>() {
@Override
public String evaluate(final Function2<String, String, String> self,
final String i1, final String i2) {
return i1 + i2;
}
};
Option<String> a = some("a");
assertEquals("ba", a.foldLeft(concat, "b"));
assertEquals("ab", a.foldRight(concat, "b"));
Option<String> b = none();
assertEquals("b", b.foldLeft(concat, "b"));
assertEquals("b", b.foldRight(concat, "b"));
}
}
view raw FoldTest.java hosted with ❤ by GitHub

The filter method returns a sub-collection, based on a predicate that gets evaluated on every element of the original collection. If the predicate returns true for a given element, that element is included in the output collection. For Option, filter will return a Some if and only if the original Option was a Some and the predicate returns true for the wrapped value. For an ImmutableList, the code for filter is similar to map, but we only copy into the accumulator the values that satisfy the predicate.

public abstract class ImmutableList<T> implements AugmentedIterable<T> {
/* ... previous method and subclass definitions ... */
@Override
public final ImmutableList<T> filter(Function1<Boolean, T> predicate) {
ImmutableList<T> input = this;
ImmutableList<T> output = nil();
// This traversal constructs the output list by
// prepending, yielding a list in reverse order.
while (!input.isEmpty()) {
if (predicate.evaluate(input.head())) {
output = output.prepend(input.head());
}
input = input.tail();
}
// Use reverse to get back the original order.
return output.reverse();
}
}
public abstract class Option<T> implements AugmentedIterable<T> {
/* ... previous method and subclass definitions ... */
public final Option<T> filter(Function1<Boolean, T> predicate) {
if (isDefined() && predicate.evaluate(get())) {
return this;
}
return none();
}
}
view raw Filter.java hosted with ❤ by GitHub
public class ImmutableListTest {
/* ... previous test methods ... */
@Test
public void testFilter() throws Exception {
ImmutableList<Integer> list = list(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
Function1<Boolean, Integer> isEven = new Function1<Boolean, Integer>() {
@Override
public Boolean evaluate(final Function1<Boolean, Integer> self,
final Integer i1) {
return i1 % 2 == 0;
}
};
assertEquals(list(2, 4, 6, 8, 10), list.filter(isEven));
}
}
public class OptionTest {
/* ... previous test methods ... */
@Test
public void testFilter() throws Exception {
Option<Integer> a = some(2);
Option<Integer> b = some(1);
Option<Integer> c = none();
Function1<Boolean, Integer> isEven = new Function1<Boolean, Integer>() {
@Override
public Boolean evaluate(final Function1<Boolean, Integer> self,
final Integer i1) {
return i1 % 2 == 0;
}
};
// a is defined and even
assertEquals(some(2), a.filter(isEven));
// b is defined, but not even
assertEquals(Option.<Integer>none(), b.filter(isEven));
// c is undefined
assertEquals(Option.<Integer>none(), c.filter(isEven));
}
}
view raw FilterTest.java hosted with ❤ by GitHub

Our last higher-order method is a little more complicated. flatMap takes a function that returns collections, maps it across all elements in our collection, and then "flattens" the result down to a collection of elements. Say you have a function that takes an Integer and returns an ImmutableList. If you map that function over an ImmutableList<Integer>, the result is an ImmutableList<ImmutableList<Integer>>, or a list of lists of integers. However, if you flatMap it, you get back an ImmutableList<Integer> consisting of the concatenation of all of the lists you would get using map. Interestingly, both map and filter can be implemented in terms of flatMap. For map, we simply return wrap the return value from the passed function in single-element lists (or Some), while for filter we return a single-element list (or Some) for elements where the predicate evaluates to true and empty lists (or None) for elements where the predicate is false. Here are implementations and tests for flatMap:

public abstract class ImmutableList<T> implements AugmentedIterable<T> {
/* ... previous method and subclass definitions ... */
public final <R> ImmutableList<R>
flatMap(Function1<? extends Iterable<R>, T> f) {
ImmutableList<T> input = this;
ImmutableList<R> output = nil();
while (!input.isEmpty()) {
// Unlike with map(), we know that the return value
// from f is itself iterable. So, we iterate and
// accumulate those values.
Iterable<R> rs = f.evaluate(input.head());
for (R r : rs) {
output = output.prepend(r);
}
input = input.tail();
}
return output.reverse();
}
}
public abstract class Option<T> implements AugmentedIterable<T> {
/* ... previous method and subclass definitions ... */
public final <R> Option<R> flatMap(Function1<? extends Iterable<R>, T> f) {
if (!isDefined()) {
return none();
}
Iterable<? extends R> val = f.evaluate(get());
Iterator<? extends R> iter = val.iterator();
if (iter.hasNext()) {
Option<R> result = option(iter.next());
if (iter.hasNext()) {
// Option.flatMap only works with functions that return at most
// one element.
throw new RuntimeException("Function passed to Option flatMap " +
"returns more than one element");
}
return result;
}
return none();
}
}
view raw FlatMap.java hosted with ❤ by GitHub
public class ImmutableListTest {
/* ... previous test methods ... */
@Test
public void testFlatMap() throws Exception {
ImmutableList<Integer> list = list(1, 2, 3);
// oneToN :: n -> [1..n]
Function1<ImmutableList<Integer>, Integer> oneToN =
new Function1<ImmutableList<Integer>, Integer>() {
@Override
public ImmutableList<Integer>
evaluate(final Function1<ImmutableList<Integer>, Integer> self,
final Integer i1) {
ImmutableList<Integer> list = ImmutableList.nil();
for (int i = i1; i >= 1; i--) {
list = list.prepend(i);
}
return list;
}
};
assertEquals(list(1, 1, 2, 1, 2, 3), list.flatMap(oneToN));
}
}
public class OptionTest {
/* ... previous test methods ... */
@Test
public void testFlatMap() throws Exception {
Option<Integer> a = some(5);
Option<Integer> b = none();
Option<Integer> c = some(0);
// safeDivide prevents divByZero errors by returning None
Function2<Option<Integer>, Integer, Integer> safeDivide =
new Function2<Option<Integer>, Integer, Integer>() {
@Override
public Option<Integer>
evaluate(final Function2<Option<Integer>, Integer, Integer> self,
final Integer i1, final Integer i2) {
if (i2 == 0) {
return none();
}
return some(i1 / i2);
}
};
// 10 / 5 returns 2
assertEquals(some(2), a.flatMap(safeDivide.apply(10)));
// 10 / none returns none
assertEquals(Option.<Integer>none(), b.flatMap(safeDivide.apply(10)));
// 10 / 0 returns none
assertEquals(Option.<Integer>none(), c.flatMap(safeDivide.apply(10)));
}
}

A nice way of combining these two collection types is to flatMap a function that returns Option across an ImmutableList. The elements that evaluate to None simply disappear, while the Some values are fully realized (as in, you don't need to invoke get() on them).

Summary

In this post, we got to see a couple of immutable collection types.

While ImmutableList as I've implemented it takes a performance hit by using its own interface (since it tends to have to call reverse to provide consistent ordering), there are a couple of improvements that could be made. One would be to provide additional implementations of the methods that would contain all of the logic before the call to reverse, for the cases where order doesn't matter, and then implement the ordered versions by calling the unordered version followed by a call to reverse. The other would be to make NonEmptyList's tail member non-final, and make use of our private access to implement efficient appends by the higher-order methods. The code for that would be a fair bit more verbose (since we would have to hold a reference to the element preceding the terminal EmptyList). That said, it's possible that the additional operations per iteration might outweigh the cost of traversing (and creating) the list twice.

I think Option is one of my favourite types from Scala. The idea that Option gives you a warning through the type system that a value may not be initialized is fantastic. Sure, you're wrapping your object in another object (at a cost of 8 byes on a 32-bit JVM, I believe), but (with some discipline) you have some assurance that things that don't return Option will not return null. With the option() factory method, you can even wrap the return values from "old-fashioned" Java APIs. Unfortunately, I haven't yet seen a language where NullPointerExceptions are completely unavoidable (or aren't masked by treating unitialized values as some default). Since Scala works on the JVM, null still exists. Haskell has the "bottom" type (which actually corresponds more closely to Scala's Nothing), which effectively takes the place of null (as I understand it -- I'm a dabbler at best in Haskell), though the use of bottom is heavily discouraged, from what I've read.

In between these collection types, I managed to sneak in an explanation of tail-call optimization, and how it can be used to avoid blowing your stack on recursive calls on a list. Tail-calls are not restricted to recursive functions (or even mutually-recursive sets of functions), but the stack-saving nature of tail-call optimizations tend to be most apparent when recursion is involved. The other bonus of TCO (with apologies for my vague memories of some form of assembly) is that the tail-call gets turned into a JMP ("jump" or "goto"), rather than a JSR ("jump, set return"), so the final evaluation can immediately return to the last spot where a tail-call didn't occur. That said, I should remind you that tail-call optimization only works when you're delegating your return to another method/function/procedure (that is, the compiler doesn't need to push your current location onto the stack in order to return). When you're traversing a binary tree, and need to visit the left and right children via recursion, at least the first recursive call is not a tail-call. Of course, in Java, with default JVM settings, in my experience, the stack seems to overflow around 50k calls. If you have a balanced binary tree with 50k levels, then you probably have enough elements to make traversal outlast the earth being consumed by the sun, in which case a StackOverflowError is the least of your worries, since you and everyone you love will be long dead. (Actually, since the recursion is likely to be depth-first, you're lucky -- the stack will overflow quite quickly, and you can call your mom to remind her that you love her and you're glad that she didn't die while your program ran. Moms love that kind of thing.)

For some really interesting information about tail-call optimization on the JVM at the bytecode level, I suggest I you read John Rose's blog post on the subject from 2007.

I think this may be the longest post in the series so far. In particular, I've gone whole-hog on using unit tests to both establish that my code actually works (give or take copy-paste errors between my IDE and GitHub gists) and show examples of how to use these methods.

I must confess that this post was also more comfortable to write. Java doesn't make writing functions as objects terribly pleasant (since it wasn't really designed for it). I think the code in the last couple of posts involved more type arguments than actual logic. Working with data types is much easier, by comparison.

Where do we go from here?

I had been thinking about trying to implement a mechanism to combine lazy-loading of Function0 values with parallel execution, by replacing uninitialized values with Futures linked to a running thread on a ForkJoinPool. Unfortunately, the more I think about it, the more I realize that "parallelize everything" is a really bad idea (which is why nobody actually does it). I might be able to do it like memoization, where you can selectively memoize certain functions that are frequently used on a small set of inputs, so you would be able to parallelize only your "hot spot" functions. Unfortunately, my memoization implementation is kind of an ugly hack, requiring a parallel MemoizedFunctionN hierarchy, and I'm not particularly enthusiastic about creating a separate ParallelizedFunctionN hierarchy. Sure, I can write code generators for these different function specializations, but it's still not particularly elegant.

The last remaining "standard" functional programming concept I can think of that I haven't covered (besides pattern matching, which I don't think I can implement nicely in Java) is tuples. If I dedicate a post exclusively to tuples, I think it should be fairly short. After these last couple of lengthy posts, that may be a good way to go.