Expression templates

Don Clugston dac at nospam.com.au
Wed Dec 13 21:35:37 PST 2006


Bill Baxter wrote:
> Don Clugston wrote:
>> Don Clugston wrote:
>>> The recent language improvements (especially, opAssign) have made 
>>> expression templates eminently feasible. Here's a proof-of-concept which
>>> implements parsing of the basic vector operations of addition and 
>>> scalar multiplication.
>>
>> I've made a slightly improved version...
> 
> I'm making some headway on an nd-array class that uses BLAS/LAPACK under 
> the hood.  It would be nifty if I could somehow get this:
> 
>      R = A * x + y
> 
> to be groked and translated to a single call to of BLAS's 
> saxpy/daxpy/caxpy/zaxpy routines.
> 
> With all the EAX talk, it looks like you're talking about a different 
> level of optimization.  The theory with BLAS anyway, is that you'll have 
> access to a hand-tuned version for the platform you're on, so it'll be 
> hard to beat with any kind of register-level compile-time 
> micromanagement.  For instance BLAS level 3 (matrix-matrix) routines do 
> blocking to make things easier on the cache for big inputs.  And of 
> course there are also parallel BLAS implementations around, too.

Quite true. But the code below spits out very-near optimal asm for x87,
with double[] vectors with +, -, and scalar multiplication.
For example, the code generated for something like

a = b + 3.5 * c - 7 *(d - e)*s;

where a, b, c, d, e are vectors, and s is a scalar, will (I think) leave 
BLAS for dead. Think of it as extended BLAS1 kernel.
However, it shouldn't be to difficult to generate calls to DAXPY and kin.
-------

/*
* Supports vector addition, subtraction, and multiplication by a real 
scalar.
* For functions of the form in BLAS1, the code is optimal for early 
Pentium CPUs, or very nearly so.
* For example, the code generated for the crucial DAXPY operation is almost
* identical to that described in Agner Fog's optimisation manual 
(www.agner.org),
* and is faster than the unrolled solution published by Intel.
*
* Future directions:
*   Actually generate the code (instead of just displaying it)! (tricky).
*   Support dot product (easy).
*   Support += and -= (with optimisation: destination register can be 
reused).
*   Support floats as well as doubles (moderate).
*   Support imaginary and complex numbers (tricky).
*   Implement loop unrolling for small loops (tricky).
*   Make a version that spits out D code (like Blitz++) (moderate).
*   Make an SSE2 version (extremely tricky).
*/
import std.stdio;  // only for debugging
import std.string; // only for debugging

// This would be a vector. For now, just store a name.
class DVec {
     char [] data;
public:
     this(char [] name) { data = name; }
     DVec opAssign(A)(A expr) {
         pragma(msg, code!(A));
         expr.evaluate(); // Do the actual work here.
         return this;
     }
     Expr!('+', A, DVec) opAdd(A)(A w) {
         Expr!('+', A, DVec) q;
         q.right = this;
         q.left = w;
         return q;
     }
     Expr!('-', A, DVec) opSub(A)(A w) {
         Expr!('-', A, DVec) q;
         q.right = this;
         q.left = w;
         return q;
     }
     Expr!('*', real, DVec) opMul(A)(A w) {
         static assert(is(A : real), "Can only multiply by scalars");
         Expr!('*', real, DVec) q;
         q.right = this;
         q.left = w;
         return q;
     }
}

/// For compile-time debugging
template ExprToText(X, int numvecs=0, int numscalars=0)
{
     static if (is (X: real))
        const char [] ExprToText = "Scalar" ~ cast(char)('1' + numscalars);
     else static if (is (X.isExpr))
         const char [] ExprToText = "(" ~ X.ExprString!(numvecs, 
numscalars) ~ ")";
     else static if (is (X: DVec))
        const char [] ExprToText = "Vec"  ~ cast(char)('1' + numvecs);
}

template VecName(int vecnum) {
     const char [] VecName = "E" ~ cast(char)('A'+vecnum-1) ~ "X";
}

template ScalarName(int num) {
     const char [] ScalarName = "Scalar" ~ cast(char)('0'+num);
}

   // Vectors 1..4 are stored in EAX, EBX, ECX, EDX.
   // ESI is the counter variable.
   // EDI is the destination.
   // EBP is a scratch register.
   // If there are more than 4 source vectors, we'll have to use EBP for 
