From 496683b892a65d19917d169e6a98b86b878ea726 Mon Sep 17 00:00:00 2001 From: owenlu552 Date: Wed, 14 Mar 2012 23:09:15 -0700 Subject: [PATCH 1/6] optimizations: reduced number of "hadd" calls, incorperated simd into transpose --- sgemm-small.c | 263 ++++++++------------------------------------------ 1 file changed, 42 insertions(+), 221 deletions(-) diff --git a/sgemm-small.c b/sgemm-small.c index a7eda31..bf88217 100644 --- a/sgemm-small.c +++ b/sgemm-small.c @@ -1,6 +1,5 @@ #include #include -#include #include #include @@ -24,19 +23,33 @@ void square_sgemm( int n, float *A, float *B, float *C ) { __m128 partialSum5; __m128 partialSum6; __m128 partialSum7; + /* float pSum[4]; float pSum1[4]; float pSum2[4]; float pSum3[4]; + */ float cij=0.0, cij1=0.0, cij2=0.0, cij3=0.0, cij4=0.0, cij5=0.0, cij6=0.0, cij7=0.0; //transpose A - //I was unable to simd this without doing extra store/loads for (i = 0; i < n; i ++) { for (j = 0; j < n/4*4; j += 4) { - At[i+j*n] = A[j+i*n]; - At[i+(j+1)*n] = A[j+i*n + 1]; - At[i+(j+2)*n] = A[j+i*n + 2]; - At[i+(j+3)*n] = A[j+i*n + 3]; + + x = _mm_loadu_ps(A + j + i*n); + _MM_EXTRACT_FLOAT(temp, x, 0); + At[i+j*n] = temp; + _MM_EXTRACT_FLOAT(temp, x, 1); + At[i+(j+1)*n] = temp; + _MM_EXTRACT_FLOAT(temp, x, 2); + At[i+(j+2)*n] = temp; + _MM_EXTRACT_FLOAT(temp, x, 3); + At[i+(j+3)*n] = temp; + + /* + At[i+j*n] = A[j + i*n]; + At[i+(j+1)*n] = A[j + 1 + i*n]; + At[i+(j+2)*n] = A[j + 2 + i*n]; + At[i+(j+3)*n] = A[j + 3 + i*n]; + */ } for (; j Date: Wed, 14 Mar 2012 23:10:16 -0700 Subject: [PATCH 2/6] Update sgemm-small.c --- sgemm-small.c | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/sgemm-small.c b/sgemm-small.c index bf88217..4913746 100644 --- a/sgemm-small.c +++ b/sgemm-small.c @@ -23,12 +23,7 @@ void square_sgemm( int n, float *A, float *B, float *C ) { __m128 partialSum5; __m128 partialSum6; __m128 partialSum7; - /* - float pSum[4]; - float pSum1[4]; - float pSum2[4]; - float pSum3[4]; - */ + float cij=0.0, cij1=0.0, cij2=0.0, cij3=0.0, cij4=0.0, cij5=0.0, cij6=0.0, cij7=0.0; //transpose A for (i = 0; i < n; i ++) { @@ -44,12 +39,6 @@ void square_sgemm( int n, float *A, float *B, float *C ) { _MM_EXTRACT_FLOAT(temp, x, 3); At[i+(j+3)*n] = temp; - /* - At[i+j*n] = A[j + i*n]; - At[i+(j+1)*n] = A[j + 1 + i*n]; - At[i+(j+2)*n] = A[j + 2 + i*n]; - At[i+(j+3)*n] = A[j + 3 + i*n]; - */ } for (; j Date: Wed, 14 Mar 2012 23:11:14 -0700 Subject: [PATCH 3/6] removed unused code --- sgemm-small.c | 2 -- 1 file changed, 2 deletions(-) diff --git a/sgemm-small.c b/sgemm-small.c index 4913746..2ce41b3 100644 --- a/sgemm-small.c +++ b/sgemm-small.c @@ -3,8 +3,6 @@ #include #include -#define NUM_REGISTERS 4; - void square_sgemm( int n, float *A, float *B, float *C ) { int i, j , k, l; //int count = 0; //for debug From 9c4278d75e92495908f5f2dd685905e019e0d0fd Mon Sep 17 00:00:00 2001 From: owenlu552 Date: Fri, 16 Mar 2012 16:20:38 -0700 Subject: [PATCH 4/6] 10.6 gflops on 64x64, edge cleanup not finished. blocking by 4 on i, 2 on j, simd reads and stores into C --- sgemm-small.c | 208 +++++++++++++++++++++----------------------------- 1 file changed, 87 insertions(+), 121 deletions(-) diff --git a/sgemm-small.c b/sgemm-small.c index 2ce41b3..be54478 100644 --- a/sgemm-small.c +++ b/sgemm-small.c @@ -1,8 +1,11 @@ #include #include +#include #include #include +#define NUM_REGISTERS 4; + void square_sgemm( int n, float *A, float *B, float *C ) { int i, j , k, l; //int count = 0; //for debug @@ -10,9 +13,10 @@ void square_sgemm( int n, float *A, float *B, float *C ) { float temp; __m128 x; __m128 y; - __m128 z; __m128 a; - __m128 zero = _mm_setzero_ps(); + __m128 b; + __m128 c; + __m128 d; __m128 partialSum; __m128 partialSum1; __m128 partialSum2; @@ -21,8 +25,9 @@ void square_sgemm( int n, float *A, float *B, float *C ) { __m128 partialSum5; __m128 partialSum6; __m128 partialSum7; - float cij=0.0, cij1=0.0, cij2=0.0, cij3=0.0, cij4=0.0, cij5=0.0, cij6=0.0, cij7=0.0; + __m128 c1; + __m128 c2; //transpose A for (i = 0; i < n; i ++) { for (j = 0; j < n/4*4; j += 4) { @@ -44,21 +49,13 @@ void square_sgemm( int n, float *A, float *B, float *C ) { } // For each row i of A - for (i = 0; i < n; ++i) { + for (i = 0; i < n/4*4; i+=4) { // For each column j of B - for (j = 0; j < n/8*8; j+=8) + for (j = 0; j < n/2*2; j+=2) { - // Compute C(i,j) - - cij = C[i+j*n]; - cij1 = C[i+(j+1)*n]; - cij2 = C[i+(j+2)*n]; - cij3 = C[i+(j+3)*n]; - cij4 = C[i+(j+4)*n]; - cij5 = C[i+(j+5)*n]; - cij6 = C[i+(j+6)*n]; - cij7 = C[i+(j+7)*n]; - + //load C inital values + c1 = _mm_loadu_ps(C + i + j*n); + c2 = _mm_loadu_ps(C + i + (j+1)*n); //this will hold 4 floats which sum to the dot product partialSum = _mm_setzero_ps(); @@ -71,116 +68,85 @@ void square_sgemm( int n, float *A, float *B, float *C ) { partialSum7 = _mm_setzero_ps(); for(k = 0; k < n/4*4; k += 4) { - x = _mm_loadu_ps(At + k + i*n); - y = _mm_loadu_ps(B + k + j*n); - z = _mm_mul_ps(x, y); - //accumulate dot prduct - y = _mm_loadu_ps(B + k + (j+1) *n); - a = _mm_mul_ps(x, y); - partialSum = _mm_add_ps(partialSum, z); - - //accumulate dot prduct - y = _mm_loadu_ps(B + k + (j+2)*n); - z = _mm_mul_ps(x, y); - partialSum1 = _mm_add_ps(partialSum1, a); - - //accumulate dot prduct - y = _mm_loadu_ps(B + k + (j+3)*n); - a = _mm_mul_ps(x, y); - partialSum2 = _mm_add_ps(partialSum2, z); - - //accumulate dot prduct - y = _mm_loadu_ps(B + k + (j+4)*n); - z = _mm_mul_ps(x, y); - partialSum3 = _mm_add_ps(partialSum3, a); - - //accumulate dot prduct - y = _mm_loadu_ps(B + k + (j+5) *n); - a = _mm_mul_ps(x, y); - partialSum4 = _mm_add_ps(partialSum4, z); - - //accumulate dot prduct - y = _mm_loadu_ps(B + k + (j+6)*n); - z = _mm_mul_ps(x, y); - partialSum5 = _mm_add_ps(partialSum5, a); + a = _mm_loadu_ps(At + k + i*n); + x = _mm_loadu_ps(B + k + j*n); + y = _mm_mul_ps(a,x); + partialSum = _mm_add_ps(partialSum, y); + b = _mm_loadu_ps(At + k + (i+1)*n); + y = _mm_mul_ps(b,x); + partialSum1 = _mm_add_ps(partialSum1, y); + c = _mm_loadu_ps(At + k + (i+2)*n); + y = _mm_mul_ps(c,x); + partialSum2 = _mm_add_ps(partialSum2,y); + d = _mm_loadu_ps(At + k + (i+3)*n); + y = _mm_mul_ps(d,x); + partialSum3 = _mm_add_ps(partialSum3,y); + x = _mm_loadu_ps(B + k + (j+1)*n); + y = _mm_mul_ps(a, x); + partialSum4 = _mm_add_ps(partialSum4,y); + y = _mm_mul_ps(b, x); + partialSum5 = _mm_add_ps(partialSum5, y); + y = _mm_mul_ps(c, x); + partialSum6 = _mm_add_ps(partialSum6, y); + y = _mm_mul_ps(d, x); + partialSum7 = _mm_add_ps(partialSum7, y); - //accumulate dot prduct - y = _mm_loadu_ps(B + k + (j+7)*n); - a = _mm_mul_ps(x, y); - partialSum6 = _mm_add_ps(partialSum6, z); - - //accumulate dot prduct - partialSum7 = _mm_add_ps(partialSum7, a); - } - partialSum = _mm_hadd_ps(partialSum, partialSum1); - partialSum1 = _mm_hadd_ps(partialSum2, partialSum3); - partialSum2 = _mm_hadd_ps(partialSum, partialSum1); // [p0,p1,p2,p3] where p1 = accumulation of partialSum1 - - partialSum4 = _mm_hadd_ps(partialSum4, partialSum5); - partialSum5 = _mm_hadd_ps(partialSum6, partialSum7); - partialSum6 = _mm_hadd_ps(partialSum4, partialSum5); // [p4,p5,p6,p7] + partialSum = _mm_hadd_ps(partialSum, partialSum1); + partialSum1 = _mm_hadd_ps(partialSum2, partialSum3); + partialSum2 = _mm_hadd_ps(partialSum, partialSum1); // [p0,p1,p2,p3] where p0 = sum of partialSum0 + + partialSum4 = _mm_hadd_ps(partialSum4, partialSum5); + partialSum5 = _mm_hadd_ps(partialSum6, partialSum7); + partialSum6 = _mm_hadd_ps(partialSum4, partialSum5); // [p4,p5,p6,p7] - _MM_EXTRACT_FLOAT(temp, partialSum2, 0); - cij += temp; - _MM_EXTRACT_FLOAT(temp, partialSum2, 1); - cij1 += temp; - _MM_EXTRACT_FLOAT(temp, partialSum2, 2); - cij2 += temp; - _MM_EXTRACT_FLOAT(temp, partialSum2, 3); - cij3 += temp; - _MM_EXTRACT_FLOAT(temp, partialSum6, 0); - cij4 += temp; - _MM_EXTRACT_FLOAT(temp, partialSum6, 1); - cij5 += temp; - _MM_EXTRACT_FLOAT(temp, partialSum6, 2); - cij6 += temp; - _MM_EXTRACT_FLOAT(temp, partialSum6, 3); - cij7 += temp; - - //cleanup k - for (; k < n; k ++) { - cij += At[k+i*n] * B[k+j*n]; - cij1 += At[k+i*n] * B[k+(j+1)*n]; - cij2 += At[k+i*n] * B[k+(j+2)*n]; - cij3 += At[k+i*n] * B[k+(j+3)*n]; - cij4 += At[k+i*n] * B[k+(j+4)*n]; - cij5 += At[k+i*n] * B[k+(j+5)*n]; - cij6 += At[k+i*n] * B[k+(j+6)*n]; - cij7 += At[k+i*n] * B[k+(j+7)*n]; + c1 = _mm_add_ps(c1, partialSum2); + c2 = _mm_add_ps(c2, partialSum6); + //cleanup k + if (k != n) { + for (; k < n; k ++) { + cij += At[k+i*n] * B[k+j*n]; + cij1 += At[k+(i+1)*n] * B[k+j*n]; + cij2 += At[k+(i+2)*n] * B[k+j*n]; + cij3 += At[k+(i+3)*n] * B[k+j*n]; + cij4 += At[k+i*n] * B[k+(j+1)*n]; + cij5 += At[k+(i+1)*n] * B[k+(j+1)*n]; + cij6 += At[k+(i+2)*n] * B[k+(j+1)*n]; + cij7 += At[k+(i+3)*n] * B[k+(j+1)*n]; + } + float ca[] = {cij, cij1, cij2, cij3, cij4, cij5, cij6, cij7}; //LOOK FOR ALTERNATIVE + partialSum = _mm_loadu_ps(ca); + partialSum1 = _mm_loadu_ps(ca + 3); + c1 = _mm_add_ps(c1, partialSum); + c2 = _mm_add_ps(c2, partialSum1); } - C[i+j*n] = cij; - C[i+(j+1)*n] = cij1; - C[i+(j+2)*n] = cij2; - C[i+(j+3)*n] = cij3; - C[i+(j+4)*n] = cij4; - C[i+(j+5)*n] = cij5; - C[i+(j+6)*n] = cij6; - C[i+(j+7)*n] = cij7; - //count += 4; //for debug + _mm_storeu_ps(C + i + j*n, c1); + _mm_storeu_ps(C + i + (j+1)*n, c2); + //count += 4; //for debug } - //cleanup j + //cleanup j NOT FULLY IMPLEMENTED for (; j < n; j++) { - cij = C[i+j*n]; - partialSum = _mm_setzero_ps(); - for (k = 0; k < n/4*4; k+=4) { - x = _mm_loadu_ps(At + k + i*n); - y = _mm_loadu_ps(B + k + j*n); - z = _mm_mul_ps(x, y); - //accumulate dot prduct - partialSum = _mm_add_ps(partialSum, z); - } - partialSum = _mm_hadd_ps(partialSum,partialSum); - _MM_EXTRACT_FLOAT(temp, partialSum, 0); - cij += temp; - _MM_EXTRACT_FLOAT(temp, partialSum, 1); - cij += temp; - for (; k < n; k++) { - cij += At[k+i*n] * B[k+j*n]; - } - C[i+j*n] = cij; - //count += 1; //for debug + cij = C[i+j*n]; + partialSum = _mm_setzero_ps(); + for (k = 0; k < n/4*4; k+=4) { + x = _mm_loadu_ps(At + k + i*n); + y = _mm_loadu_ps(B + k + j*n); + a = _mm_mul_ps(x, y); + //accumulate dot prduct + partialSum = _mm_add_ps(partialSum, a); + } + partialSum = _mm_hadd_ps(partialSum,partialSum); + _MM_EXTRACT_FLOAT(temp, partialSum, 0); + cij += temp; + _MM_EXTRACT_FLOAT(temp, partialSum, 1); + cij += temp; + for (; k < n; k++) { + cij += At[k+i*n] * B[k+j*n]; + } + C[i+j*n] = cij; + //count += 1; //for debug } - } -} \ No newline at end of file + } // NEED TO CLEANUP i +} + From e38d10eb91e14dcf49fefc79c58635b1968affbe Mon Sep 17 00:00:00 2001 From: owenlu552 Date: Tue, 27 Mar 2012 15:37:35 -0300 Subject: [PATCH 5/6] not working --- sgemm-small.c | 450 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 317 insertions(+), 133 deletions(-) diff --git a/sgemm-small.c b/sgemm-small.c index be54478..f830bdd 100644 --- a/sgemm-small.c +++ b/sgemm-small.c @@ -4,149 +4,333 @@ #include #include -#define NUM_REGISTERS 4; - void square_sgemm( int n, float *A, float *B, float *C ) { - int i, j , k, l; - //int count = 0; //for debug - float At[n*n] __attribute__ ((aligned(16))); - float temp; - __m128 x; - __m128 y; - __m128 a; - __m128 b; - __m128 c; - __m128 d; - __m128 partialSum; - __m128 partialSum1; - __m128 partialSum2; - __m128 partialSum3; - __m128 partialSum4; - __m128 partialSum5; - __m128 partialSum6; - __m128 partialSum7; - float cij=0.0, cij1=0.0, cij2=0.0, cij3=0.0, cij4=0.0, cij5=0.0, cij6=0.0, cij7=0.0; - __m128 c1; - __m128 c2; - //transpose A + int f, g, h, i, j , k, l; + int blockI = 64, blockJ = 64, blockK = 64; +// if (n < 300) { +// blockI = 64; +// blockJ = 64; +// blockK = 256; +// } else { +// blockI = 16; +// blockJ = 16; +// blockK = 256; +// } + float temp, temp1, temp2, temp3, temp4; + __m128 x; + __m128 y; + __m128 a; + __m128 b; + __m128 c; + __m128 d; + __m128 partialSum; + __m128 partialSum1; + __m128 partialSum2; + __m128 partialSum3; + __m128 partialSum4; + __m128 partialSum5; + __m128 partialSum6; + __m128 partialSum7; + float cij=0.0, cij1=0.0, cij2=0.0, cij3=0.0, cij4=0.0, cij5=0.0, cij6=0.0, cij7=0.0; + __m128 c1; + __m128 c2; + float *At = malloc(n*n*sizeof(float)); for (i = 0; i < n; i ++) { - for (j = 0; j < n/4*4; j += 4) { - - x = _mm_loadu_ps(A + j + i*n); - _MM_EXTRACT_FLOAT(temp, x, 0); - At[i+j*n] = temp; - _MM_EXTRACT_FLOAT(temp, x, 1); - At[i+(j+1)*n] = temp; - _MM_EXTRACT_FLOAT(temp, x, 2); - At[i+(j+2)*n] = temp; - _MM_EXTRACT_FLOAT(temp, x, 3); - At[i+(j+3)*n] = temp; - - } - for (; j Date: Tue, 27 Mar 2012 16:33:13 -0300 Subject: [PATCH 6/6] Update sgemm-small.c --- sgemm-small.c | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sgemm-small.c b/sgemm-small.c index f830bdd..83fbb34 100644 --- a/sgemm-small.c +++ b/sgemm-small.c @@ -5,7 +5,7 @@ #include void square_sgemm( int n, float *A, float *B, float *C ) { - int f, g, h, i, j , k, l; + int f, g, h, i, j , k, l, alpha, beta, gamma; int blockI = 64, blockJ = 64, blockK = 64; // if (n < 300) { // blockI = 64; @@ -119,13 +119,13 @@ void square_sgemm( int n, float *A, float *B, float *C ) { } //k cleanup if (k < n) { - for(i = g; i < g + blockI; i ++) { - for (j = h; j < h + blockJ; j++) { - cij = C[i+j*n]; - for (k = f; k < n; k++) { - cij += At[k+i*n] * B[k+j*n]; + for(alpha = g; alpha < g + blockI; alpha ++) { + for (beta = h; beta < h + blockJ; beta++) { + cij = C[alpha+beta*n]; + for (gamma = f; gamma < n; gamma++) { + cij += At[gamma+alpha*n] * B[gamma+beta*n]; } - C[i+j*n] = cij; + C[alpha+beta*n] = cij; } } }