From f7bc37ea3257c6bba74641896c750609768deee3 Mon Sep 17 00:00:00 2001 From: managalv Date: Mon, 1 Jun 2020 21:04:00 +0530 Subject: [PATCH] CPUPL-929: Improve Complex GEMM performance - Support all storage formats and non Transpose/Conjugate Matrices Details Added Support of N SUP kernel for complex float and complex double Removed prefetching in M SUP kernels for complex float and complex double Removed all warnings Change-Id: I05ffde0f0613681927fe7576db7f5f1a4486fd05 --- config/zen/bli_cntx_init_zen.c | 14 +- config/zen2/bli_cntx_init_zen2.c | 14 +- frame/3/bli_l3_sup.c | 24 +- .../zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8.c | 35 - .../zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8m.c | 31 +- .../zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c | 1582 +++++++++++++++++ .../zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c | 44 - .../zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c | 37 +- .../zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c | 1196 +++++++++++++ kernels/zen/bli_kernels_zen.h | 14 + 10 files changed, 2823 insertions(+), 168 deletions(-) create mode 100644 kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c create mode 100644 kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index adcb964a6..52f6fb966 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -196,7 +196,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Initialize sup thresholds with architecture-appropriate values. // s d c z - bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 256, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 256, 256, 128 ); bli_blksz_init_easy( &thresh[ BLIS_KT ], 440, 220, 220, 110 ); @@ -221,7 +221,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized small/unpacked gemm kernels. bli_cntx_set_l3_sup_kers ( - 26, + 28, //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, @@ -242,13 +242,15 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, cntx ); diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 3dccaf64e..550dacdf6 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -183,7 +183,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) ); // Initialize sup thresholds with architecture-appropriate values. s d c z - bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 256, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); bli_blksz_init_easy( &thresh[ BLIS_NT ], 200, 256, 256, 128 ); bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 ); @@ -208,7 +208,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized small/unpacked gemm kernels. bli_cntx_set_l3_sup_kers ( - 26, + 28, //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, @@ -229,13 +229,15 @@ void bli_cntx_init_zen2( cntx_t* cntx ) BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, cntx ); diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c index 321cc2895..d7ce21671 100644 --- a/frame/3/bli_l3_sup.c +++ b/frame/3/bli_l3_sup.c @@ -67,27 +67,23 @@ err_t bli_gemmsup trans_t transa = bli_obj_conjtrans_status( a ); trans_t transb = bli_obj_conjtrans_status( b ); - //Don't use sup for currently unsupported storage types and dimension in cgemmsup + //Don't use sup for currently unsupported storage types in cgemmsup if(bli_obj_is_scomplex(c) && - ((!((stor_id == BLIS_RRR) || (stor_id == BLIS_CRR) - ||(stor_id == BLIS_CCR) || (stor_id == BLIS_RCR) - ||(stor_id == BLIS_CCC))) - || ((m/3) < (n/8)) - || (!((transa == BLIS_NO_TRANSPOSE)&&(transb == BLIS_NO_TRANSPOSE))) + (((stor_id == BLIS_RRC)||(stor_id == BLIS_CRC)) + || ((transa == BLIS_CONJ_NO_TRANSPOSE) || (transa == BLIS_CONJ_TRANSPOSE)) + || ((transb == BLIS_CONJ_NO_TRANSPOSE) || (transb == BLIS_CONJ_TRANSPOSE)) )){ - //printf(" gemmsup: Returning with for un-supported storage types,dimension and matrix property in cgemmsup \n"); + //printf(" gemmsup: Returning with for un-supported storage types and conjugate property in cgemmsup \n"); return BLIS_FAILURE; } - //Don't use sup for currently unsupported storage types and dimension in zgemmsup + //Don't use sup for currently unsupported storage types in zgemmsup if(bli_obj_is_dcomplex(c) && - ((!((stor_id == BLIS_RRR) || (stor_id == BLIS_CRR) - ||(stor_id == BLIS_CCR) || (stor_id == BLIS_RCR) - ||(stor_id == BLIS_CCC))) - || ((m/3) < (n/4)) - || (!((transa == BLIS_NO_TRANSPOSE)&&(transb == BLIS_NO_TRANSPOSE))) + (((stor_id == BLIS_RRC)||(stor_id == BLIS_CRC)) + || ((transa == BLIS_CONJ_NO_TRANSPOSE) || (transa == BLIS_CONJ_TRANSPOSE)) + || ((transb == BLIS_CONJ_NO_TRANSPOSE) || (transb == BLIS_CONJ_TRANSPOSE)) )){ - //printf(" gemmsup: Returning with for un-supported storage types,dimension and matrix property in zgemmsup \n"); + //printf(" gemmsup: Returning with for un-supported storage types and conjugate property in zgemmsup \n"); return BLIS_FAILURE; } diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8.c index e54c94ece..03c1627f1 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8.c @@ -153,8 +153,6 @@ void bli_cgemmsup_rv_zen_asm_2x8 lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -163,15 +161,7 @@ void bli_cgemmsup_rv_zen_asm_2x8 lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 1*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 1*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 1*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c - prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c - prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; - prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 6*cs_c - prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 7*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -627,7 +617,6 @@ void bli_cgemmsup_rv_zen_asm_1x8 lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -636,15 +625,7 @@ void bli_cgemmsup_rv_zen_asm_1x8 lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 0*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 0*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c - prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c - prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; - prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c - prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 7*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -1009,8 +990,6 @@ void bli_cgemmsup_rv_zen_asm_2x4 lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -1019,10 +998,6 @@ void bli_cgemmsup_rv_zen_asm_2x4 lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 1*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 1*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 1*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -1394,7 +1369,6 @@ void bli_cgemmsup_rv_zen_asm_1x4 lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -1418,7 +1392,6 @@ void bli_cgemmsup_rv_zen_asm_1x4 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) vmovups(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1716,8 +1689,6 @@ void bli_cgemmsup_rv_zen_asm_2x2 lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 3*8)) // prefetch c + 1*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -1726,8 +1697,6 @@ void bli_cgemmsup_rv_zen_asm_2x2 lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 1*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 1*8)) // prefetch c + 1*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -2099,7 +2068,6 @@ void bli_cgemmsup_rv_zen_asm_1x2 lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -2108,8 +2076,6 @@ void bli_cgemmsup_rv_zen_asm_1x2 lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 0*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 0*8)) // prefetch c + 1*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -2335,4 +2301,3 @@ void bli_cgemmsup_rv_zen_asm_1x2 "memory" ) } - diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8m.c index bb7296078..8d10406a0 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8m.c @@ -240,9 +240,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -251,9 +248,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -269,7 +263,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) vmovups(mem(rbx, 0*32), ymm0) vmovups(mem(rbx, 1*32), ymm1) @@ -303,7 +296,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) vmovups(mem(rbx, 0*32), ymm0) vmovups(mem(rbx, 1*32), ymm1) @@ -336,7 +328,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) vmovups(mem(rbx, 0*32), ymm0) vmovups(mem(rbx, 1*32), ymm1) @@ -369,7 +360,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; vmovups(mem(rbx, 0*32), ymm0) @@ -868,9 +858,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -879,9 +866,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -897,7 +881,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) vmovups(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -923,7 +906,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) vmovups(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -950,7 +932,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) vmovups(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -977,7 +958,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; vmovups(mem(rbx, 0*32), ymm0) @@ -1367,9 +1347,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -1378,9 +1355,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -1396,7 +1370,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) vmovups(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -1422,7 +1395,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) vmovups(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -1448,7 +1420,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) vmovups(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -1474,7 +1445,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; vmovups(mem(rbx, 0*32), xmm0) @@ -1782,4 +1752,5 @@ void bli_cgemmsup_rv_zen_asm_3x2m return; } } + \ No newline at end of file diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c new file mode 100644 index 000000000..45f889554 --- /dev/null +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c @@ -0,0 +1,1582 @@ + +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "immintrin.h" + +//GENTFUNC( scomplex, c, gemmsup_r_zen_ref_3x1, 3 ) +/* + rrr: + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + + rcr: + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : +*/ +void bli_cgemmsup_rv_zen_asm_3x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 3; + if ( m_left ) + { + cgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_cgemmsup_rv_zen_asm_1x8n, + bli_cgemmsup_rv_zen_asm_2x8n, + }; + cgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + ker_fp + ( + conja, conjb, m_left, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; + } + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + __m128 xmm0, xmm3; + + scomplex *tA = a; + float *tAimag = &a->imag; + scomplex *tB = b; + scomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 8; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm11 = _mm256_setzero_ps(); + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + ymm15 = _mm256_setzero_ps(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*8; + tC = c + n_iter*tc_inc_col*8; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_ps(ymm4, 0xb1); + ymm4 = _mm256_mul_ps(ymm0, ymm4); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_addsub_ps(ymm4, ymm3); + + ymm3 = _mm256_permute_ps(ymm5, 0xb1); + ymm5 = _mm256_mul_ps(ymm0, ymm5); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_addsub_ps(ymm5, ymm3); + + ymm3 = _mm256_permute_ps(ymm8, 0xb1); + ymm8 = _mm256_mul_ps(ymm0, ymm8); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_addsub_ps(ymm8, ymm3); + + ymm3 = _mm256_permute_ps(ymm9, 0xb1); + ymm9 = _mm256_mul_ps(ymm0, ymm9); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm9 = _mm256_addsub_ps(ymm9, ymm3); + + ymm3 = _mm256_permute_ps(ymm12, 0xb1); + ymm12 = _mm256_mul_ps(ymm0, ymm12); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm12 = _mm256_addsub_ps(ymm12, ymm3); + + ymm3 = _mm256_permute_ps(ymm13, 0xb1); + ymm13 = _mm256_mul_ps(ymm0, ymm13); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm13 = _mm256_addsub_ps(ymm13, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 3x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ) ,_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12, 1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm1,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12,1)); + + //transpose right 3x4 + tC += tc_inc_col; + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm5), _mm256_castps_pd(ymm9))); + _mm_storeu_ps((float *)(tC ),_mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm13)); + + ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(ymm5), _mm256_castps_pd(ymm9))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm13)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm13,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm1,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm13,1)); + + } + else{ + ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate + + //Multiply ymm4 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC) ); + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm4 = _mm256_add_ps(ymm4, ymm0); + + //Multiply ymm8 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 1)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 1 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm8 = _mm256_add_ps(ymm8, ymm0); + + //Multiply ymm12 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 2)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 2 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm12 = _mm256_add_ps(ymm12, ymm0); + + //transpose left 3x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + ymm3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm3)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12, 1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm3,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12,1)); + + //Multiply ymm5 with beta + tC += tc_inc_col; + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm5 = _mm256_add_ps(ymm5, ymm0); + + //Multiply ymm9 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC+ 1)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC+ 1 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC+ 1 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC+ 1 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm9 = _mm256_add_ps(ymm9, ymm0); + + //Multiply ymm13 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 2)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 2 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 2 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm13 = _mm256_add_ps(ymm13, ymm0); + + //transpose right 3x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm5), _mm256_castps_pd(ymm9))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm13)); + + ymm3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(ymm5), _mm256_castps_pd(ymm9))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ), _mm256_castps256_ps128(ymm3)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm13)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm13,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm3,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm13,1)); + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + _mm256_storeu_ps((float*)(tC + tc_inc_row ), ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row + 4), ymm9); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2+ 4), ymm13); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_ss((float const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_ps((float const *)(tC)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_add_ps(ymm4, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_add_ps(ymm5, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_add_ps(ymm8, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row + 4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm9 = _mm256_add_ps(ymm9, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm12 = _mm256_add_ps(ymm12, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row*2 +4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm13 = _mm256_add_ps(ymm13, _mm256_addsub_ps(ymm2, ymm3)); + + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + _mm256_storeu_ps((float*)(tC + tc_inc_row) , ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row + 4), ymm9); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2+ 4), ymm13); + } + } + } + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + scomplex* restrict cij = c + j_edge*cs_c; + scomplex* restrict ai = a; + scomplex* restrict bj = b + n_iter*8; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_cgemmsup_rv_zen_asm_3x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_cgemmsup_rv_zen_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 1 == n_left ) + { + bli_cgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } + +} + +void bli_cgemmsup_rv_zen_asm_2x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m128 xmm0, xmm3; + + scomplex *tA = a; + float *tAimag = &a->imag; + scomplex *tB = b; + scomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 8; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm11 = _mm256_setzero_ps(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*8; + tC = c + n_iter*tc_inc_col*8; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_ps(ymm4, 0xb1); + ymm4 = _mm256_mul_ps(ymm0, ymm4); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_addsub_ps(ymm4, ymm3); + + ymm3 = _mm256_permute_ps(ymm5, 0xb1); + ymm5 = _mm256_mul_ps(ymm0, ymm5); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_addsub_ps(ymm5, ymm3); + + ymm3 = _mm256_permute_ps(ymm8, 0xb1); + ymm8 = _mm256_mul_ps(ymm0, ymm8); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_addsub_ps(ymm8, ymm3); + + ymm3 = _mm256_permute_ps(ymm9, 0xb1); + ymm9 = _mm256_mul_ps(ymm0, ymm9); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm9 = _mm256_addsub_ps(ymm9, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 2x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + + ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm1,1)); + + //transpose right 2x4 + tC += tc_inc_col; + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm5), _mm256_castps_pd(ymm9))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + + ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(ymm5), _mm256_castps_pd(ymm9))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm1,1)); + + } + else{ + ymm1 = _mm256_broadcast_ss((float const *)beta); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_ss((float const *)&beta->imag); // load alpha_i and duplicate + + //Multiply ymm4 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm4 = _mm256_add_ps(ymm4, ymm0); + + //Multiply ymm8 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 1)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 1 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1); + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm8 = _mm256_add_ps(ymm8, ymm0); + + //transpose left 2x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + + ymm3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm3)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm3,1)); + + //Multiply ymm5 with beta + tC += tc_inc_col; + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm5 = _mm256_add_ps(ymm5, ymm0); + + //Multiply ymm9 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC+ 1)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC+ 1 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC+ 1 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC+ 1 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm9 = _mm256_add_ps(ymm9, ymm0); + + //transpose right 2x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm5), _mm256_castps_pd(ymm9))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + + ymm3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(ymm5), _mm256_castps_pd(ymm9))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm3)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm3,1)); + + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + _mm256_storeu_ps((float*)(tC + tc_inc_row) , ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row + 4), ymm9); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_ss((float const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_ps((float const *)(tC)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_add_ps(ymm4, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_add_ps(ymm5, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_add_ps(ymm8, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row + 4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm9 = _mm256_add_ps(ymm9, _mm256_addsub_ps(ymm2, ymm3)); + + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + _mm256_storeu_ps((float*)(tC + tc_inc_row) , ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row + 4), ymm9); + } + } + } + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + scomplex* restrict cij = c + j_edge*cs_c; + scomplex* restrict ai = a; + scomplex* restrict bj = b + n_iter * 8 ; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_cgemmsup_rv_zen_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_cgemmsup_rv_zen_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_cgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } +} + +void bli_cgemmsup_rv_zen_asm_1x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m128 xmm0, xmm3; + + scomplex *tA = a; + float *tAimag = &a->imag; + scomplex *tB = b; + scomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 8; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm7 = _mm256_setzero_ps(); + + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*8; + tC = c + n_iter*tc_inc_col*8; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_ps(ymm4, 0xb1); + ymm4 = _mm256_mul_ps(ymm0, ymm4); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_addsub_ps(ymm4, ymm3); + + ymm3 = _mm256_permute_ps(ymm5, 0xb1); + ymm5 = _mm256_mul_ps(ymm0, ymm5); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_addsub_ps(ymm5, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 1x4 + _mm_storel_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm4)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm4)); + + tC += tc_inc_col; + _mm_storel_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm4,1)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm4,1)); + + //transpose right 1x4 + tC += tc_inc_col; + _mm_storel_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm5)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm5)); + + tC += tc_inc_col; + _mm_storel_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm5,1)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm5,1)); + + } + else{ + ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate + + //Multiply ymm4 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm4 = _mm256_add_ps(ymm4, ymm0); + + _mm_storel_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm4)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm4)); + + tC += tc_inc_col; + _mm_storel_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm4,1)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm4,1)); + + //Multiply ymm5 with beta + tC += tc_inc_col; + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm5 = _mm256_add_ps(ymm5, ymm0); + + _mm_storel_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm5)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC), _mm256_castps256_ps128(ymm5)); + + tC += tc_inc_col; + _mm_storel_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm5,1)); + + tC += tc_inc_col; + _mm_storeh_pi((__m64 *)(tC) ,_mm256_extractf128_ps (ymm5,1)); + + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_ss((float const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_ps((float const *)(tC)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_add_ps(ymm4, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+4)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm5 = _mm256_add_ps(ymm5, _mm256_addsub_ps(ymm2, ymm3)); + + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + 4), ymm5); + } + } + } + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + scomplex* restrict cij = c + j_edge*cs_c; + scomplex* restrict ai = a; + scomplex* restrict bj = b + n_iter * 8; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_cgemmsup_rv_zen_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_cgemmsup_rv_zen_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ){ + bli_cgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } +} + + +void bli_cgemmsup_rv_zen_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + uint64_t k_iter = 0; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + // ------------------------------------------------------------------------- + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm6; + __m256 ymm8, ymm10; + __m256 ymm12, ymm14; + __m128 xmm0, xmm3; + + scomplex *tA = a; + float *tAimag = &a->imag; + scomplex *tB = b; + scomplex *tC = c; + // clear scratch registers. + ymm4 = _mm256_setzero_ps(); + ymm6 = _mm256_setzero_ps(); + ymm8 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); + ymm12 = _mm256_setzero_ps(); + ymm14 = _mm256_setzero_ps(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tc_inc_col = cs_c; + + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_ps(ymm4, 0xb1); + ymm4 = _mm256_mul_ps(ymm0, ymm4); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_addsub_ps(ymm4, ymm3); + + ymm3 = _mm256_permute_ps(ymm8, 0xb1); + ymm8 = _mm256_mul_ps(ymm0, ymm8); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_addsub_ps(ymm8, ymm3); + + ymm3 = _mm256_permute_ps(ymm12, 0xb1); + ymm12 = _mm256_mul_ps(ymm0, ymm12); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm12 = _mm256_addsub_ps(ymm12, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose 3x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)tC, _mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)tC+2, _mm256_castps256_ps128(ymm12)); + + ymm1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC),_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12, 1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ) ,_mm256_extractf128_ps (ymm1,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12,1)); + + } + else{ + ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate + + //Multiply ymm4 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *) (tC + tc_inc_col*2)); + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm4 = _mm256_add_ps(ymm4, ymm0); + + //Multiply ymm8 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)(tC + 1)) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)(tC + 1 + tc_inc_col)) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*2)) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)(tC + 1 + tc_inc_col*3)) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm8 = _mm256_add_ps(ymm8, ymm0); + + //Multiply ymm12 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)tC + 2) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)tC + 2 + tc_inc_col) ; + xmm3 = _mm_loadl_pi(xmm3, (__m64 const *)tC + 2 + tc_inc_col*2) ; + xmm3 = _mm_loadh_pi(xmm3, (__m64 const *)tC + 2 + tc_inc_col*3) ; + ymm0 = _mm256_insertf128_ps(_mm256_castps128_ps256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_ps(ymm0, 0xb1); + ymm0 = _mm256_mul_ps(ymm1, ymm0); + ymm3 = _mm256_mul_ps(ymm2, ymm3); + ymm0 = _mm256_addsub_ps(ymm0, ymm3); + ymm12 = _mm256_add_ps(ymm12, ymm0); + + //transpose 3x4 + ymm0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd (ymm4), _mm256_castps_pd (ymm8))); + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm0)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + ymm3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd (ymm4) , _mm256_castps_pd(ymm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC), _mm256_castps256_ps128(ymm3)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_castps256_ps128(ymm12)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC) ,_mm256_extractf128_ps (ymm0,1)); + _mm_storel_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12, 1)); + + tC += tc_inc_col; + _mm_storeu_ps((float *)(tC ),_mm256_extractf128_ps (ymm3,1)); + _mm_storeh_pi((__m64 *)(tC+2), _mm256_extractf128_ps(ymm12,1)); + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + tc_inc_row) , ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2), ymm12); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_ss((float const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_ps((float const *)(tC)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 =_mm256_mul_ps(ymm1, ymm3); + ymm4 = _mm256_add_ps(ymm4, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm8 = _mm256_add_ps(ymm8, _mm256_addsub_ps(ymm2, ymm3)); + + ymm2 = _mm256_loadu_ps((float const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_ps(ymm2, 0xb1); + ymm2 = _mm256_mul_ps(ymm0, ymm2); + ymm3 = _mm256_mul_ps(ymm1, ymm3); + ymm12 = _mm256_add_ps(ymm12, _mm256_addsub_ps(ymm2, ymm3)); + + _mm256_storeu_ps((float*)(tC), ymm4); + _mm256_storeu_ps((float*)(tC + tc_inc_row) , ymm8); + _mm256_storeu_ps((float*)(tC + tc_inc_row *2), ymm12);; + } + } +} + +void bli_cgemmsup_rv_zen_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + scomplex* restrict alpha, + scomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + scomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + scomplex *tA = a; + float *tAimag = &a->imag; + scomplex *tB = b; + scomplex *tC = c; + // clear scratch registers. + __m128 xmm0, xmm1, xmm2, xmm3; + __m128 xmm4 = _mm_setzero_ps(); + __m128 xmm6 = _mm_setzero_ps(); + __m128 xmm8 = _mm_setzero_ps(); + __m128 xmm10 = _mm_setzero_ps(); + __m128 xmm12 = _mm_setzero_ps(); + __m128 xmm14 = _mm_setzero_ps(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tc_inc_col = cs_c; + + for (k_iter = 0; k_iter imag); // load alpha_i and duplicate + + xmm3 = _mm_permute_ps(xmm4, 0xb1); + xmm4 = _mm_mul_ps(xmm0, xmm4); + xmm3 =_mm_mul_ps(xmm1, xmm3); + xmm4 = _mm_addsub_ps(xmm4, xmm3); + + xmm3 = _mm_permute_ps(xmm8, 0xb1); + xmm8 = _mm_mul_ps(xmm0, xmm8); + xmm3 = _mm_mul_ps(xmm1, xmm3); + xmm8 = _mm_addsub_ps(xmm8, xmm3); + + xmm3 = _mm_permute_ps(xmm12, 0xb1); + xmm12 = _mm_mul_ps(xmm0, xmm12); + xmm3 = _mm_mul_ps(xmm1, xmm3); + xmm12 = _mm_addsub_ps(xmm12, xmm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose 3x2 + xmm0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd (xmm4), _mm_castps_pd (xmm8))); + _mm_storeu_ps((float *)tC, xmm0); + _mm_storel_pi((__m64 *)tC+2, xmm12); + + xmm1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd (xmm4) , _mm_castps_pd(xmm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)tC, xmm1); + _mm_storeh_pi((__m64 *)tC+2, xmm12); + } + else{ + xmm1 = _mm_broadcast_ss((float const *)beta); // load alpha_r and duplicate + xmm2 = _mm_broadcast_ss((float const *)&beta->imag); // load alpha_i and duplicate + + //Multiply xmm4 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) tC) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) tC + tc_inc_col); + xmm3 = _mm_permute_ps(xmm0, 0xb1); + xmm0 = _mm_mul_ps(xmm1, xmm0); + xmm3 = _mm_mul_ps(xmm2, xmm3); + xmm0 = _mm_addsub_ps(xmm0, xmm3); + xmm4 = _mm_add_ps(xmm4, xmm0); + + //Multiply xmm8 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)tC + 1) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)tC + 1 + tc_inc_col) ; + xmm3 = _mm_permute_ps(xmm0, 0xb1); + xmm0 = _mm_mul_ps(xmm1, xmm0); + xmm3 = _mm_mul_ps(xmm2, xmm3); + xmm0 = _mm_addsub_ps(xmm0, xmm3); + xmm8 = _mm_add_ps(xmm8, xmm0); + + //Multiply xmm12 with beta + xmm0 = _mm_loadl_pi(xmm0, (__m64 const *)tC + 2) ; + xmm0 = _mm_loadh_pi(xmm0, (__m64 const *)tC + 2 + tc_inc_col) ; + xmm3 = _mm_permute_ps(xmm0, 0xb1); + xmm0 = _mm_mul_ps(xmm1, xmm0); + xmm3 = _mm_mul_ps(xmm2, xmm3); + xmm0 = _mm_addsub_ps(xmm0, xmm3); + xmm12 = _mm_add_ps(xmm12, xmm0); + + //transpose 3x2 + xmm0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd (xmm4), _mm_castps_pd (xmm8))); + _mm_storeu_ps((float *)tC, xmm0); + _mm_storel_pi((__m64 *)tC+2, xmm12); + + xmm3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd (xmm4) , _mm_castps_pd(xmm8))); + tC += tc_inc_col; + _mm_storeu_ps((float *)tC, xmm3); + _mm_storeh_pi((__m64 *)tC+2, xmm12); + + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm_storeu_ps((float *)tC, xmm4); + _mm_storeu_ps((float *)tC + tc_inc_row , xmm8); + _mm_storeu_ps((float *)tC + tc_inc_row *2, xmm12); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + xmm0 = _mm_broadcast_ss((float const *)beta); // load beta_r and duplicate + xmm1 = _mm_broadcast_ss((float const *)&beta->imag); // load beta_i and duplicate + + xmm2 = _mm_loadu_ps((float const *)tC); + xmm3 = _mm_permute_ps(xmm2, 0xb1); + xmm2 = _mm_mul_ps(xmm0, xmm2); + xmm3 = _mm_mul_ps(xmm1, xmm3); + xmm4 = _mm_add_ps(xmm4, _mm_addsub_ps(xmm2, xmm3)); + + xmm2 = _mm_loadu_ps((float const *)tC+tc_inc_row); + xmm3 = _mm_permute_ps(xmm2, 0xb1); + xmm2 = _mm_mul_ps(xmm0, xmm2); + xmm3 = _mm_mul_ps(xmm1, xmm3); + xmm8 = _mm_add_ps(xmm8, _mm_addsub_ps(xmm2, xmm3)); + + xmm2 = _mm_loadu_ps((float const *)tC+tc_inc_row*2); + xmm3 = _mm_permute_ps(xmm2, 0xb1); + xmm2 = _mm_mul_ps(xmm0, xmm2); + xmm3 = _mm_mul_ps(xmm1, xmm3); + xmm12 = _mm_add_ps(xmm12, _mm_addsub_ps(xmm2, xmm3)); + + _mm_storeu_ps((float *)tC, xmm4); + _mm_storeu_ps((float *)tC + tc_inc_row , xmm8); + _mm_storeu_ps((float *)tC + tc_inc_row *2, xmm12);; + } + } +} diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c index e7ac094bb..1638eaba0 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c @@ -96,7 +96,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 uint64_t k_left = k0 % 4; uint64_t m_iter = m0 / 3; - uint64_t m_left = m0 % 3; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -156,9 +155,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -166,9 +162,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -184,7 +177,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -210,7 +202,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -235,7 +226,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -261,7 +251,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; vmovupd(mem(rbx, 0*32), ymm0) @@ -584,7 +573,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 uint64_t k_left = k0 % 4; uint64_t m_iter = m0 / 3; - uint64_t m_left = m0 % 3; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -644,9 +632,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -654,9 +639,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -672,7 +654,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -691,7 +672,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -709,7 +689,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -728,7 +707,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; vmovupd(mem(rbx, 0*32), ymm0) @@ -980,7 +958,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 uint64_t k_left = k0 % 4; uint64_t m_iter = m0 / 3; - uint64_t m_left = m0 % 3; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -1038,9 +1015,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -1049,9 +1023,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -1067,7 +1038,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1087,7 +1057,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1107,7 +1076,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1128,7 +1096,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; vmovupd(mem(rbx, 0*32), ymm0) @@ -1379,7 +1346,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 uint64_t k_left = k0 % 4; uint64_t m_iter = m0 / 3; - uint64_t m_left = m0 % 3; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -1439,9 +1405,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -1450,9 +1413,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -1468,7 +1428,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1483,7 +1442,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1497,7 +1455,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1512,7 +1469,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; vmovupd(mem(rbx, 0*32), ymm0) diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c index 25cf2a7bb..05e05dfec 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c @@ -224,9 +224,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -234,9 +231,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -252,7 +246,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -285,7 +278,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -318,7 +310,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -351,7 +342,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; vmovupd(mem(rbx, 0*32), ymm0) @@ -727,14 +717,14 @@ void bli_zgemmsup_rv_zen_asm_3x4m dcomplex* ai = a + i_edge*rs_a; dcomplex* bj = b; - sgemmsup_ker_ft ker_fps[3] = + zgemmsup_ker_ft ker_fps[3] = { NULL, bli_zgemmsup_rv_zen_asm_1x4, bli_zgemmsup_rv_zen_asm_2x4, }; - sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; ker_fp ( @@ -837,12 +827,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c - //prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - //prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - //prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c jmp(.SPOSTPFETCH) // jump to end of pre-fetching c label(.SCOLPFETCH) // column-stored pre-fetching c @@ -851,15 +835,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c - //prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c - //prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c - //prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c - //lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; - //prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c - //prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c label(.SPOSTPFETCH) // done prefetching c @@ -875,7 +850,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -901,7 +875,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -927,7 +900,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -953,7 +925,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; vmovupd(mem(rbx, 0*32), ymm0) @@ -1238,14 +1209,14 @@ void bli_zgemmsup_rv_zen_asm_3x2m dcomplex* ai = a + i_edge*rs_a; dcomplex* bj = b; - sgemmsup_ker_ft ker_fps[3] = + zgemmsup_ker_ft ker_fps[3] = { NULL, bli_zgemmsup_rv_zen_asm_1x2, bli_zgemmsup_rv_zen_asm_2x2, }; - sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; ker_fp ( diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c new file mode 100644 index 000000000..7abb99181 --- /dev/null +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c @@ -0,0 +1,1196 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "immintrin.h" + +/* + rrr: + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + + rcr: + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : +*/ +void bli_zgemmsup_rv_zen_asm_3x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 3; + if ( m_left ) + { + zgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_zgemmsup_rv_zen_asm_1x4n, + bli_zgemmsup_rv_zen_asm_2x4n, + }; + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + ker_fp + ( + conja, conjb, m_left, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; + } + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); + + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); + + ymm3 = _mm256_permute_pd(ymm9, 5); + ymm9 = _mm256_mul_pd(ymm0, ymm9); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_addsub_pd(ymm9, ymm3); + + ymm3 = _mm256_permute_pd(ymm12, 5); + ymm12 = _mm256_mul_pd(ymm0, ymm12); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_addsub_pd(ymm12, ymm3); + + ymm3 = _mm256_permute_pd(ymm13, 5); + ymm13 = _mm256_mul_pd(ymm0, ymm13); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm13 = _mm256_addsub_pd(ymm13, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + tC += tc_inc_col; + + //transpose right 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); + + //Multiply ymm12 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm12 = _mm256_add_pd(ymm12, ymm0); + + //transpose left 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + tC += tc_inc_col; + + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); + //Multiply ymm9 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm9 = _mm256_add_pd(ymm9, ymm0); + + //Multiply ymm13 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm13 = _mm256_add_pd(ymm13, ymm0); + + //transpose right 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); + } + + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm13 = _mm256_add_pd(ymm13, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); + } + } + } + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_zgemmsup_rv_zen_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } + +} + +void bli_zgemmsup_rv_zen_asm_2x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + + uint64_t k_iter = 0; + + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); + + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); + + ymm3 = _mm256_permute_pd(ymm9, 5); + ymm9 = _mm256_mul_pd(ymm0, ymm9); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_addsub_pd(ymm9, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 2x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + tC += tc_inc_col; + + //transpose right 2x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); + + //transpose left 2x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + tC += tc_inc_col; + + + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); + //Multiply ymm9 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm9 = _mm256_add_pd(ymm9, ymm0); + + //transpose right 2x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + } + + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + } + } + } + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_zgemmsup_rv_zen_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } + +} + +void bli_zgemmsup_rv_zen_asm_1x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 1x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm4,1)); + tC += tc_inc_col; + + //transpose right 1x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); + tC += tc_inc_col; + + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); + + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); + } + + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + } + } + } + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + bli_zgemmsup_rv_zen_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } +} + +void bli_zgemmsup_rv_zen_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t rs_a0, inc_t cs_a0, + dcomplex* restrict b, inc_t rs_b0, inc_t cs_b0, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = 0; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm6; + __m256d ymm8, ymm10; + __m256d ymm12, ymm14; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tc_inc_col = cs_c; + + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); + + ymm3 = _mm256_permute_pd(ymm12, 5); + ymm12 = _mm256_mul_pd(ymm0, ymm12); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_addsub_pd(ymm12, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 3x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); + + //Multiply ymm12 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm12 = _mm256_add_pd(ymm12, ymm0); + + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + tc_inc_row ), ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)tC+tc_inc_row); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + } + } +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 0d12d2942..243706c30 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -184,3 +184,17 @@ GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x4 ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4 ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x2 ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x2 ) + +// gemmsup_rv (mkernel in n dim) + + +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x8n ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_2x8n ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_1x8n ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x4 ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x2 ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x4n ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x4n ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4n ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x2 ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x1 )