Good dotProduct

bearophile bearophileHUGS at lycos.com
Sun Jun 27 09:35:12 PDT 2010


> I am now able to write some working SSE* code, but I am not expert yet. So if someone is willing to write it (or even offer it, if already written), I think Andrei will be willing to add it to the std.numeric module.

My first try, writing X86+SSE asm is a pain for me still:


import std.c.stdio: printf;
import std.date: getUTCtime, ticksPerSecond;
import std.numeric: dotProduct;
import std.contracts: enforce;

double clock() {
    return cast(double)getUTCtime() / ticksPerSecond;
}

double dot1(double[] arr1, double[] arr2) {
    assert(arr1.length == arr2.length);
    double tot = 0.0;
    foreach (i, e1; arr1)
        tot += e1 * arr2[i];
    return tot;
}


double dot2(double[] a, double[] b) {
    size_t len = a.length;
    assert(len == b.length, "dot(): the two array lengths differ.");
    if (len == 0)
        return 0.0;
    if (len == 1)
        return a[0] * b[0];

    double tot = void;
    double* a_ptr = a.ptr;
    double* b_ptr = b.ptr;

    assert((cast(size_t)a_ptr & cast(size_t)0b1111) == 0,
           "dot(): the first array is not aligned to 16 bytes.");
    assert((cast(size_t)b_ptr & cast(size_t)0b1111) == 0,
           "dot(): the second array is not aligned to 16 bytes.");

    len = (len / 2) * 2 * double.sizeof;

    asm {
        mov EAX, a_ptr;   // EAX = a_ptr
        mov ECX, b_ptr;   // ECX = b_ptr
        mov EDI, len;     // EDI = len
        xorpd XMM0, XMM0; // XMM0[0,1] = 0.0, 0.0  (tot0,tot1)
        xor EDX, EDX;     // EDX = 0 (loop counter)

        align 8;
    LOOP_START:
        movapd XMM1, [EAX + EDX]; // XMM1[0,1] = *(EAX + EDX)[0,1]
        mulpd XMM1, [ECX + EDX];  // XMM1[0,1] *= *(ECX + EDX)[0,1]
        add EDX, 16;              // EDX += 16
        addpd XMM0, XMM1;         // XMM0[0,1] += XMM1[0,1]
        cmp EDI, EDX;
        jne LOOP_START;

        // XMM0[0] = XMM0[0] + XMM0[1]
        movapd  XMM1, XMM0;  // XMM1[0,1] = XMM0[0,1]
        unpckhpd XMM1, XMM1; // XMM1[0] = XMM1[1]
        addsd XMM0, XMM1;    // XMM0[0] += XMM1[0]

        movsd tot, XMM0;     // tot = XMM0[0]
    }

    if (a.length & 1)
        return tot + (a[$-1] * b[$-1]);
    else
        return tot;
}


void main() {
    enum int NLOOPS = 100_000;

    double[] a1 = [0.97939860007980784, 0.15654832543818165, 0.20876449456836543, 0.13588707622872687, 0.56737250028542408, 0.60890755422949261, 0.72503629431774808, 0.52283227344759831, 0.82107581846425648, 0.9280027000111094, 0.65371212119615851, 0.99162440025345067, 0.93134413423287143, 0.95319320272812291, 0.82373984947977308, 0.09382106227964937, 0.9914424038875832, 0.80601047736119313, 0.023619286739061329, 0.82167558946575081];
    a1 ~= a1;

    double[] a2 = [0.37967009318734157, 0.49961633977084308, 0.063379452228570665, 0.016573170529015635, 0.32720445135092779, 0.90558380684677242, 0.59689617644678783, 0.22590204202286546, 0.13701998912150426, 0.21786382155943662, 0.74110776773633547, 0.62437487889391807, 0.41013869338412479, 0.047723768990690196, 0.98567658092497179, 0.19281802583215202, 0.7937119206931792, 0.34128113271035532, 0.90960739148505643, 0.01852954991914102];
    a2 ~= a2;
    assert(a1.length == a2.length);

    double t0, t1;

    t0 = clock();
    double total0 = 0.0;
    foreach (i; 0 .. NLOOPS)
        foreach (len; 0 .. a1.length+1)
            total0 += dotProduct(a1[0 .. len], a2[0 .. len]);
    t1 = clock();
    printf("dotProduct: %.3f %f\n", t1-t0, total0);

    t0 = clock();
    double total1 = 0.0;
    foreach (i; 0 .. NLOOPS)
        foreach (len;  0 .. a1.length+1)
            total1 += dot1(a1[0 .. len], a2[0 .. len]);
    t1 = clock();
    printf("dot1:       %.3f %f\n", t1-t0, total1);
    enforce((total0 - total1) < 0.001);

    t0 = clock();
    double total2 = 0.0;
    foreach (i; 0 .. NLOOPS)
        foreach (len; 0 .. a1.length+1)
            total2 += dot2(a1[0 .. len], a2[0 .. len]);
    t1 = clock();
    printf("dot2:       %.3f %f\n", t1-t0, total2);
    enforce((total0 - total2) < 0.001);
}


Suggestions welcome.
This dot2() seems to work, but it can be improved in several ways:
- it needs to work with a and/or b unaligned to 16 bytes
- the loop can be unrolled once or more
- maybe the nonasm code can be improved a bit
- SSE detection branch can be added, if it's not too much slow

dot2() is slower than std.numeric.dotProduct for small arrays, so it can be added a test inside to detect short arrays and use normal D code on them. Unrolling its loop once can make it a little faster still for longhish arrays.

Is this not valid?
movapd XMM1, [EAX + EDX * 16];

Bye,
bearophile


More information about the Digitalmars-d mailing list