This is a guest post by Fabian Giesen. Actually it’s an off-the-cuff email, so please tolerate that it may be a little sketchy. In the next post about working with Braid’s particle systems, I want to visit the idea of fast-forwarding LCGs. Fabian really ran with the idea here, and since I am going to summarize the idea in the next post, it felt best to have the full piece here for reference. -J. Blow, 8 July 2016

Just realized an interesting (maybe?) twist on this. So the basic idea for the fast-forward is just that a LCG-style update

x_{k+1} = a*x_k + b

has

x_{k+2} = a*x_{k+1} + b = a*(a*x_k + b) + b
        = (a^2)*x_k + (a*b+b)

so if you describe a LCG by its coeffs (a,b), then stepping it at twice the rate yields another LCG with coeffs (a^2, (a+1)*b). (This is the direct approach, but of course you can also just do it by writing the LCG update in homogeneous matrix from and multiplying it out as Jon suggests). You’re back in the same form you started with and now have coeffs for a LCG that proceeds twice as fast as the source LCG.

So the basic fast-forward is just the fast-exponentiation style:

uint cur_a = a;
uint cur_b = b;
uint x = x0;

while (n) {
    if (n & 1)
        x = cur_a*x + cur_b;
    cur_b *= cur_a + 1;
    cur_a *= cur_a;
    n >>= 1;
}

Okay, that’s background.

But we don’t generally have “a” and “b” different for every invocation; usually we pick one set of LCG coeffs and stick with them as our RNG, which means the sequence of “cur_a” and “cur_b” values is completely predetermined.

// init (or just bake into table ahead of time)
// good for 32-bit "n"
uint lcg_a_bit[32], lcg_b_bit[32];
uint cur_a = a, cur_b = b;

for (int i = 0; i < 32; ++i) {
    lcg_a_bit[i] = cur_a;
    lcg_b_bit[i] = cur_b;
    cur_b *= cur_a + 1;
    cur_a *= cur_a;
}

// runtime eval
uint x = x0;
for (int bit = 0; (n >> bit) != 0; ++bit)
    if ((n >> bit) & 1)
        x = lcg_a_bit[bit]*x + lcg_b_bit[x];

So here’s the actual realization: if you’re willing to live with a bit of precomputation, there’s really no need to limit yourself to binary steps. You might as well do multiple bits at a time, something like:

// runtime eval only this time
uint x = x0;
for (int i = 0; n != 0; ++i, n >>= 4)
    x = lcg_a_tab[i][n & 15]*x + lcg_b_tab[i][n & 15];

where lcg_?_tab[i][j] contains the coeffs to step the LCG (j16^i) times; this particular choice of parameters needs (32/4) * (2^4) = 816 = 128 table entries (so 256 uint32s for a 32-bit LCG).

This is still O(log N), but it’s a log to a larger base, so the constant factor is smaller. You can get relatively far with a moderate table size, which is fine as long as all those calcs happen in a relatively tight loop so the table stays warm in L1. For example, 8 bits at a time needs (32/8) * 2^8 = 1024 table entries working out to 8KB for a 32-bit LCG, which then allows you to step the LCG to an arbitrary position in exactly 4 (serial) multiply-adds.

-

That part I think is kinda cool and still reasonably simple. This is the part where it gets a bit rambly.

The computation above is very serial because we funnel everything directly into x. There’s more latent parallelism if we manipulate the generators directly. At this point it probably is easier to switch to matrix form for clarity: (ASCII art, view with fixed-width font)

     x_1 = a*x_0 + b

-> [x_1]   [a  b] [x_0]
   [ 1 ] = [0  1] [ 1 ]

the stuff we’ve been doing so far is raising the matrix (let’s call it A) to the n’th power. I’ll also call the RHS vector ([x_0; 1]) just “x” (hopefully not too confusing).

So what I’ve described so far is basically splitting

n = ... + n_2 + n_1 + n_0

and then computing

  A^n * x
= A^(... + n_2 + n_1 + n_0) * x
= ... * (A^(n_2) * (A^(n_1) * (A^(n_0) * x))).

In the binary split, the n_i are just (i’th bit of n)*2^i.

Computing it this way (immediately applying the matrices to the vector) is cheaper in terms of op count, but you can break the dependency chain up somewhat if you use associativity and set the parens differently. Say we have:

n = n_3 + n_2 + n_1 + n_0

then we can compute it as

((A^(n_3) * A^(n_2)) * (A^(n_1) * A^(n_0))) * x

The n3-by-n2 and the n1-by-n0 matrix mult can be done in parallel in this version, which is why this can help.

If we write down “LCG arithmetic” for convenience:

struct LCG {
    uint mul, add;
};

// this is just the matrix multiply
LCG mul_lcg(LCG a, LCG b)
{
    LCG x = { a.mul*b.mul, a.mul*b.add + a.add };
    return x;
}

Then we can have an alternative version of the fast exponentiation loop that is left-associative rather than right-associative (i.e. multiplies out all the matrices and then multiplies the result by the vector in the end:

LCG combined = { 1, 0 }; // identity
LCG cur = base_lcg;

// accumulate LCG matrix
while (n) {
    if (n & 1)
        combined = mul_lcg(cur, combined);
    cur = mul_lcg(cur, cur);
    n >>= 1;
}

// final eval
x = combined.mul*x0 + combined.add;

So far, not very appealing; as-is, this does more work. But it allows splitting into several dependency chains and/or vectorization within a single LCG evaluation; e.g. you can do stuff like this:

LCG combined_lo = { 1, 0 };
LCG combined_hi = { 1, 0 };
LCG cur_lo = base_lcg;
LCG cur_hi = base_lcg_pow16; // precomputed

// accumulate "low" and "high" halves in parallel
// split at bit 16 of n (no early out this time)
for (int i = 0; i < 16; ++i) {
    if (n & (1<< 0)) combined_lo = mul_lcg(cur_lo, combined_lo);
    if (n & (1<<16)) combined_hi = mul_lcg(cur_hi, combined_hi);
    cur_lo = mul_lcg(cur_lo, cur_lo);
    cur_hi = mul_lcg(cur_hi, cur_hi);
    n >>= 1;
}

// final eval
x = combined_hi.mul*(combined_lo.mul * x0 + combined_lo.add) + combined_hi.add;

here this is written scalar, but the (lo,hi) pair calcs could be SIMDed, and you can do more than 2 of these at a time. If you’re going wider, you probably want to do the “final eval” part using a reduction tree of matrix mults.

…that said, all of this is interesting strictly when viewed in isolation; i.e. you have individual “gimme LCG state at point n” calls.

If they are in a loop, you don’t need any heroics to extract parallelism; multiple iterations of the loop can overlap all by themselves. (And e.g. SIMDing over iters is gonna work a lot better than making that stuff I just wrote work.)

But I think it’s definitely cute.