the spill
   // (not currently implemented).
template code(X) {
     const char [] code = "-------- GENERATE CODE --------------"\n
     ~ "; " ~ X.ExprString!(0,0) ~ \n
     ~ "; push used registers "\n
     ~ "  mov EBP, vector1.length; "\n
     ~ ";load vectors into EAX, EBX,..."\n
     ~ "  mov EAX, vector1.ptr;"\n
     ~ "  ... mov EBX, vector2.ptr; ...etc"\n
     ~ "  xor ESI, ESI;"\n
     ~ "  lea EAX, [EAX + 8*EBP];"\n
     ~ "  ... lea EBX, [EBX + 8*EBP]; ... etc"\n
     ~ "  sub ESI, EBP; // counter = -length" \n
     ~ "  mov EDI, destination;"\n
     ~ "  lea EDI, [EDI + 8*EBP];"\n
     ~ "  jz short L3; // test for length==0"\n
     ~ X.instr!(0,0).s
     ~ "  jmp short L2;\n"
     ~ "L1:\n"
     ~ X.instr!(0,0).s
     ~ "  fxch ST(1), ST; \t// previous result\n"
     ~ "  fstp double ptr [EDI + 8*ESI-8];\n"
     ~ "L2:\n"
     ~ X.finalinstr!(0,0).s
     ~ "  inc ESI;\n  jnz L1;\n"
     ~ "  fstp double ptr [EDI + 8*ESI-8];\n"
     ~ "L3:" \n
     ~ "; pop used registers";
}

// Stores all the data from the expression.
// 'dummy' is a workaround to avoid a 'recursive template' error message.
struct Expr(char operation, LeftExpr, RightExpr, int dummy=0) {
     typedef int isExpr; // workaround because is() doesn't work properly
     const char opType = operation;
     LeftExpr left;
     RightExpr right;

  static if (operation=='*') {
     // On x87, the FMUL instruction has a long latency.
     // We arrange the order to give it a chance to finish before we use 
it again.
    Expr!('+', Expr!(operation, LeftExpr, RightExpr, 1), A) opAdd(A)(A w) {
        Expr!('+', Expr!(operation, LeftExpr, RightExpr, 1), A) q;
        q.left.left = this.left;
        q.left.right = this.right;
        q.right = w;
        return q;
    }
    Expr!('-', Expr!(operation, LeftExpr, RightExpr, 1), A) opSub(A)(A w) {
        Expr!('-', Expr!(operation, LeftExpr, RightExpr, 1), A) q;
        q.left.left = this.left;
        q.left.right = this.right;
        q.right = w;
        return q;
    }
  } else {
    Expr!('+', A, Expr!(operation, LeftExpr, RightExpr, 1)) opAdd(A)(A w) {
        Expr!('+', A, Expr!(operation, LeftExpr, RightExpr, 1)) q;
        q.right.left = this.left;
        q.right.right = this.right;
        q.left = w;
        return q;
    }

    Expr!('-', A, Expr!(operation, LeftExpr, RightExpr, 1)) opSub(A)(A w) {
        Expr!('-', A, Expr!(operation, LeftExpr, RightExpr, 1)) q;
        q.right.left = this.left;
        q.right.right = this.right;
        q.left = w;
        return q;
    }
  }
  static if ( operation=='*' && is(RightExpr: real)) {
        // Constant fold scalars
        Expr!('*', LeftExpr, real) opMul(real w) {
            Expr!('*', LeftExpr, real) q;
            q.left = this.left;
            q.right = this.right * w;
            return q;
        }
  } else static if ( operation=='*' && is(LeftExpr: real)) {
        // Constant fold scalars
        Expr!('*', RightExpr, real) opMul(real w) {
            Expr!('*', RightExpr, real) q;
            q.left = this.right;
            q.right = this.left * w;
            return q;
        }
  } else {
    Expr!('*', Expr!(operation, LeftExpr, RightExpr, 1), real) 
opMul(A)(A w) {
        static assert(is(A : real), "Can only multiply by scalars");
        Expr!('*', Expr!(operation, LeftExpr, RightExpr, 1), real) q;
        q.left.left = this.left;
        q.left.right = this.right;
        q.right = w;
        return q;
    }
  }

