Expression templates
Don Clugston
dac at nospam.com.au
Wed Dec 13 21:35:37 PST 2006
Bill Baxter wrote:
> Don Clugston wrote:
>> Don Clugston wrote:
>>> The recent language improvements (especially, opAssign) have made
>>> expression templates eminently feasible. Here's a proof-of-concept which
>>> implements parsing of the basic vector operations of addition and
>>> scalar multiplication.
>>
>> I've made a slightly improved version...
>
> I'm making some headway on an nd-array class that uses BLAS/LAPACK under
> the hood. It would be nifty if I could somehow get this:
>
> R = A * x + y
>
> to be groked and translated to a single call to of BLAS's
> saxpy/daxpy/caxpy/zaxpy routines.
>
> With all the EAX talk, it looks like you're talking about a different
> level of optimization. The theory with BLAS anyway, is that you'll have
> access to a hand-tuned version for the platform you're on, so it'll be
> hard to beat with any kind of register-level compile-time
> micromanagement. For instance BLAS level 3 (matrix-matrix) routines do
> blocking to make things easier on the cache for big inputs. And of
> course there are also parallel BLAS implementations around, too.
Quite true. But the code below spits out very-near optimal asm for x87,
with double[] vectors with +, -, and scalar multiplication.
For example, the code generated for something like
a = b + 3.5 * c - 7 *(d - e)*s;
where a, b, c, d, e are vectors, and s is a scalar, will (I think) leave
BLAS for dead. Think of it as extended BLAS1 kernel.
However, it shouldn't be to difficult to generate calls to DAXPY and kin.
-------
/*
* Supports vector addition, subtraction, and multiplication by a real
scalar.
* For functions of the form in BLAS1, the code is optimal for early
Pentium CPUs, or very nearly so.
* For example, the code generated for the crucial DAXPY operation is almost
* identical to that described in Agner Fog's optimisation manual
(www.agner.org),
* and is faster than the unrolled solution published by Intel.
*
* Future directions:
* Actually generate the code (instead of just displaying it)! (tricky).
* Support dot product (easy).
* Support += and -= (with optimisation: destination register can be
reused).
* Support floats as well as doubles (moderate).
* Support imaginary and complex numbers (tricky).
* Implement loop unrolling for small loops (tricky).
* Make a version that spits out D code (like Blitz++) (moderate).
* Make an SSE2 version (extremely tricky).
*/
import std.stdio; // only for debugging
import std.string; // only for debugging
// This would be a vector. For now, just store a name.
class DVec {
char [] data;
public:
this(char [] name) { data = name; }
DVec opAssign(A)(A expr) {
pragma(msg, code!(A));
expr.evaluate(); // Do the actual work here.
return this;
}
Expr!('+', A, DVec) opAdd(A)(A w) {
Expr!('+', A, DVec) q;
q.right = this;
q.left = w;
return q;
}
Expr!('-', A, DVec) opSub(A)(A w) {
Expr!('-', A, DVec) q;
q.right = this;
q.left = w;
return q;
}
Expr!('*', real, DVec) opMul(A)(A w) {
static assert(is(A : real), "Can only multiply by scalars");
Expr!('*', real, DVec) q;
q.right = this;
q.left = w;
return q;
}
}
/// For compile-time debugging
template ExprToText(X, int numvecs=0, int numscalars=0)
{
static if (is (X: real))
const char [] ExprToText = "Scalar" ~ cast(char)('1' + numscalars);
else static if (is (X.isExpr))
const char [] ExprToText = "(" ~ X.ExprString!(numvecs,
numscalars) ~ ")";
else static if (is (X: DVec))
const char [] ExprToText = "Vec" ~ cast(char)('1' + numvecs);
}
template VecName(int vecnum) {
const char [] VecName = "E" ~ cast(char)('A'+vecnum-1) ~ "X";
}
template ScalarName(int num) {
const char [] ScalarName = "Scalar" ~ cast(char)('0'+num);
}
// Vectors 1..4 are stored in EAX, EBX, ECX, EDX.
// ESI is the counter variable.
// EDI is the destination.
// EBP is a scratch register.
// If there are more than 4 source vectors, we'll have to use EBP for
the spill
// (not currently implemented).
template code(X) {
const char [] code = "-------- GENERATE CODE --------------"\n
~ "; " ~ X.ExprString!(0,0) ~ \n
~ "; push used registers "\n
~ " mov EBP, vector1.length; "\n
~ ";load vectors into EAX, EBX,..."\n
~ " mov EAX, vector1.ptr;"\n
~ " ... mov EBX, vector2.ptr; ...etc"\n
~ " xor ESI, ESI;"\n
~ " lea EAX, [EAX + 8*EBP];"\n
~ " ... lea EBX, [EBX + 8*EBP]; ... etc"\n
~ " sub ESI, EBP; // counter = -length" \n
~ " mov EDI, destination;"\n
~ " lea EDI, [EDI + 8*EBP];"\n
~ " jz short L3; // test for length==0"\n
~ X.instr!(0,0).s
~ " jmp short L2;\n"
~ "L1:\n"
~ X.instr!(0,0).s
~ " fxch ST(1), ST; \t// previous result\n"
~ " fstp double ptr [EDI + 8*ESI-8];\n"
~ "L2:\n"
~ X.finalinstr!(0,0).s
~ " inc ESI;\n jnz L1;\n"
~ " fstp double ptr [EDI + 8*ESI-8];\n"
~ "L3:" \n
~ "; pop used registers";
}
// Stores all the data from the expression.
// 'dummy' is a workaround to avoid a 'recursive template' error message.
struct Expr(char operation, LeftExpr, RightExpr, int dummy=0) {
typedef int isExpr; // workaround because is() doesn't work properly
const char opType = operation;
LeftExpr left;
RightExpr right;
static if (operation=='*') {
// On x87, the FMUL instruction has a long latency.
// We arrange the order to give it a chance to finish before we use
it again.
Expr!('+', Expr!(operation, LeftExpr, RightExpr, 1), A) opAdd(A)(A w) {
Expr!('+', Expr!(operation, LeftExpr, RightExpr, 1), A) q;
q.left.left = this.left;
q.left.right = this.right;
q.right = w;
return q;
}
Expr!('-', Expr!(operation, LeftExpr, RightExpr, 1), A) opSub(A)(A w) {
Expr!('-', Expr!(operation, LeftExpr, RightExpr, 1), A) q;
q.left.left = this.left;
q.left.right = this.right;
q.right = w;
return q;
}
} else {
Expr!('+', A, Expr!(operation, LeftExpr, RightExpr, 1)) opAdd(A)(A w) {
Expr!('+', A, Expr!(operation, LeftExpr, RightExpr, 1)) q;
q.right.left = this.left;
q.right.right = this.right;
q.left = w;
return q;
}
Expr!('-', A, Expr!(operation, LeftExpr, RightExpr, 1)) opSub(A)(A w) {
Expr!('-', A, Expr!(operation, LeftExpr, RightExpr, 1)) q;
q.right.left = this.left;
q.right.right = this.right;
q.left = w;
return q;
}
}
static if ( operation=='*' && is(RightExpr: real)) {
// Constant fold scalars
Expr!('*', LeftExpr, real) opMul(real w) {
Expr!('*', LeftExpr, real) q;
q.left = this.left;
q.right = this.right * w;
return q;
}
} else static if ( operation=='*' && is(LeftExpr: real)) {
// Constant fold scalars
Expr!('*', RightExpr, real) opMul(real w) {
Expr!('*', RightExpr, real) q;
q.left = this.right;
q.right = this.left * w;
return q;
}
} else {
Expr!('*', Expr!(operation, LeftExpr, RightExpr, 1), real)
opMul(A)(A w) {
static assert(is(A : real), "Can only multiply by scalars");
Expr!('*', Expr!(operation, LeftExpr, RightExpr, 1), real) q;
q.left.left = this.left;
q.left.right = this.right;
q.right = w;
return q;
}
}
static if (is (LeftExpr : DVec)) {
const int getNumLeftVectors = 1;
const int getNumLeftScalars = 0;
} else static if (is(LeftExpr.isExpr)) {
const int getNumLeftVectors = LeftExpr.totalVectors;
const int getNumLeftScalars = LeftExpr.totalScalars;
} else {
const int getNumLeftVectors = 0;
const int getNumLeftScalars = 1;
}
static if (is (RightExpr : DVec)) {
const int totalVectors = getNumLeftVectors + 1;
const int totalScalars = getNumLeftScalars;
} else static if (is(RightExpr.isExpr)) {
const int totalVectors = getNumLeftVectors + RightExpr.totalVectors;
const int totalScalars = getNumLeftScalars + RightExpr.totalScalars;
} else {
const int totalVectors = getNumLeftVectors;
const int totalScalars = getNumLeftScalars + 1;
}
template ExprString(int numvecs, int numscalars) {
const char [] ExprString = ExprToText!(LeftExpr, numvecs,
numscalars) ~ " " ~ operation ~ " "
~ ExprToText!(RightExpr, numvecs + getNumLeftVectors,
numscalars + getNumLeftScalars);
}
template finalinstr(int vecbase, int scalarbase) {
static if (operation == '*') {
const char [] oper= " fmul";
} else static if (operation == '-') {
const char [] oper= " fsub";
} else const char [] oper= " fadd";
static if (is (right.isExpr)) {
const char [] s = oper ~ "p ST(1), ST;\t\t\n";
} else static if (is (RightExpr: DVec)) {
const char [] s = oper ~ " double ptr [" ~
VecName!(vecbase+1+getNumLeftVectors) ~ " + 8*ESI];\t\n"; // " ~ getStr
~ " is " ~ right.data ~ \n;
} else const char [] s = oper ~ " real ptr [" ~
ScalarName!(scalarbase+1+getNumLeftScalars)~ "];\n"; // " ~ getStr ~ \n;
}
/// Print the instructions required to evaluate it.
template instr(int vecbase, int scalarbase) {
static if (is (left.isExpr)) {
const char [] s1 = left.instr!(vecbase, scalarbase).s ~
left.finalinstr!(vecbase, scalarbase).s;
} else static if (is (LeftExpr : DVec)) {
const char [] s1 = " fld double ptr [" ~
VecName!(vecbase+1) ~ " + 8*ESI];\n"; // left.data
} else const char [] s1 = " fld real ptr [" ~
ScalarName!(scalarbase+1)~ "];\n";//= format(" fld real ptr [ %g ];\n",
left);
static if (is (right.isExpr)) {
const char [] s = s1 ~
right.instr!(vecbase+getNumLeftVectors, scalarbase+getNumLeftScalars).s
~ right.finalinstr!(vecbase+getNumLeftVectors,
scalarbase+getNumLeftScalars).s;
} else const char [] s = s1;
}
void evaluate() {
printf("----------\nEvaluate: ");
printf(ExprString!(0,0).ptr);
printf(\n"With values: ");
printf(getStr().ptr);
printf(\n);
}
// Just for debugging.
char [] getStr() {
char [] s;
static if (is (left.isExpr)) {
s = "(" ~ left.getStr() ~ ")";
} else static if (is (LeftExpr : DVec)) {
s = left.data;
} else s ~= format("%g", left);
s ~= " " ~ operation ~ " ";
static if (is (right.isExpr)) {
s ~= "(" ~ right.getStr() ~ ")";
} else static if (is (RightExpr: DVec)) {
s ~= right.data;
} else s ~= format("%g", right);
return s;
}
}
void main()
{
DVec a = new DVec("a");
DVec b = new DVec("b");
DVec c = new DVec("c");
b = a + c;
b = c*2;
a = 27.4*a + 2*((b-c) + 3.2*c)*1.234*2;
a = 7*b + a;
}
More information about the Digitalmars-d
mailing list