From 7112b73d0d57f93563f37efd6f2a3ddf6d826dd7 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Fri, 12 Mar 2021 15:15:24 +0530 Subject: [PATCH] disabled zgemm induced and gemm sqp temporarily. 1. mx1, mx4 kernel addition and framework modification. 2. 8mx6n kernel addition. 3. NULL check added in dgemm_sqp malloc. 4. mem tracing added. 5. Restricted 3m_sqp to limited matrix sizes. 6. Induced methods disabled temporarily for debug. AMD-Internal: [CPUPL-1352] Change-Id: I31671859b32bfbb359687fb7c9056f9eb904c8b2 --- frame/compat/bla_gemm.c | 25 +- kernels/zen/3/bli_gemm_sqp.c | 1479 ++++++++++++++++++++++++---------- 2 files changed, 1090 insertions(+), 414 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 184361e69..e3041b1a8 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -38,7 +38,7 @@ // // Define BLAS-to-BLIS interfaces. // - +#define ENABLE_INDUCED_METHOD 0 #ifdef BLIS_BLAS3_CALLS_TAPI #undef GENTFUNC @@ -639,10 +639,26 @@ void zgemm_ } // The code below will be called when number of threads = 1. +#if ENABLE_INDUCED_METHOD + /* 3m_sqp is optimal for certain matrix shapes. + Initial study that it works well for square sizes and sizes closer to square shape. - - dim_t m8rem = m0&7; - if( ((blis_transa==BLIS_TRANSPOSE) || (blis_transa==BLIS_NO_TRANSPOSE)) && (blis_transb==BLIS_NO_TRANSPOSE) &&(m8rem==0)&&(n0>40)) + * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. + * Further investigation is necessary to make the usage choices more generic. */ + bool sqp_on = false; + if((m0==n0)&&(n0==k0)&&(m0==128)) + { + sqp_on = true; + } +#if 0 + // though this range is giving 60 gflops/s in standalone, while integration in app cause performance degradation. + // to be enabled after fixing. + if((m0>=4200) && (m0<=4600) && (n0==326)&&(k0==1120)) //to be tuned further. + { + sqp_on = true; + } +#endif + if( ((blis_transa==BLIS_TRANSPOSE) || (blis_transa==BLIS_NO_TRANSPOSE)) && (blis_transb==BLIS_NO_TRANSPOSE) && (sqp_on==true)) { //sqp algo is found better for n > 40 if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) @@ -660,6 +676,7 @@ void zgemm_ return; } else +#endif//ENABLE_INDUCED_METHOD { err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); if(status==BLIS_SUCCESS) diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c index fab6950d2..84924a57f 100644 --- a/kernels/zen/3/bli_gemm_sqp.c +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -39,6 +39,8 @@ #define MEM_ALLOC 1//malloc performs better than bli_malloc. #define BLIS_MX8 8 +#define BLIS_MX4 4 +#define BLIS_MX1 1 #define DEBUG_3M_SQP 0 typedef struct { @@ -48,8 +50,8 @@ typedef struct { void* unalignedBuf; }mem_block; -static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA); -static err_t bli_dgemm_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha); +static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA, gint_t mx, gint_t* p_istart); +static err_t bli_dgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha, gint_t mx, gint_t* p_istart); /* * The bli_gemm_sqp (square packed) function performs dgemm and 3m zgemm. @@ -115,15 +117,14 @@ err_t bli_gemm_sqp } dim_t m8rem = m - ((m>>3)<<3); - if(m8rem!=0) - { - /* Residue kernel m4 and m1 to be implemented */ - return BLIS_NOT_YET_IMPLEMENTED; - } double* ap = ( double* )bli_obj_buffer( a ); double* bp = ( double* )bli_obj_buffer( b ); double* cp = ( double* )bli_obj_buffer( c ); + gint_t istart = 0; + gint_t* p_istart = &istart; + *p_istart = 0; + err_t status; if(dt==BLIS_DCOMPLEX) { dcomplex* alphap = ( dcomplex* )bli_obj_buffer( alpha ); @@ -139,7 +140,21 @@ err_t bli_gemm_sqp return BLIS_NOT_YET_IMPLEMENTED; } /* 3m zgemm implementation for C = AxB and C = AtxB */ - return bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA); +#if 0 + return bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, 8, p_istart); +#else + status = bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, 8, p_istart); + if(m8rem==0) + { + return status;// No residue: done + } + else + { + //complete residue m blocks + status = bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, 1, p_istart); + return status; + } +#endif } else if(dt == BLIS_DOUBLE) { @@ -156,7 +171,21 @@ err_t bli_gemm_sqp return BLIS_NOT_YET_IMPLEMENTED; } /* dgemm implementation with 8mx5n major kernel and column preferred storage */ - return bli_dgemm_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, isTransA, (*alpha_cast)); + status = bli_dgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, isTransA, (*alpha_cast), 8, p_istart); + if(status==BLIS_SUCCESS) + { + if(m8rem==0) + { + return status;// No residue: done + } + else + { + //complete residue m blocks + status = bli_dgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, isTransA, (*alpha_cast), 1, p_istart); + return status; + } + } + } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); @@ -166,7 +195,133 @@ err_t bli_gemm_sqp /************************************************************************************************************/ /************************** dgemm kernels (8mxn) column preffered ******************************************/ /************************************************************************************************************/ -/* Main dgemm kernel 8mx5n with single load and store of C matrix block + +/* Main dgemm kernel 8mx6n with single load and store of C matrix block + alpha = +/-1 and beta = +/-1,0 handled while packing.*/ +inc_t bli_kernel_8mx6n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + + __m256d av0, av1; + __m256d bv0, bv1; + __m256d cv0, cv1, cv2, cv3, cv4, cv5; + __m256d cx0, cx1, cx2, cx3, cx4, cx5; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc6 = ldc * 6; inc_t ldb6 = ldb * 6; + + for (j = 0; j <= (n - 6); j += 6) { + + //printf("x"); + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; double* pcldc5 = pcldc4 + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; double* pbldb5 = pbldb4 + ldb; + +#if BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(pc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc5), _MM_HINT_T0); + + _mm_prefetch((char*)(aPacked), _MM_HINT_T0); + + _mm_prefetch((char*)(pb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb5), _MM_HINT_T0); +#endif + /* C matrix column major load */ +#if BLIS_LOADFIRST + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); + cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); + cv5 = _mm256_loadu_pd(pcldc5); cx5 = _mm256_loadu_pd(pcldc5 + 4); +#else + cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); + cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); + cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); + cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); + cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); + cv5 = _mm256_setzero_pd(); cx5 = _mm256_setzero_pd(); +#endif + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(x); x += 4; av1 = _mm256_loadu_pd(x); x += 4; + bv0 = _mm256_broadcast_sd (pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cx0 = _mm256_fmadd_pd(av1, bv0, cx0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cx1 = _mm256_fmadd_pd(av1, bv1, cx1); + + bv0 = _mm256_broadcast_sd(pbldb2);pbldb2++; + bv1 = _mm256_broadcast_sd(pbldb3);pbldb3++; + cv2 = _mm256_fmadd_pd(av0, bv0, cv2); + cx2 = _mm256_fmadd_pd(av1, bv0, cx2); + cv3 = _mm256_fmadd_pd(av0, bv1, cv3); + cx3 = _mm256_fmadd_pd(av1, bv1, cx3); + + bv0 = _mm256_broadcast_sd(pbldb4);pbldb4++; + bv1 = _mm256_broadcast_sd(pbldb5);pbldb5++; + cv4 = _mm256_fmadd_pd(av0, bv0, cv4); + cx4 = _mm256_fmadd_pd(av1, bv0, cx4); + cv5 = _mm256_fmadd_pd(av0, bv1, cv5); + cx5 = _mm256_fmadd_pd(av1, bv1, cx5); + } +#if BLIS_LOADFIRST +#else + bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); + cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); + + av0 = _mm256_loadu_pd(pcldc); av1 = _mm256_loadu_pd(pcldc + 4); + cv1 = _mm256_add_pd(cv1, av0); cx1 = _mm256_add_pd(cx1, av1); + + bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); + cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); + + av0 = _mm256_loadu_pd(pcldc3); av1 = _mm256_loadu_pd(pcldc3 + 4); + cv3 = _mm256_add_pd(cv3, av0); cx3 = _mm256_add_pd(cx3, av1); + + bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); + cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); + + av0 = _mm256_loadu_pd(pcldc5); av1 = _mm256_loadu_pd(pcldc5 + 4); + cv5 = _mm256_add_pd(cv5, av0); cx5 = _mm256_add_pd(cx5, av1); +#endif + /* C matrix column major store */ + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc3 + 4, cx3); + + _mm256_storeu_pd(pcldc4, cv4); + _mm256_storeu_pd(pcldc4 + 4, cx4); + + _mm256_storeu_pd(pcldc5, cv5); + _mm256_storeu_pd(pcldc5 + 4, cx5); + + pc += ldc6;pb += ldb6; + } + + return j; +} + +/* alternative Main dgemm kernel 8mx5n with single load and store of C matrix block alpha = +/-1 and beta = +/-1,0 handled while packing.*/ inc_t bli_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) { @@ -485,6 +640,174 @@ inc_t bli_kernel_8mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld return j; } +#if 0 +/************************************************************************************************************/ +/************************** dgemm kernels (4mxn) column preffered ******************************************/ +/************************************************************************************************************/ +/* Residue dgemm kernel 4mx10n with single load and store of C matrix block + alpha = +/-1 and beta = +/-1,0 handled while packing.*/ +inc_t bli_kernel_4mx10n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + /* incomplete */ + __m256d av0; + __m256d bv0, bv1, bv2, bv3; + __m256d cv0, cv1, cv2, cv3; + __m256d cx0, cx1, cx2, cx3; + __m256d bv4, cv4, cx4; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc10 = ldc * 10; inc_t ldb10 = ldb * 10; + + for (j = 0; j <= (n - 10); j += 10) { + + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; + +#if BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(pc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); + + _mm_prefetch((char*)(aPacked), _MM_HINT_T0); + + _mm_prefetch((char*)(pb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); +#endif + /* C matrix column major load */ +#if BLIS_LOADFIRST + cv0 = _mm256_loadu_pd(pc); + cv1 = _mm256_loadu_pd(pcldc); + cv2 = _mm256_loadu_pd(pcldc2); + cv3 = _mm256_loadu_pd(pcldc3); + cv4 = _mm256_loadu_pd(pcldc4); +#else + cv0 = _mm256_setzero_pd(); + cv1 = _mm256_setzero_pd(); + cv2 = _mm256_setzero_pd(); + cv3 = _mm256_setzero_pd(); + cv4 = _mm256_setzero_pd(); +#endif + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; + bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + cv3 = _mm256_fmadd_pd(av0, bv3, cv3); + cv4 = _mm256_fmadd_pd(av0, bv4, cv4); + + } +#if BLIS_LOADFIRST +#else + bv0 = _mm256_loadu_pd(pc); + cv0 = _mm256_add_pd(cv0, bv0); + + bv2 = _mm256_loadu_pd(pcldc); + cv1 = _mm256_add_pd(cv1, bv2); + + bv0 = _mm256_loadu_pd(pcldc2); + cv2 = _mm256_add_pd(cv2, bv0); + + bv2 = _mm256_loadu_pd(pcldc3); + cv3 = _mm256_add_pd(cv3, bv2); + + bv0 = _mm256_loadu_pd(pcldc4); + cv4 = _mm256_add_pd(cv4, bv0); +#endif + /* C matrix column major store */ + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc4, cv4); + + + pc += ldc10;pb += ldb10; + } + + return j; +} + +/* residue dgemm kernel 4mx1n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_kernel_4mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0; + __m256d cv0; + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + cv0 = _mm256_loadu_pd(pc); + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + } + _mm256_storeu_pd(pc, cv0); + pc += ldc;pb += ldb; + }// j loop 1 multiple + return j; +} + +#endif +/************************************************************************************************************/ +/************************** dgemm kernels (1mxn) column preffered ******************************************/ +/************************************************************************************************************/ + +/* residue dgemm kernel 1mx1n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_kernel_1mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + double a0; + double b0; + double c0; + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + c0 = *pc; + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + b0 = *pb0; pb0++; + a0 = *x; x++; + c0 += (a0 * b0); + } + *pc = c0; + pc += ldc;pb += ldb; + }// j loop 1 multiple + return j; +} + /* Ax8 packing subroutine */ void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha) { @@ -543,26 +866,106 @@ void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isT } } -/* A8x4 packing subroutine */ -void bli_prepackA_8x4(double* pa, double* aPacked, gint_t k, guint_t lda) +/* Ax4 packing subroutine */ +void bli_prepackA_4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha) { - __m256d av00, av10; - __m256d av01, av11; - __m256d av02, av12; - __m256d av03, av13; + __m256d av0, ymm0; + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); pa += lda; + _mm256_storeu_pd(aPacked, av0); + aPacked += BLIS_MX4; + } + } + else if(alpha==-1.0) + { + ymm0 = _mm256_setzero_pd();//set zero + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); pa += lda; + av0 = _mm256_sub_pd(ymm0,av0); // a = 0 - a; + _mm256_storeu_pd(aPacked, av0); + aPacked += BLIS_MX4; + } + } + } + else + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX4 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX4; + *(aPacked + sidx + i) = ar_; + } + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX4 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX4; + *(aPacked + sidx + i) = -ar_; + } + } + } + } +} - for (gint_t p = 0; p < k; p += 4) { - av00 = _mm256_loadu_pd(pa); av10 = _mm256_loadu_pd(pa + 4); pa += lda; - av01 = _mm256_loadu_pd(pa); av11 = _mm256_loadu_pd(pa + 4); pa += lda; - av02 = _mm256_loadu_pd(pa); av12 = _mm256_loadu_pd(pa + 4); pa += lda; - av03 = _mm256_loadu_pd(pa); av13 = _mm256_loadu_pd(pa + 4); pa += lda; - - _mm256_storeu_pd(aPacked, av00); _mm256_storeu_pd(aPacked + 4, av10); - _mm256_storeu_pd(aPacked + 8, av01); _mm256_storeu_pd(aPacked + 12, av11); - _mm256_storeu_pd(aPacked + 16, av02); _mm256_storeu_pd(aPacked + 20, av12); - _mm256_storeu_pd(aPacked + 24, av03); _mm256_storeu_pd(aPacked + 28, av13); - - aPacked += 32; +/* Ax1 packing subroutine */ +void bli_prepackA_1(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha) +{ + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + *aPacked = *pa; + pa += lda; + aPacked++; + } + } + else if(alpha==-1.0) + { + for (gint_t p = 0; p < k; p += 1) { + *aPacked = -(*pa); + pa += lda; + aPacked++; + } + } + } + else + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+p); + *(aPacked + p) = ar_; + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+p); + *(aPacked + p) = -ar_; + } + } } } @@ -575,23 +978,28 @@ void bli_prepackA_8x4(double* pa, double* aPacked, gint_t k, guint_t lda) In majority of use-case, alpha are +/-1, so instead of explicitly multiplying alpha its done during packing itself by changing sign. */ -static err_t bli_dgemm_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha) +static err_t bli_dgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha, gint_t mx, gint_t* p_istart) { double* aPacked; double* aligned = NULL; + gint_t i; bool pack_on = false; - if((m!=BLIS_MX8)||(m!=lda)||isTransA) + if((m!=mx)||(m!=lda)||isTransA) { pack_on = true; } if(pack_on==true) { - aligned = (double*)bli_malloc_user(sizeof(double) * k * BLIS_MX8); + aligned = (double*)bli_malloc_user(sizeof(double) * k * mx); + if(aligned==NULL) + { + return BLIS_MALLOC_RETURNED_NULL; + } } - for (gint_t i = 0; i < m; i += BLIS_MX8) //this loop can be threaded. no of workitems = m/8 + for (i = (*p_istart); i <= (m-mx); i += mx) //this loop can be threaded. no of workitems = m/8 { inc_t j = 0; double* ci = c + i; @@ -603,31 +1011,57 @@ static err_t bli_dgemm_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, { pa = a + (i*lda); } - bli_prepackA_8(pa, aPacked, k, lda, isTransA, alpha); - //bli_prepackA_8x4(a + i, aPacked, k, lda); + /* should be changed to func pointer */ + if(mx==8) + { + bli_prepackA_8(pa, aPacked, k, lda, isTransA, alpha); + } + else if(mx==4) + { + bli_prepackA_4(pa, aPacked, k, lda, isTransA, alpha); + } + else if(mx==1) + { + bli_prepackA_1(pa, aPacked, k, lda, isTransA, alpha); + } } else { aPacked = a+i; } - - j = bli_kernel_8mx5n(n, k, j, aPacked, lda, b, ldb, ci, ldc); - if (j <= n - 4) + if(mx==8) { - j = bli_kernel_8mx4n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + //printf(" mx8i:%3ld ", i); + //8mx6n currently turned off to isolate a bug. + //j = bli_kernel_8mx6n(n, k, j, aPacked, lda, b, ldb, ci, ldc); + if (j <= n - 5) + { + j = bli_kernel_8mx5n(n, k, j, aPacked, lda, b, ldb, ci, ldc); + } + if (j <= n - 4) + { + j = bli_kernel_8mx4n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } + if (j <= n - 3) + { + j = bli_kernel_8mx3n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } + if (j <= n - 2) + { + j = bli_kernel_8mx2n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } + if (j <= n - 1) + { + j = bli_kernel_8mx1n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } } - if (j <= n - 3) + /* mx==4 to be implemented */ + else if(mx==1) { - j = bli_kernel_8mx3n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); - } - if (j <= n - 2) - { - j = bli_kernel_8mx2n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); - } - if (j <= n - 1) - { - j = bli_kernel_8mx1n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + //printf(" mx1i:%3ld ", i); + j = bli_kernel_1mx1n(n, k, j, aPacked, lda, b, ldb, ci, ldc); } + *p_istart = i + mx; } if(pack_on==true) @@ -648,6 +1082,10 @@ gint_t bli_getaligned(mem_block* mem_req) } memSize += 128;// extra 128 bytes added for alignment. Could be minimized to 64. #if MEM_ALLOC +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "malloc(): size %ld\n",( long )memSize; + fflush( stdout ); +#endif mem_req->unalignedBuf = (double*)malloc(memSize); if (mem_req->unalignedBuf == NULL) { @@ -677,9 +1115,24 @@ gint_t bli_allocateWorkspace(gint_t n, gint_t k, mem_block *mxr, mem_block *mxi, if (!((bli_getaligned(mxr) == 0) && (bli_getaligned(mxi) == 0) && (bli_getaligned(msx) == 0))) { +#if MEM_ALLOC + if(mxr->unalignedBuf) + { + free(mxr->unalignedBuf); + } + if(mxi->unalignedBuf) + { + free(mxi->unalignedBuf); + } + if(msx->unalignedBuf) + { + free(msx->unalignedBuf); + } +#else bli_free_user(mxr->alignedBuf); bli_free_user(mxi->alignedBuf); bli_free_user(msx->alignedBuf); +#endif return -1; } return 0; @@ -731,375 +1184,492 @@ void bli_sub_m(gint_t m, gint_t n, double* w, double* c) } } -void bli_packX_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul) +void bli_packX_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul, gint_t mx) { gint_t j, p; __m256d av0, av1, zerov; __m256d tv0, tv1; - - if((mul ==1.0)||(mul==-1.0)) + if(mx==8) { - if(mul ==1.0) + if((mul ==1.0)||(mul==-1.0)) { - for (j = 0; j < n; j++) + if(mul ==1.0) { - for (p = 0; p <= ((k*2)-8); p += 8) + for (j = 0; j < n; j++) { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = *(pb + p) ; + double bi = *(pb + p + 1); + *pbr = br; + *pbi = bi; + pbr++; pbi++; + } + pb = pb + ldb; } - - for (; p < (k*2); p += 2)// (real + imag)*k + } + else + { + zerov = _mm256_setzero_pd(); + for (j = 0; j < n; j++) { - double br = *(pb + p) ; - double bi = *(pb + p + 1); - *pbr = br; - *pbi = bi; - pbr++; pbi++; + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = -*(pb + p) ; + double bi = -*(pb + p + 1); + *pbr = br; + *pbi = bi; + pbr++; pbi++; + } + pb = pb + ldb; } - pb = pb + ldb; } } else { - zerov = _mm256_setzero_pd(); for (j = 0; j < n; j++) { - for (p = 0; p <= ((k*2)-8); p += 8) + for (p = 0; p < (k*2); p += 2)// (real + imag)*k { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = -*(pb + p) ; - double bi = -*(pb + p + 1); - *pbr = br; - *pbi = bi; + double br_ = mul * (*(pb + p)); + double bi_ = mul * (*(pb + p + 1)); + *pbr = br_; + *pbi = bi_; pbr++; pbi++; } pb = pb + ldb; } } - } - else + }//mx==8 +#if 0//already taken care in previous loop + else//mx==1 { - for (j = 0; j < n; j++) + if((mul ==1.0)||(mul==-1.0)) { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k + if(mul ==1.0) { - double br_ = mul * (*(pb + p)); - double bi_ = mul * (*(pb + p + 1)); - *pbr = br_; - *pbi = bi_; - pbr++; pbi++; + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < (mx*2); ii += 2) + { + double cr_ = c[(j * ldc) + i + ii]; + double ci_ = c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } + else + { + //mul = -1.0 + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < (mx*2); ii += 2) + { + double cr_ = -c[(j * ldc) + i + ii]; + double ci_ = -c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } } - pb = pb + ldb; } - } + else + { + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < (mx*2); ii += 2) + { + double cr_ = mul*c[(j * ldc) + i + ii]; + double ci_ = mul*c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } + }//mx==1 +#endif } -void bli_packX_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul) +void bli_packX_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul, gint_t mx) { gint_t j, p; __m256d av0, av1, zerov; __m256d tv0, tv1, sum; - if((mul ==1.0)||(mul==-1.0)) + if(mx==8) { - if(mul ==1.0) + if((mul ==1.0)||(mul==-1.0)) { - for (j = 0; j < n; j++) + if(mul ==1.0) { - for (p = 0; p <= ((k*2)-8); p += 8) + for (j = 0; j < n; j++) { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = *(pb + p) ; + double bi = *(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; } - - for (; p < (k*2); p += 2)// (real + imag)*k + } + else + { + zerov = _mm256_setzero_pd(); + for (j = 0; j < n; j++) { - double br = *(pb + p) ; - double bi = *(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); - pbr++; pbi++; pbs++; + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = -*(pb + p) ; + double bi = -*(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; } - pb = pb + ldb; } } else { - zerov = _mm256_setzero_pd(); for (j = 0; j < n; j++) { - for (p = 0; p <= ((k*2)-8); p += 8) + for (p = 0; p < (k*2); p += 2)// (real + imag)*k { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = -*(pb + p) ; - double bi = -*(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; + double br_ = mul * (*(pb + p)); + double bi_ = mul * (*(pb + p + 1)); + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; pbr++; pbi++; pbs++; } pb = pb + ldb; } } - } + }//mx==8 +#if 0 else { - for (j = 0; j < n; j++) + if((alpha ==1.0)||(alpha==-1.0)) { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k + if(alpha ==1.0) { - double br_ = mul * (*(pb + p)); - double bi_ = mul * (*(pb + p + 1)); - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = b[(j * ldb) + p]; + double bi_ = b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; - pbr++; pbi++; pbs++; + pbr++; pbi++; pbs++; + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = -b[(j * ldb) + p]; + double bi_ = -b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = alpha * b[(j * ldb) + p]; + double bi_ = alpha * b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } } - pb = pb + ldb; } } + #endif } -void bli_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, bool isTransA) +void bli_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, bool isTransA, gint_t mx) { __m256d av0, av1, av2, av3; __m256d tv0, tv1, sum; gint_t p; - if(isTransA==false) + + if(mx==8) { - pa = pa +i; - for (p = 0; p < k; p += 1) + if(isTransA==false) { - //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. - #if 1 - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); + pa = pa +i; + for (p = 0; p < k; p += 1) + { + //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. + av0 = _mm256_loadu_pd(pa); + av1 = _mm256_loadu_pd(pa+4); + av2 = _mm256_loadu_pd(pa+8); + av3 = _mm256_loadu_pd(pa+12); - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(par, av0); par += 4; - _mm256_storeu_pd(pai, av1); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(par, av0); par += 4; + _mm256_storeu_pd(pai, av1); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; - tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); - tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); - av2 = _mm256_unpacklo_pd(tv0, tv1); - av3 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av2, av3); - _mm256_storeu_pd(par, av2); par += 4; - _mm256_storeu_pd(pai, av3); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - #else //method 2 - __m128d high, low, real, img, sum; - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); - high = _mm256_extractf128_pd(av0, 1); - low = _mm256_castpd256_pd128(av0); - real = _mm_shuffle_pd(low, high, 0b00); - img = _mm_shuffle_pd(low, high, 0b11); - sum = _mm_add_pd(real, img); - _mm_storeu_pd(par, real); par += 2; - _mm_storeu_pd(pai, img); pai += 2; - _mm_storeu_pd(pas, sum); pas += 2; + tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); + tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); + av2 = _mm256_unpacklo_pd(tv0, tv1); + av3 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av2, av3); + _mm256_storeu_pd(par, av2); par += 4; + _mm256_storeu_pd(pai, av3); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; - high = _mm256_extractf128_pd(av1, 1); - low = _mm256_castpd256_pd128(av1); - real = _mm_shuffle_pd(low, high, 0b00); - img = _mm_shuffle_pd(low, high, 0b11); - sum = _mm_add_pd(real, img); - _mm_storeu_pd(par, real); par += 2; - _mm_storeu_pd(pai, img); pai += 2; - _mm_storeu_pd(pas, sum); pas += 2; + pa = pa + lda; + } + } + else + { + gint_t idx = (i/2) * lda; + pa = pa + idx; + #if 0 + for (int p = 0; p <= ((2*k)-8); p += 8) + { + //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. + av0 = _mm256_loadu_pd(pa); + av1 = _mm256_loadu_pd(pa+4); + av2 = _mm256_loadu_pd(pa+8); + av3 = _mm256_loadu_pd(pa+12); - high = _mm256_extractf128_pd(av2, 1); - low = _mm256_castpd256_pd128(av2); - real = _mm_shuffle_pd(low, high, 0b00); - img = _mm_shuffle_pd(low, high, 0b11); - sum = _mm_add_pd(real, img); - _mm_storeu_pd(par, real); par += 2; - _mm_storeu_pd(pai, img); pai += 2; - _mm_storeu_pd(pas, sum); pas += 2; + //transpose 4x4 + tv0 = _mm256_unpacklo_pd(av0, av1); + tv1 = _mm256_unpackhi_pd(av0, av1); + tv2 = _mm256_unpacklo_pd(av2, av3); + tv3 = _mm256_unpackhi_pd(av2, av3); - high = _mm256_extractf128_pd(av3, 1); - low = _mm256_castpd256_pd128(av3); - real = _mm_shuffle_pd(low, high, 0b00); - img = _mm_shuffle_pd(low, high, 0b11); - sum = _mm_add_pd(real, img); - _mm_storeu_pd(par, real); par += 2; - _mm_storeu_pd(pai, img); pai += 2; - _mm_storeu_pd(pas, sum); pas += 2; + av0 = _mm256_permute2f128_pd(tv0, tv2, 0x20); + av1 = _mm256_permute2f128_pd(tv1, tv3, 0x20); + av2 = _mm256_permute2f128_pd(tv0, tv2, 0x31); + av3 = _mm256_permute2f128_pd(tv1, tv3, 0x31); + + //get real, imag and sum + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(par, av0); par += 4; + _mm256_storeu_pd(pai, av1); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); + tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); + av2 = _mm256_unpacklo_pd(tv0, tv1); + av3 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av2, av3); + _mm256_storeu_pd(par, av2); par += 4; + _mm256_storeu_pd(pai, av3); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + pa = pa + lda; + } #endif - pa = pa + lda; + //A Transpose case: + for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) + { + gint_t idx = ii * lda; + gint_t sidx; + for (p = 0; p <= ((k*2)-8); p += 8) + { + double ar0_ = *(pa + idx + p); + double ai0_ = *(pa + idx + p + 1); + + double ar1_ = *(pa + idx + p + 2); + double ai1_ = *(pa + idx + p + 3); + + double ar2_ = *(pa + idx + p + 4); + double ai2_ = *(pa + idx + p + 5); + + double ar3_ = *(pa + idx + p + 6); + double ai3_ = *(pa + idx + p + 7); + + sidx = (p/2) * BLIS_MX8; + *(par + sidx + ii) = ar0_; + *(pai + sidx + ii) = ai0_; + *(pas + sidx + ii) = ar0_ + ai0_; + + sidx = ((p+2)/2) * BLIS_MX8; + *(par + sidx + ii) = ar1_; + *(pai + sidx + ii) = ai1_; + *(pas + sidx + ii) = ar1_ + ai1_; + + sidx = ((p+4)/2) * BLIS_MX8; + *(par + sidx + ii) = ar2_; + *(pai + sidx + ii) = ai2_; + *(pas + sidx + ii) = ar2_ + ai2_; + + sidx = ((p+6)/2) * BLIS_MX8; + *(par + sidx + ii) = ar3_; + *(pai + sidx + ii) = ai3_; + *(pas + sidx + ii) = ar3_ + ai3_; + + } + + for (; p < (k*2); p += 2) + { + double ar_ = *(pa + idx + p); + double ai_ = *(pa + idx + p + 1); + gint_t sidx = (p/2) * BLIS_MX8; + *(par + sidx + ii) = ar_; + *(pai + sidx + ii) = ai_; + *(pas + sidx + ii) = ar_ + ai_; + } + } } - } - else + } //mx==8 + else//mx==1 { - gint_t idx = (i/2) * lda; - pa = pa + idx; - -#if 0 - for (int p = 0; p <= ((2*k)-8); p += 8) + if(isTransA==false) { - //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); - - //transpose 4x4 - tv0 = _mm256_unpacklo_pd(av0, av1); - tv1 = _mm256_unpackhi_pd(av0, av1); - tv2 = _mm256_unpacklo_pd(av2, av3); - tv3 = _mm256_unpackhi_pd(av2, av3); - - av0 = _mm256_permute2f128_pd(tv0, tv2, 0x20); - av1 = _mm256_permute2f128_pd(tv1, tv3, 0x20); - av2 = _mm256_permute2f128_pd(tv0, tv2, 0x31); - av3 = _mm256_permute2f128_pd(tv1, tv3, 0x31); - - //get real, imag and sum - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(par, av0); par += 4; - _mm256_storeu_pd(pai, av1); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); - tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); - av2 = _mm256_unpacklo_pd(tv0, tv1); - av3 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av2, av3); - _mm256_storeu_pd(par, av2); par += 4; - _mm256_storeu_pd(pai, av3); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - pa = pa + lda; - } -#endif - //A Transpose case: - for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) - { - gint_t idx = ii * lda; - gint_t sidx; - for (p = 0; p <= ((k*2)-8); p += 8) + pa = pa +i; + //A No transpose case: + for (gint_t p = 0; p < k; p += 1) { - double ar0_ = *(pa + idx + p); - double ai0_ = *(pa + idx + p + 1); - - double ar1_ = *(pa + idx + p + 2); - double ai1_ = *(pa + idx + p + 3); - - double ar2_ = *(pa + idx + p + 4); - double ai2_ = *(pa + idx + p + 5); - - double ar3_ = *(pa + idx + p + 6); - double ai3_ = *(pa + idx + p + 7); - - sidx = (p/2) * BLIS_MX8; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; - - sidx = ((p+2)/2) * BLIS_MX8; - *(par + sidx + ii) = ar1_; - *(pai + sidx + ii) = ai1_; - *(pas + sidx + ii) = ar1_ + ai1_; - - sidx = ((p+4)/2) * BLIS_MX8; - *(par + sidx + ii) = ar2_; - *(pai + sidx + ii) = ai2_; - *(pas + sidx + ii) = ar2_ + ai2_; - - sidx = ((p+6)/2) * BLIS_MX8; - *(par + sidx + ii) = ar3_; - *(pai + sidx + ii) = ai3_; - *(pas + sidx + ii) = ar3_ + ai3_; - - } - - for (; p < (k*2); p += 2) - { - double ar_ = *(pa + idx + p); - double ai_ = *(pa + idx + p + 1); - gint_t sidx = (p/2) * BLIS_MX8; - *(par + sidx + ii) = ar_; - *(pai + sidx + ii) = ai_; - *(pas + sidx + ii) = ar_ + ai_; + gint_t idx = p * lda; + for (gint_t ii = 0; ii < (mx*2) ; ii += 2) + { //real + imag : Rkernel needs 8 elements each. + double ar_ = *(pa + idx + ii); + double ai_ = *(pa + idx + ii + 1); + *par = ar_; + *pai = ai_; + *pas = ar_ + ai_; + par++; pai++; pas++; + } } } - } + else + { + gint_t idx = (i/2) * lda; + pa = pa + idx; + + //A Transpose case: + for (gint_t ii = 0; ii < mx ; ii++) + { + gint_t idx = ii * lda; + gint_t sidx; + for (p = 0; p < (k*2); p += 2) + { + double ar0_ = *(pa + idx + p); + double ai0_ = *(pa + idx + p + 1); + + sidx = (p/2) * mx; + *(par + sidx + ii) = ar0_; + *(pai + sidx + ii) = ai0_; + *(pas + sidx + ii) = ar0_ + ai0_; + + } + } + } + }//mx==1 } /************************************************************************************************************/ @@ -1109,8 +1679,14 @@ void bli_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, doubl 3m_sqp focuses mainly on square matrixes but also supports non-square matrix. Current support is limiteed to m multiple of 8 and column storage. */ -static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA) +static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA, gint_t mx, gint_t* p_istart) { + inc_t m2 = m<<1; + inc_t mxmul2 = mx<<1; + if((*p_istart) > (m2-mxmul2)) + { + return BLIS_SUCCESS; + } /* B matrix */ double* br, * bi, * bs; mem_block mbr, mbi, mbs; @@ -1127,59 +1703,75 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l ldb = ldb * 2; ldc = ldc * 2; -//debug to be removed. -#if DEBUG_3M_SQP - double ax[8][16] = { {10,-10,20,-20,30,-30,40,-40,50,-50,60,-60,70,-70,80,-80}, - {1.1,-1.1,2.1,-2.1,3.1,-3.1,4.1,-4.1,5.1,-5.1,6.1,-6.1,7.1,-7.1,8.1,-8.1}, - {1.2,-1.2,2.2,-2.2,3.2,-3.2,4.2,-4.2,5.2,-5.2,6.2,-6.2,7.2,-7.2,8.2,-8.2}, - {1.3,-1.3,2.3,-2.3,3.3,-3.3,4.3,-4.3,5.3,-5.3,6.3,-6.3,7.3,-7.3,8.3,-8.3}, - - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8} }; - - - double bx[6][16] = { {10,-10,20,-20,30,-30,40,-40,50,-50,60,-60,70,-70,80,-80}, - {1.1,-1.1,2.1,-2.1,3.1,-3.1,4.1,-4.1,5.1,-5.1,6.1,-6.1,7.1,-7.1,8.1,-8.1}, - {1.2,-1.2,2.2,-2.2,3.2,-3.2,4.2,-4.2,5.2,-5.2,6.2,-6.2,7.2,-7.2,8.2,-8.2}, - - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8} }; - - double cx[8][12] = { {10,-10,20,-20,30,-30,40,-40,50,-50,60,-60}, - {1.1,-1.1,2.1,-2.1,3.1,-3.1,4.1,-4.1,5.1,-5.1,6.1,-6.1}, - {1.2,-1.2,2.2,-2.2,3.2,-3.2,4.2,-4.2,5.2,-5.2,6.2,-6.2}, - {1.3,-1.3,2.3,-2.3,3.3,-3.3,4.3,-4.3,5.3,-5.3,6.3,-6.3}, - - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6} }; - - b = &bx[0][0]; - a = &ax[0][0]; - c = &cx[0][0]; -#endif - /* Split b (br, bi) and compute bs = br + bi */ double* pbr = br; double* pbi = bi; double* pbs = bs; - gint_t j; + gint_t j, p; /* b matrix real and imag packing and compute. */ - bli_packX_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha); + //bli_packX_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha, mx); +#if 1//bug in above api to be fixed for mx = 1 + if((alpha ==1.0)||(alpha==-1.0)) + { + if(alpha ==1.0) + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = b[(j * ldb) + p]; + double bi_ = b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + pbr++; pbi++; pbs++; + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = -b[(j * ldb) + p]; + double bi_ = -b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = alpha * b[(j * ldb) + p]; + double bi_ = alpha * b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + } + } +#endif /* Workspace memory allocation currently done dynamically This needs to be taken from already allocated memory pool in application for better performance */ /* A matrix */ double* ar, * ai, * as; mem_block mar, mai, mas; - if(bli_allocateWorkspace(8, k, &mar, &mai, &mas) !=0) + if(bli_allocateWorkspace(mx, k, &mar, &mai, &mas) !=0) { return BLIS_FAILURE; } @@ -1192,7 +1784,7 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l double* w; mem_block mw; mw.data_size = sizeof(double); - mw.size = 8 * n; + mw.size = mx * n; if (bli_getaligned(&mw) != 0) { return BLIS_FAILURE; @@ -1203,7 +1795,7 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l double* cr; mem_block mcr; mcr.data_size = sizeof(double); - mcr.size = 8 * n; + mcr.size = mx * n; if (bli_getaligned(&mcr) != 0) { return BLIS_FAILURE; @@ -1215,14 +1807,14 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l double* ci; mem_block mci; mci.data_size = sizeof(double); - mci.size = 8 * n; + mci.size = mx * n; if (bli_getaligned(&mci) != 0) { return BLIS_FAILURE; } ci = (double*)mci.alignedBuf; - - for (inc_t i = 0; i < (2*m); i += (2*BLIS_MX8)) //this loop can be threaded. + inc_t i; + for (i = (*p_istart); i <= (m2-mxmul2); i += mxmul2) //this loop can be threaded. { ////////////// operation 1 ///////////////// @@ -1233,17 +1825,67 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l double* pas = as; /* a matrix real and imag packing and compute. */ - bli_packA_real_imag_sum(a, i, k, lda, par, pai, pas, isTransA); + bli_packA_real_imag_sum(a, i, k, lda, par, pai, pas, isTransA, mx); double* pcr = cr; double* pci = ci; //Split Cr and Ci and beta multiplication done. double* pc = c + i; - bli_packX_real_imag(pc, n, BLIS_MX8, ldc, pcr, pci, beta); - + //bli_packX_real_imag(pc, n, mx, ldc, pcr, pci, beta, mx); +#if 1 //bug in above api to be fixed for mx = 1 + if((beta ==1.0)||(beta==-1.0)) + { + if(beta ==1.0) + { + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < mxmul2; ii += 2) + { + double cr_ = c[(j * ldc) + i + ii]; + double ci_ = c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } + else + { + //beta = -1.0 + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < mxmul2; ii += 2) + { + double cr_ = -c[(j * ldc) + i + ii]; + double ci_ = -c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < mxmul2; ii += 2) + { + double cr_ = beta*c[(j * ldc) + i + ii]; + double ci_ = beta*c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } +#endif //Ci := rgemm( SA, SB, Ci ) - bli_dgemm_m8(BLIS_MX8, n, k, as, BLIS_MX8, bs, k, ci, BLIS_MX8, false, 1.0); + gint_t istart = 0; + gint_t* p_is = &istart; + *p_is = 0; + bli_dgemm_sqp_m8(mx, n, k, as, mx, bs, k, ci, mx, false, 1.0, mx, p_is); @@ -1251,18 +1893,19 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l //Wr: = dgemm_sqp(Ar, Br, 0) // Wr output 8xn double* wr = w; for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < BLIS_MX8; ii += 1) { + for (gint_t ii = 0; ii < mx; ii += 1) { *wr = 0; wr++; } } wr = w; - bli_dgemm_m8(BLIS_MX8, n, k, ar, BLIS_MX8, br, k, wr, BLIS_MX8, false, 1.0); + *p_is = 0; + bli_dgemm_sqp_m8(mx, n, k, ar, mx, br, k, wr, mx, false, 1.0, mx, p_is); //Cr : = addm(Wr, Cr) - bli_add_m(BLIS_MX8, n, wr, cr); + bli_add_m(mx, n, wr, cr); //Ci : = subm(Wr, Ci) - bli_sub_m(BLIS_MX8, n, wr, ci); + bli_sub_m(mx, n, wr, ci); @@ -1271,18 +1914,19 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l //Wi : = dgemm_sqp(Ai, Bi, 0) // Wi output 8xn double* wi = w; for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < BLIS_MX8; ii += 1) { + for (gint_t ii = 0; ii < mx; ii += 1) { *wi = 0; wi++; } } wi = w; - bli_dgemm_m8(BLIS_MX8, n, k, ai, BLIS_MX8, bi, k, wi, BLIS_MX8, false, 1.0); + *p_is = 0; + bli_dgemm_sqp_m8(mx, n, k, ai, mx, bi, k, wi, mx, false, 1.0, mx, p_is); //Cr : = subm(Wi, Cr) - bli_sub_m(BLIS_MX8, n, wi, cr); + bli_sub_m(mx, n, wi, cr); //Ci : = subm(Wi, Ci) - bli_sub_m(BLIS_MX8, n, wi, ci); + bli_sub_m(mx, n, wi, ci); pcr = cr; @@ -1290,41 +1934,56 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < (2*BLIS_MX8); ii += 2) + for (gint_t ii = 0; ii < mxmul2; ii += 2) { c[(j * ldc) + i + ii] = *pcr; c[(j * ldc) + i + ii + 1] = *pci; pcr++; pci++; } } - + *p_istart = i + mxmul2; } -//debug to be removed. -#if DEBUG_3M_SQP - for (gint_t jj = 0; jj < n;jj++) - { - for (gint_t ii = 0; ii < m;ii++) - { - printf("( %4.2lf %4.2lf) ", *cr, *ci); - cr++;ci++; - } - printf("\n"); - } -#endif #if MEM_ALLOC - free(mar.unalignedBuf); - free(mai.unalignedBuf); - free(mas.unalignedBuf); + if(mar.unalignedBuf) + { + free(mar.unalignedBuf); + } + if(mai.unalignedBuf) + { + free(mai.unalignedBuf); + } + if(mas.unalignedBuf) + { + free(mas.unalignedBuf); + } + if(mw.unalignedBuf) + { + free(mw.unalignedBuf); + } + if(mcr.unalignedBuf) + { + free(mcr.unalignedBuf); + } - free(mw.unalignedBuf); + if(mci.unalignedBuf) + { + free(mci.unalignedBuf); + } + if(mbr.unalignedBuf) + { + free(mbr.unalignedBuf); + } - free(mcr.unalignedBuf); - free(mci.unalignedBuf); + if(mbi.unalignedBuf) + { + free(mbi.unalignedBuf); + } - free(mbr.unalignedBuf); - free(mbi.unalignedBuf); - free(mbs.unalignedBuf); + if(mbs.unalignedBuf) + { + free(mbs.unalignedBuf); + } #else /* free workspace buffers */ bli_free_user(mbr.alignedBuf);