Gotcha! Cartesian products of Java 8 streams

Occasionally, you need to compute the Cartesian product of several sets. Java 8 Streams make this simple, clear, elegant... and extravagantly expensive. I've created a Cartesian helper class that solves the memory usage problem, and allows Cartesian products to be streamed efficiently.

The problem: given n sets S₁, S₂, ..., Sn, produce a Stream of all possible items from the n-dimensional Cartesian product set S₁S₂⨯ ... ⨯ Sn.

...uh...

Ok, this is a lot easier to explain with an example. Let's suppose you want to generate all N-digit numbers having digits from 1 to M. If you're making 3-digit numbers having digits from 1 to 3, you'd get { 111, 112, 113, 121, ... }. This is the cartesian product {1,2,3} ⨯ {1,2,3} ⨯ {1,2,3}.

If you happen to speak Python, you might write a generator like this:

def permutations(num_dimensions):  
    if num_dimensions == 0:
        yield 0
    else:
        for tens in permutations(num_dimensions-1):
            for ones in range(1, top_digit+1):
                yield tens * 10 + ones

(I usually stick with Java, but this Python code is so incredibly legible that we can use it as pseudo-code, even if you don't know Python.)

You can see how we compute all d-digit numbers by taking every possible d−1-digit number and adding every possible digit on the end.

While nothing quite matches Python's readability in this example, Java's Stream.flatMap method does provide quite an elegant way to solve this problem:

for (int d = 0; d < NUM_DIMENSIONS; d++) {  
    permutations = permutations.flatMap(tens ->
        LongStream.rangeClosed(1, TOP_DIGIT)
        .map(ones-> 10*tens + ones));

As usual with Stream-based programming, there is a lot going on here in just a few lines of code, so let's examine this in detail. You're going to want to understand every word of this code before we go much further.

At the outer level, we see there is a loop repeatedly calling flatMap, which is a pretty nifty method. Unfortunately, I don't have the space here to explain flatMap in detail; if you are not familiar with it, definitely take a few minutes and look at the documentation for flatMap; it gives a precise explanation, and offers a a few good examples. I'll wait.

The role of the flatMap call in the code above is to add one more digit to the numbers in the stream. It does this by taking each of the incoming d-digit numbers and producing every possible d+1-digit number by multiplying the incoming number by 10 and adding every possible ones digit.

That's it. That's all you need to do to get your cartesian product.

...except for one major problem. If you run this loop, it produces the expected stream of N-digit numbers, but at a hefty price: this code computes the entire stream of cartesian product values before processing the first value.

The effect of this varies widely. In some cases, your code just runs more slowly. In other cases, depending on the exact form of your stream processing, your stream's entire contents could be buffered in memory before processing begins. This might merely stress the garbage collector a bit, or it might bring down the JVM with an OutOfMemoryError. Worst of all: if one of your dimensions happens to be infinite, this code will never terminate.

The problem is described in the Java bug I opened. Here's the full example program:

package org.vena.qb;

import java.util.stream.LongStream;

public class BugReport {

    static final int NUM_DIMENSIONS = 3;
    static final int TOP_DIGIT = 3;

    public static void main(String[] args) {
        LongStream permutations = LongStream.of(0L);
        for (int d = 0; d NUM_DIMENSIONS; d++) {
            permutations = permutations.flatMap(tens ->
                    LongStream.rangeClosed(1, TOP_DIGIT)
                    .map(ones-> 10*tens + ones)
                    .peek(v->System.out.println("+ "+v)));
        }
        long first = permutations.findFirst().getAsLong();
        System.out.println("First is " + first);
    }

}

Note that this version includes a call to peek, which permits us to inject a println for debugging. Very useful.

Ideally, this would produce the following output:

+ 1
+ 11
+ 111
First is 111  

In reality, it produces this:

+ 1
+ 11
+ 111
+ 112
+ 113
+ 12
+ 121
+ 122
+ 123
⋮
+ 322
+ 323
+ 33
+ 331
+ 332
+ 333
First is 111  

I tried a variety of slightly different idioms to produce the stream I wanted, but none of them gave the minimal output I had hoped for. Finally, I resorted to writing ordinary Java code to produce all the permutations, and wrapping that in a stream.

But how?

Conventional computation of a Cartesian product

The obvious way to enumerate a Cartesian product of n sets is with n nested loops. But what if you don't know n in advance?

The Python generator above does it with recursion. That was pretty simple. But what if your language doesn't support coroutines?

Well, you may be aware that it's possible to turn any recursive algorithm into an iterative one. Some algorithms, like Fibonacci, have simple iterative forms, while others (say, the Ackermann function) do not.

However, it's straightforward (if not necessarily simple) to turn any recursive algorithm into an iterative one if you keep in mind that you can always use an explicit stack data structure in place of the program's call stack. The idea is to take the data that would have been local variables in your recursive algorithm (the stack frames) and represent them explicitly as a stack data structure.

So let's start by writing a hypothetical recursive version that sends the values to a given action. This approximates how our Stream version will eventually work. It might look like this:

void recursive(int numDimensions, long digitsSoFar, Consumer<Long> action) {  
    if (numDimensions == 0) {
        action.accept(digitsSoFar);
    } else {
        for (int ones = 1; ones <= TOP_DIGIT; ones++) {
            recursive(numDimensions-1, 10 * digitsSoFar + ones, action);
        }
    }
}

That looks pretty nice, actually. Unfortunately, this code would send the entire stream to action, which is exactly what we're trying to avoid.

Spliterator: the heart of a Stream

The fundamental way to make a custom stream from scratch is to implement the Spliterator interface. This interface is like Iterator, with support for parallel processing. The documentation for this class is pretty abstract, so let me cut to the chase: if you're familiar with Iterator, you can make the same logic into a Spliterator as follows:

public class MySpliterator<T> extends AbstractSpliterator<T> {  
    protected MySpliterator(...) {
        super(
            Long.MAX_VALUE,  // Conservative size estimate
            0                // Conservative characteristics (ie. none)
        );
    }

    // You can pretend you're writing an Iterator...
    //
    boolean hasNext(){ ... }
    T       next()   { ... }

    // ... and now turn it into a Spliterator.
    //
    @Override
    public boolean tryAdvance(Consumer<T> action) {
        if (hasNext()) {
            action.apply(next());
            return true;
        } else {
            return false;
        }
}

Once you have your Spliterator, you can get a Stream using StreamSupport.stream:

StreamSupport.stream(new MySpliterator<T>(...), false);  

By extending AbstractSpliterator, this code even supports parallel processing, to some degree, though there are several ways to improve parallel performance. First, you'd want to supply a size estimate if that's convenient, and as many characteristic flags as you can; and then you'd want to provide your own implementation of trySplit that cuts the iteration space roughly in half. The trySplit inherited from AbstractSpliterator is not a terrible default considering it knows nothing at all about the nature of your stream, but you can usually do much better.

A Cartesian spliterator

As you can see, Spliterator's tryAdvance method should send just one value to action, and then return. That makes the implementation much more awkward than our recursive implementation, because we must explicitly remember where we are in the stream so we can continue the next time tryAdvance is called. The most straightforward way to do this is to use an explicit stack data structure. (As Mustafa mentioned before, Java spells stack "Deque" for historical reasons.)

The recursive implementation has three pieces of state at each level of recursion:

  • numDimensions
  • partialResult
  • onesDigit

If we're to turn this into an iterative algorithm, we can represent this state either as a stack of objects with these three fields, or as three stacks. I've opted to use the three-stacks approach because it's more memory-efficient, but both work.

To make this more efficient, there's a trick we can pull: we actually only need one of the dimensions to use a stack; the others can be arrays that we index using the stack's depth. (This amounts to implementing our own stacks using arrays.) Since our variables are primitives, this also avoids the auto-boxing overhead.

As it turns out, we also don't need to represent numDimensions explicitly. In this algorithm, numDimensions just serves as a depth counter that indicates when we have hit the base case. If we're using explicit stack data structures, we can tell that by inspecting the stack's depth.

Hence, we're left needing just one stack and one array. Here's the code:

static class CartesianDigits extends AbstractSpliterator<Long> {  
    final ArrayDeque<Long> partialResults = new ArrayDeque<>(NUM_DIMENSIONS);
    final int[] onesDigits = new int[NUM_DIMENSIONS];

    protected CartesianDigits() {
        super(Long.MAX_VALUE, 0);
        partialResults.push(0L);
        Arrays.fill(onesDigits, 1);
    }

    public static Stream<Long> stream(){ return StreamSupport.stream(new CartesianDigits(), false); }

    @Override
    public boolean tryAdvance(Consumer<? super Long> action) {
        if (partialResults.isEmpty()) {
            return false;
        } else {
            // "Recurse" to the bottom from wherever we happen to be
            //
            while (partialResults.size() <= NUM_DIMENSIONS) {
                int dim = partialResults.size()-1;

                // Compute the partial result the same way we did in the
                // recursive implementation
                //
                partialResults.push(partialResults.peek() * 10 + onesDigits[dim]);
            }

            // Top of the stack now has the full result.
            //
            action.accept(partialResults.pop());

            // Advance the counters, starting with the least significant, popping
            // obsolete partialValues for any exhausted dimensions we encounter.
            // This corresponds to the increment and exit test from the
            // recursive version's for loop
            //
            for (int dim = NUM_DIMENSIONS-1; dim >= 0; dim--) {
                if (++onesDigits[dim] <= TOP_DIGIT) {
                    break;
                } else {
                    // This is where the recursive implementation returns.
                    // We also re-initialize the "frame" we just left so it's
                    // ready to use next time we need it
                    //
                    partialResults.pop();
                    onesDigits[dim] = 1;
                }
            }

            return true;
        }
    }
}

If you call this code as follows, you will be delighted to find that it only does the minimum work necessary to generate the very first value:

Long first = CartesianDigits.stream().findFirst();  

So there you have it: an efficient Cartesian product stream.

A general Cartesian product spliterator

Now it's time to make a reusable library so we don't have to go through all this again. Below is a generalized version of the Cartesian product generator.

You use it as follows:

Stream<Long> permutations = Cartesian.productOf(  
    dimensions,                  // array of dimensions
    0L,                          // initial value
    (tens,ones)->10*tens+ones);  // combiner function

The Cartesian class operates by starting with a stream of just one value (the supplied initial value), and using the combiner function to expand it by one dimension at a time, until an output value is produced.

The dimensions parameter is an array of arrays of dimension members. For the digits example, we can compute that as follows:

Long[][] dimensions =  
    IntStream.range(0, NUM_DIMENSIONS).mapToObj(d->
    LongStream.rangeClosed(1, TOP_DIGIT)
        .mapToObj(v->v) // auto-boxing
        .toArray(s->new Long[s]))
    .toArray(s->new Long[s][]);

Here is the implementation of Cartesian:

public class Cartesian<T, C> extends AbstractSpliterator<C> {  
    final T[][] dimensions;
    final int[] currentLocation;
    final BiFunction<C,T,C> combiner;
    final ArrayDeque<C> partialValues;

    protected Cartesian(T[][] dimensions, C initialValue, BiFunction<C,T,C> combiner) {
        super(
            Stream.of(dimensions).mapToLong(d->d.length).reduce(Math::multiplyExact).orElse(0),
            SIZED | DISTINCT | IMMUTABLE);
        this.dimensions = dimensions;
        this.currentLocation = new int[dimensions.length];
        this.partialValues = new ArrayDeque<>(dimensions.length);
        this.combiner = combiner;
        this.partialValues.push(initialValue);
    }

    @Override
    public boolean tryAdvance(Consumer<? super C> action) {
        if (partialValues.isEmpty()) {
            return false;
        } else {
            // Populate any dimensions that need it
            //
            while (partialValues.size() <= dimensions.length) {
                int dim = partialValues.size() - 1;
                partialValues.push(
                    combiner.apply(
                        partialValues.peek(),
                        dimensions[dim][currentLocation[dim]]));
            }

            // Top of the partialValues stack is actually the full value.
            //
            action.accept(partialValues.pop());

            // Advance the counters, starting with the least significant, popping
            // obsolete partialValues for any exhausted dimensions we encounter.
            //
            for (int dim = dimensions.length-1; dim >= 0; dim--) {
                if (++currentLocation[dim] < dimensions[dim].length) {
                    break;
                } else {
                    partialValues.pop();
                    currentLocation[dim] = 0;
                }
            }

            return true;
        }
    }

    public static <T,C> Stream<C> productOf(T[][] dimensions, C initialValue, BiFunction<C,T,C> combiner) {
        return StreamSupport.stream(new Cartesian<>(dimensions, initialValue, combiner), false);
    }

}

Discuss on Hacker News

Vena is hiring in Toronto!
Learn about our culture, if you think you're a good fit, apply!