问题
I'm trying to optimize this function using SIMD but I don't know where to start.
long sum(int x,int y)
{
return x*x*x+y*y*y;
}
The disassembled function looks like this:
4007a0: 48 89 f2 mov %rsi,%rdx
4007a3: 48 89 f8 mov %rdi,%rax
4007a6: 48 0f af d6 imul %rsi,%rdx
4007aa: 48 0f af c7 imul %rdi,%rax
4007ae: 48 0f af d6 imul %rsi,%rdx
4007b2: 48 0f af c7 imul %rdi,%rax
4007b6: 48 8d 04 02 lea (%rdx,%rax,1),%rax
4007ba: c3 retq
4007bb: 0f 1f 44 00 00 nopl 0x0(%rax,%rax,1)
The calling code looks like this:
do {
for (i = 0; i < maxi; i++) {
j = nextj[i];
long sum = cubeSum(i,j);
while (sum <= p) {
long x = sum & (psize - 1);
int flag = table[x];
if (flag <= guard) {
table[x] = guard+1;
} else if (flag == guard+1) {
table[x] = guard+2;
count++;
}
j++;
sum = cubeSum(i,j);
}
nextj[i] = j;
}
p += psize;
guard += 3;
} while (p <= n);
回答1:
- Fill one SSE register with (x|y|0|0) (since each SSE register holds 4 32-bit elements). Lets call it r1
- then make a copy of that register to another register r2
- Do r2 * r1, storing the result in, say r2.
- Do r2 * r1 again storing the result in r2
- Now in r2 you have (x*x*x|y*y*y|0|0)
- Unpack the lower two elements of r2 into separate registers, add them (SSE3 has horizontal add instructions, but only for floats and doubles).
In the end, I'd actually be surprised if this turned out to be any faster than the simple code the compiler has already generated for you. SIMD is more useful if you have arrays of data you want to operate on..
回答2:
This particular case is not a good fit for SIMD (SSE or otherwise). SIMD really only works well when you have contiguous arrays that you can access sequentially and process heterogeneously.
However you can at least get rid of some of the redundant operations in the scalar code, e.g. repeatedly calculating i * i * i
when i
is invariant:
do {
for (i = 0; i < maxi; i++) {
int i3 = i * i * i;
int j = nextj[i];
int j3 = j * j * j;
long sum = i3 + j3;
while (sum <= p) {
long x = sum & (psize - 1);
int flag = table[x];
if (flag <= guard) {
table[x] = guard+1;
} else if (flag == guard+1) {
table[x] = guard+2;
count++;
}
j++;
j3 = j * j * j;
sum = i3 + j3;
}
nextj[i] = j;
}
p += psize;
guard += 3;
} while (p <= n);
来源:https://stackoverflow.com/questions/8357182/multiplication-using-sse-xxxyyy