Good dotProduct
bearophile
bearophileHUGS at lycos.com
Sun Jun 27 09:35:12 PDT 2010
> I am now able to write some working SSE* code, but I am not expert yet. So if someone is willing to write it (or even offer it, if already written), I think Andrei will be willing to add it to the std.numeric module.
My first try, writing X86+SSE asm is a pain for me still:
import std.c.stdio: printf;
import std.date: getUTCtime, ticksPerSecond;
import std.numeric: dotProduct;
import std.contracts: enforce;
double clock() {
return cast(double)getUTCtime() / ticksPerSecond;
}
double dot1(double[] arr1, double[] arr2) {
assert(arr1.length == arr2.length);
double tot = 0.0;
foreach (i, e1; arr1)
tot += e1 * arr2[i];
return tot;
}
double dot2(double[] a, double[] b) {
size_t len = a.length;
assert(len == b.length, "dot(): the two array lengths differ.");
if (len == 0)
return 0.0;
if (len == 1)
return a[0] * b[0];
double tot = void;
double* a_ptr = a.ptr;
double* b_ptr = b.ptr;
assert((cast(size_t)a_ptr & cast(size_t)0b1111) == 0,
"dot(): the first array is not aligned to 16 bytes.");
assert((cast(size_t)b_ptr & cast(size_t)0b1111) == 0,
"dot(): the second array is not aligned to 16 bytes.");
len = (len / 2) * 2 * double.sizeof;
asm {
mov EAX, a_ptr; // EAX = a_ptr
mov ECX, b_ptr; // ECX = b_ptr
mov EDI, len; // EDI = len
xorpd XMM0, XMM0; // XMM0[0,1] = 0.0, 0.0 (tot0,tot1)
xor EDX, EDX; // EDX = 0 (loop counter)
align 8;
LOOP_START:
movapd XMM1, [EAX + EDX]; // XMM1[0,1] = *(EAX + EDX)[0,1]
mulpd XMM1, [ECX + EDX]; // XMM1[0,1] *= *(ECX + EDX)[0,1]
add EDX, 16; // EDX += 16
addpd XMM0, XMM1; // XMM0[0,1] += XMM1[0,1]
cmp EDI, EDX;
jne LOOP_START;
// XMM0[0] = XMM0[0] + XMM0[1]
movapd XMM1, XMM0; // XMM1[0,1] = XMM0[0,1]
unpckhpd XMM1, XMM1; // XMM1[0] = XMM1[1]
addsd XMM0, XMM1; // XMM0[0] += XMM1[0]
movsd tot, XMM0; // tot = XMM0[0]
}
if (a.length & 1)
return tot + (a[$-1] * b[$-1]);
else
return tot;
}
void main() {
enum int NLOOPS = 100_000;
double[] a1 = [0.97939860007980784, 0.15654832543818165, 0.20876449456836543, 0.13588707622872687, 0.56737250028542408, 0.60890755422949261, 0.72503629431774808, 0.52283227344759831, 0.82107581846425648, 0.9280027000111094, 0.65371212119615851, 0.99162440025345067, 0.93134413423287143, 0.95319320272812291, 0.82373984947977308, 0.09382106227964937, 0.9914424038875832, 0.80601047736119313, 0.023619286739061329, 0.82167558946575081];
a1 ~= a1;
double[] a2 = [0.37967009318734157, 0.49961633977084308, 0.063379452228570665, 0.016573170529015635, 0.32720445135092779, 0.90558380684677242, 0.59689617644678783, 0.22590204202286546, 0.13701998912150426, 0.21786382155943662, 0.74110776773633547, 0.62437487889391807, 0.41013869338412479, 0.047723768990690196, 0.98567658092497179, 0.19281802583215202, 0.7937119206931792, 0.34128113271035532, 0.90960739148505643, 0.01852954991914102];
a2 ~= a2;
assert(a1.length == a2.length);
double t0, t1;
t0 = clock();
double total0 = 0.0;
foreach (i; 0 .. NLOOPS)
foreach (len; 0 .. a1.length+1)
total0 += dotProduct(a1[0 .. len], a2[0 .. len]);
t1 = clock();
printf("dotProduct: %.3f %f\n", t1-t0, total0);
t0 = clock();
double total1 = 0.0;
foreach (i; 0 .. NLOOPS)
foreach (len; 0 .. a1.length+1)
total1 += dot1(a1[0 .. len], a2[0 .. len]);
t1 = clock();
printf("dot1: %.3f %f\n", t1-t0, total1);
enforce((total0 - total1) < 0.001);
t0 = clock();
double total2 = 0.0;
foreach (i; 0 .. NLOOPS)
foreach (len; 0 .. a1.length+1)
total2 += dot2(a1[0 .. len], a2[0 .. len]);
t1 = clock();
printf("dot2: %.3f %f\n", t1-t0, total2);
enforce((total0 - total2) < 0.001);
}
Suggestions welcome.
This dot2() seems to work, but it can be improved in several ways:
- it needs to work with a and/or b unaligned to 16 bytes
- the loop can be unrolled once or more
- maybe the nonasm code can be improved a bit
- SSE detection branch can be added, if it's not too much slow
dot2() is slower than std.numeric.dotProduct for small arrays, so it can be added a test inside to detect short arrays and use normal D code on them. Unrolling its loop once can make it a little faster still for longhish arrays.
Is this not valid?
movapd XMM1, [EAX + EDX * 16];
Bye,
bearophile
More information about the Digitalmars-d
mailing list