[OT] Generating distribution of N dice rolls
H. S. Teoh
hsteoh at qfbox.info
Fri Nov 11 01:27:07 UTC 2022
On Thu, Nov 10, 2022 at 03:15:24PM -0800, H. S. Teoh via Digitalmars-d wrote:
[...]
> Now, just for fun, I added a writeln before the `while (total != N)`
> to print out just how big of a discrepancy to N the array sum is.
> Turns out, quite a bit: for N = 100_000_000 as above, the
> discrepancies range from approximately -25000 to 25000, with the
> typical discrepancy being about 3-4 digits long, which means that the
> code is spending quite a bit of time in that final loop. Which
> suggests that a possible improvement is to recursively run the initial
> approximation step, setting sub_N = (N - array.sum).
>
> I'll leave that to a future iteration, though. Being able to compute
> a hundred million dice rolls in a split second is already good enough
> for what I need. :-D
[...]
OK, I couldn't resist, the idea is so tempting. So I made a new
implementation that iteratively uses the Box-Muller transform to close
the gap between array.sum and N, resorting to individual rolls only when
the difference is < k. For N = 100_000_000, it typically takes about
3-5 iterations to bring the difference down to < k, so the entire
algorithm takes about k+5 iterations to compute the result.
Now, in theory, I could just run the Box-Muller estimate until the
difference is 0, but I found that once the difference is small, it tends
to bounce around 0 several iterations before converging on 0. So I
arbitrarily decided to stop when the difference < k, and use individual
rolls to do the rest.
One wrinkle that came up is that for small values of N, sometimes the
result array can end up with negative elements, either due to
overcompensation during the iterative Box-Muller stage (discrepancy > 0
and the selected z values happen to be larger than the current array
element), or during the final adjustment loop (it picks an array element
that's already 0 and tries to decrement it). So I had to insert extra
checks to discard generated z values if they would cause the result
array to have negative counts. Generally, this happens only for small
values of N; for large N the initial estimate is large enough that it's
extremely unlikely for a later adjustment to overshoot into negative
values. To ensure the output is never negative, I added another
out-contract to check this.
Code:
------------
import std.algorithm;
import std.conv : text;
double[2] gaussianPoint(double[2] mean, double deviation)
{
import std.math : cos, log, sin, sqrt, PI, round;
import std.random : uniform01;
auto u = uniform01();
auto v = uniform01();
auto x0 = sqrt(-2.0*log(u)) * cos(2.0*PI*v);
auto y0 = sqrt(-2.0*log(u)) * sin(2.0*PI*v);
double[2] result;
result[0] = mean[0] + cast(int)round(deviation * x0);
result[1] = mean[1] + cast(int)round(deviation * y0);
return result;
}
/**
* Simulates rolling N k-sided dice.
*
* Returns: A static array representing how many of each of 1..(k+1) were
* obtained by the rolls.
*
* Bugs: The current implementation uses a naïve algorithm that's rather
* inefficient for large N. For our purposes, though, which involve only
* relatively small N, this is Good Enough(tm) for the time being. We can look
* into improving this if it becomes a performance bottleneck.
*/
int[] diceDistrib(int k, int N)
in (k > 0)
in (N > 0)
out (r; r.sum == N)
out (r; r.all!(c => c >= 0), r.text)
{
import std.math : abs, sgn, sqrt, round;
import std.range : chunks;
import std.random : uniform;
debug import std.stdio;
// Populate output array with initial values with the correct mean and
// standard deviation.
auto result = new int[k];
auto discrepancy = N - result.sum;
while (abs(discrepancy) > k)
{
debug writefln("discrepancy=%d", discrepancy);
const mean = cast(double)abs(discrepancy) / k;
const deviation = sqrt(mean * (1.0 - 1.0 / k));
double[2] means = [ mean, mean ];
double sign = sgn(discrepancy);
foreach (chunk; result.chunks(2))
{
do
{
auto z = gaussianPoint(means, deviation);
auto c0 = chunk[0] + cast(int) round(sign*z[0]);
if (c0 < 0)
continue; // discard nonsensical result
if (chunk.length > 1)
{
auto c1 = chunk[1] + cast(int) round(sign*z[1]);
if (c1 < 0)
continue; // discard nonsensical result
chunk[1] = c1;
}
chunk[0] = c0;
} while (false);
}
discrepancy = N - result.sum;
}
debug writefln("final discrepancy=%d", discrepancy);
// Tweak resulting array until it sums exactly to N.
auto total = result.sum;
while (total != N)
{
auto i = uniform(0, k);
if (total < N)
{
result[i]++;
total++;
}
else if (result[i] > 0)
{
result[i]--;
total--;
}
}
return result;
}
unittest
{
import std.stdio;
foreach (i; 0 .. 5)
{
auto N = 100_000_000;
auto dist = diceDistrib(6, N);
debug writefln("%s sum=%d", dist, dist[].sum);
assert(dist.sum == N);
}
foreach (i; 0 .. 5)
{
auto N = 10;
auto dist = diceDistrib(6, N);
debug writefln("%s sum=%d", dist, dist[].sum);
assert(dist.sum == N);
}
}
------------
Typical output:
------------
1 modules passed unittests
discrepancy=1000000000
discrepancy=-94251
discrepancy=287
final discrepancy=5
[166665591, 166651866, 166681043, 166654104, 166673559, 166673837] sum=1000000000
discrepancy=1000000000
discrepancy=20314
discrepancy=-88
discrepancy=-9
final discrepancy=6
[166668490, 166665588, 166648866, 166658164, 166684475, 166674417] sum=1000000000
discrepancy=1000000000
discrepancy=14062
discrepancy=18
final discrepancy=5
[166658497, 166674399, 166652111, 166667954, 166672801, 166674238] sum=1000000000
discrepancy=1000000000
discrepancy=15926
discrepancy=-175
discrepancy=-19
final discrepancy=-4
[166664463, 166659610, 166655243, 166678783, 166673507, 166668394] sum=1000000000
discrepancy=1000000000
discrepancy=39916
discrepancy=201
discrepancy=-16
final discrepancy=2
[166666774, 166663400, 166672276, 166686269, 166652312, 166658969] sum=1000000000
discrepancy=10
final discrepancy=-1
[2, 2, 2, 3, 0, 1] sum=10
discrepancy=10
final discrepancy=-5
[2, 1, 3, 2, 1, 1] sum=10
discrepancy=10
final discrepancy=-1
[1, 0, 1, 1, 4, 3] sum=10
discrepancy=10
final discrepancy=-3
[3, 2, 2, 1, 0, 2] sum=10
discrepancy=10
final discrepancy=-2
[1, 1, 2, 2, 2, 2] sum=10
------------
Note that the last 5 outputs are for a different test case (N=10), just
to make sure that we don't produce nonsensical outputs when N is small.
T
--
Leather is waterproof. Ever see a cow with an umbrella?
More information about the Digitalmars-d
mailing list