#include <emmintrin.h>
// Calls instruction PHADDD from SSSE3
// For unknown reason it's not defined in the emmintrin.h
inline __m128i _mm_hadd_epi32 (__m128i a, __m128i b)
{
__asm("phaddd %[b], %[a]" : [a]"=&x"(a): "[a]"(a), [b]"x"(b));
return a;
}
// Multiply row of N short values pointed by a
// by 4 rows of N short values pointed by b.
// Difference between the rows is specified by bsize (in bytes).
// So, the b rows' pointers are:
// b
// (const short*)((const char*)b+bsize)
// (const short*)((const char*)b+2*bsize)
// (const short*)((const char*)b+3*bsize)
//
// Products for all 4 rows are then accumulated as int values and stored
// into 4 sequential int values pointed by m.
// N>0, and is multiple of 8.
// The pointers m, a, b, and the offset bsize are 16-byte aligned.
inline void row4_mul(int N, int* m, const short* a, const short* b, int bsize)
{
// s0, s1, s2, s3 are accumulators for the products
__m128i s0 = _mm_setzero_si128(), s1 = _mm_setzero_si128();
__m128i s2 = _mm_setzero_si128(), s3 = _mm_setzero_si128();
// v0, v1, v2, v3 are temporary registers
__m128i v0, v1, v2, v3;
// prepare pointers and offsets to
const short* be = b + N; // end pointer
int b2a = (const char*)a - (const char*)b; // offset of a relative to b
int bsize3 = 3*bsize; // offset of 3rd row in b
do {
// multiple and add 8 input values from each of the 4 rows at a time
asm("movdqa (%[b],%[b2a],1),%[v3] \n\t"
"movdqa (%[b]),%[v0] \n\t"
"pmaddwd %[v3],%[v0] \n\t"
"movdqa (%[b],%[bsize],1),%[v1] \n\t"
"pmaddwd %[v3],%[v1] \n\t"
"movdqa (%[b],%[bsize],2),%[v2] \n\t"
"pmaddwd %[v3],%[v2] \n\t"
"pmaddwd (%[b],%[bsize3],1),%[v3] \n\t"
"add $0x10,%[b] \n\t"
"paddd %[v0],%[s0] \n\t"
"paddd %[v1],%[s1] \n\t"
"paddd %[v2],%[s2] \n\t"
"paddd %[v3],%[s3] \n\t"
: [s0]"=&x"(s0), [s1]"=&x"(s1), [s2]"=&x"(s2), [s3]"=&x"(s3),
[v0]"=&x"(v0), [v1]"=&x"(v1), [v2]"=&x"(v2), [v3]"=&x"(v3),
[b]"=&r"(b)
: "[s0]"(s0), "[s1]"(s1), "[s2]"(s2), "[s3]"(s3),
[b2a]"r"(b2a), [bsize]"r"(bsize), [bsize3]"r"(bsize3));
} while ( b < be );
// add accumulators into a single xmm register
s0 = _mm_hadd_epi32(s0, s1);
s2 = _mm_hadd_epi32(s2, s3);
s0 = _mm_hadd_epi32(s0, s2);
_mm_store_si128((__m128i*)m, s0);
}
void foo(int N, int* m, const short* a, const short* b, int bsize)
{
row4_mul(N, m, a, b, bsize);
}