  static if (is (LeftExpr : DVec)) {
     const int getNumLeftVectors = 1;
     const int getNumLeftScalars = 0;
  } else static if (is(LeftExpr.isExpr)) {
     const int getNumLeftVectors = LeftExpr.totalVectors;
     const int getNumLeftScalars = LeftExpr.totalScalars;
  } else {
     const int getNumLeftVectors = 0;
     const int getNumLeftScalars = 1;
  }

  static if (is (RightExpr : DVec)) {
     const int totalVectors = getNumLeftVectors + 1;
     const int totalScalars = getNumLeftScalars;
  } else static if (is(RightExpr.isExpr)) {
     const int totalVectors = getNumLeftVectors + RightExpr.totalVectors;
     const int totalScalars = getNumLeftScalars + RightExpr.totalScalars;
  } else {
     const int totalVectors = getNumLeftVectors;
     const int totalScalars = getNumLeftScalars + 1;
  }

   template ExprString(int numvecs, int numscalars) {
     const char [] ExprString = ExprToText!(LeftExpr, numvecs, 
numscalars) ~ " " ~ operation ~ " "
         ~ ExprToText!(RightExpr, numvecs + getNumLeftVectors, 
numscalars + getNumLeftScalars);
   }

   template finalinstr(int vecbase, int scalarbase) {
         static if (operation == '*') {
             const char [] oper= "  fmul";
         } else static if (operation == '-') {
             const char [] oper= "  fsub";
         }  else const char [] oper= "  fadd";

         static if (is (right.isExpr)) {
             const char [] s = oper ~ "p ST(1), ST;\t\t\n";
         } else static if (is (RightExpr: DVec)) {
             const char [] s = oper ~ " double ptr [" ~ 
VecName!(vecbase+1+getNumLeftVectors) ~ " + 8*ESI];\t\n"; // " ~ getStr 
~ " is " ~ right.data ~ \n;
         } else const char [] s = oper ~ " real ptr [" ~ 
ScalarName!(scalarbase+1+getNumLeftScalars)~ "];\n"; // " ~ getStr ~ \n;
   }

   /// Print the instructions required to evaluate it.
   template instr(int vecbase, int scalarbase) {
         static if (is (left.isExpr)) {
             const char [] s1 = left.instr!(vecbase, scalarbase).s ~ 
left.finalinstr!(vecbase, scalarbase).s;
         } else static if (is (LeftExpr : DVec)) {
             const char [] s1 =  "  fld double ptr [" ~ 
VecName!(vecbase+1) ~ " + 8*ESI];\n"; // left.data
         } else const char [] s1 = "  fld real ptr [" ~ 
ScalarName!(scalarbase+1)~ "];\n";//= format("  fld real ptr [ %g ];\n", 
left);

         static if (is (right.isExpr)) {
             const char [] s = s1 ~ 
right.instr!(vecbase+getNumLeftVectors, scalarbase+getNumLeftScalars).s 
~ right.finalinstr!(vecbase+getNumLeftVectors, 
scalarbase+getNumLeftScalars).s;
         } else const char [] s = s1;
     }
   void evaluate() {
       printf("----------\nEvaluate:    ");
       printf(ExprString!(0,0).ptr);
       printf(\n"With values: ");
       printf(getStr().ptr);
       printf(\n);
   }

     // Just for debugging.
     char [] getStr() {
         char [] s;
         static if (is (left.isExpr)) {
             s = "(" ~ left.getStr() ~ ")";
         } else static if (is (LeftExpr : DVec)) {
             s =  left.data;
         } else s ~= format("%g", left);
         s ~= " " ~ operation ~ " ";
         static if (is (right.isExpr)) {
             s ~= "(" ~ right.getStr() ~ ")";
         } else static if (is (RightExpr: DVec)) {
             s ~= right.data;
         } else s ~= format("%g", right);
         return s;
     }
}

void main()
{
     DVec a = new DVec("a");
     DVec b = new DVec("b");
     DVec c = new DVec("c");

     b = a + c;
     b = c*2;
     a = 27.4*a + 2*((b-c) + 3.2*c)*1.234*2;

     a = 7*b + a;
}



More information about the Digitalmars-d mailing list