CPUPL-929: Improve Complex GEMM performance - Support all storage formats and non Transpose/Conjugate Matrices

Details
Added Support of N SUP kernel for complex float and complex double
Removed prefetching in M SUP kernels for complex float and complex double
Removed all warnings

Change-Id: I05ffde0f0613681927fe7576db7f5f1a4486fd05
This commit is contained in:
managalv
2020-06-01 21:04:00 +05:30
committed by Mangala V
parent c8f3cec5f7
commit f7bc37ea32
10 changed files with 2823 additions and 168 deletions

View File

@@ -196,7 +196,7 @@ void bli_cntx_init_zen( cntx_t* cntx )
// Initialize sup thresholds with architecture-appropriate values.
// s d c z
bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 256, 128 );
bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 );
bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 256, 256, 128 );
bli_blksz_init_easy( &thresh[ BLIS_KT ], 440, 220, 220, 110 );
@@ -221,7 +221,7 @@ void bli_cntx_init_zen( cntx_t* cntx )
// Update the context with optimized small/unpacked gemm kernels.
bli_cntx_set_l3_sup_kers
(
26,
28,
//BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref,
BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE,
BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE,
@@ -242,13 +242,15 @@ void bli_cntx_init_zen( cntx_t* cntx )
BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE,
BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE,
BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE,
BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE,
BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE,
BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE,
BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE,
BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE,
cntx
);

View File

@@ -183,7 +183,7 @@ void bli_cntx_init_zen2( cntx_t* cntx )
);
// Initialize sup thresholds with architecture-appropriate values. s d c z
bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 256, 128 );
bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 );
bli_blksz_init_easy( &thresh[ BLIS_NT ], 200, 256, 256, 128 );
bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 );
@@ -208,7 +208,7 @@ void bli_cntx_init_zen2( cntx_t* cntx )
// Update the context with optimized small/unpacked gemm kernels.
bli_cntx_set_l3_sup_kers
(
26,
28,
//BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref,
BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE,
BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE,
@@ -229,13 +229,15 @@ void bli_cntx_init_zen2( cntx_t* cntx )
BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE,
BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE,
BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE,
BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE,
BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE,
BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE,
BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE,
BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE,
BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE,
BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE,
cntx
);

View File

@@ -67,27 +67,23 @@ err_t bli_gemmsup
trans_t transa = bli_obj_conjtrans_status( a );
trans_t transb = bli_obj_conjtrans_status( b );
//Don't use sup for currently unsupported storage types and dimension in cgemmsup
//Don't use sup for currently unsupported storage types in cgemmsup
if(bli_obj_is_scomplex(c) &&
((!((stor_id == BLIS_RRR) || (stor_id == BLIS_CRR)
||(stor_id == BLIS_CCR) || (stor_id == BLIS_RCR)
||(stor_id == BLIS_CCC)))
|| ((m/3) < (n/8))
|| (!((transa == BLIS_NO_TRANSPOSE)&&(transb == BLIS_NO_TRANSPOSE)))
(((stor_id == BLIS_RRC)||(stor_id == BLIS_CRC))
|| ((transa == BLIS_CONJ_NO_TRANSPOSE) || (transa == BLIS_CONJ_TRANSPOSE))
|| ((transb == BLIS_CONJ_NO_TRANSPOSE) || (transb == BLIS_CONJ_TRANSPOSE))
)){
//printf(" gemmsup: Returning with for un-supported storage types,dimension and matrix property in cgemmsup \n");
//printf(" gemmsup: Returning with for un-supported storage types and conjugate property in cgemmsup \n");
return BLIS_FAILURE;
}
//Don't use sup for currently unsupported storage types and dimension in zgemmsup
//Don't use sup for currently unsupported storage types in zgemmsup
if(bli_obj_is_dcomplex(c) &&
((!((stor_id == BLIS_RRR) || (stor_id == BLIS_CRR)
||(stor_id == BLIS_CCR) || (stor_id == BLIS_RCR)
||(stor_id == BLIS_CCC)))
|| ((m/3) < (n/4))
|| (!((transa == BLIS_NO_TRANSPOSE)&&(transb == BLIS_NO_TRANSPOSE)))
(((stor_id == BLIS_RRC)||(stor_id == BLIS_CRC))
|| ((transa == BLIS_CONJ_NO_TRANSPOSE) || (transa == BLIS_CONJ_TRANSPOSE))
|| ((transb == BLIS_CONJ_NO_TRANSPOSE) || (transb == BLIS_CONJ_TRANSPOSE))
)){
//printf(" gemmsup: Returning with for un-supported storage types,dimension and matrix property in zgemmsup \n");
//printf(" gemmsup: Returning with for un-supported storage types and conjugate property in zgemmsup \n");
return BLIS_FAILURE;
}

View File

@@ -153,8 +153,6 @@ void bli_cgemmsup_rv_zen_asm_2x8
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -163,15 +161,7 @@ void bli_cgemmsup_rv_zen_asm_2x8
lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 1*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 1*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 1*8)) // prefetch c + 2*cs_c
prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c
prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c
prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c
lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c;
prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 6*cs_c
prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 7*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -627,7 +617,6 @@ void bli_cgemmsup_rv_zen_asm_1x8
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -636,15 +625,7 @@ void bli_cgemmsup_rv_zen_asm_1x8
lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 0*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 0*8)) // prefetch c + 2*cs_c
prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c
prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c
prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c
lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c;
prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c
prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 7*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -1009,8 +990,6 @@ void bli_cgemmsup_rv_zen_asm_2x4
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -1019,10 +998,6 @@ void bli_cgemmsup_rv_zen_asm_2x4
lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 1*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 1*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 1*8)) // prefetch c + 2*cs_c
prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -1394,7 +1369,6 @@ void bli_cgemmsup_rv_zen_asm_1x4
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -1418,7 +1392,6 @@ void bli_cgemmsup_rv_zen_asm_1x4
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
prefetch(0, mem(rdx, 5*8))
vmovups(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1716,8 +1689,6 @@ void bli_cgemmsup_rv_zen_asm_2x2
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 3*8)) // prefetch c + 1*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -1726,8 +1697,6 @@ void bli_cgemmsup_rv_zen_asm_2x2
lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 1*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 1*8)) // prefetch c + 1*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -2099,7 +2068,6 @@ void bli_cgemmsup_rv_zen_asm_1x2
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -2108,8 +2076,6 @@ void bli_cgemmsup_rv_zen_asm_1x2
lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 0*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 0*8)) // prefetch c + 1*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -2335,4 +2301,3 @@ void bli_cgemmsup_rv_zen_asm_1x2
"memory"
)
}

View File

@@ -240,9 +240,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c
prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -251,9 +248,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m
lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -269,7 +263,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
prefetch(0, mem(rdx, 5*8))
vmovups(mem(rbx, 0*32), ymm0)
vmovups(mem(rbx, 1*32), ymm1)
@@ -303,7 +296,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 1
prefetch(0, mem(rdx, r9, 1, 5*8))
vmovups(mem(rbx, 0*32), ymm0)
vmovups(mem(rbx, 1*32), ymm1)
@@ -336,7 +328,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 2
prefetch(0, mem(rdx, r9, 2, 5*8))
vmovups(mem(rbx, 0*32), ymm0)
vmovups(mem(rbx, 1*32), ymm1)
@@ -369,7 +360,6 @@ void bli_cgemmsup_rv_zen_asm_3x8m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
prefetch(0, mem(rdx, rcx, 1, 5*8))
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovups(mem(rbx, 0*32), ymm0)
@@ -868,9 +858,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c
prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -879,9 +866,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m
lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -897,7 +881,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
prefetch(0, mem(rdx, 5*8))
vmovups(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -923,7 +906,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 1
prefetch(0, mem(rdx, r9, 1, 5*8))
vmovups(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -950,7 +932,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 2
prefetch(0, mem(rdx, r9, 2, 5*8))
vmovups(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -977,7 +958,6 @@ void bli_cgemmsup_rv_zen_asm_3x4m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
prefetch(0, mem(rdx, rcx, 1, 5*8))
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovups(mem(rbx, 0*32), ymm0)
@@ -1367,9 +1347,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c
prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -1378,9 +1355,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m
lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -1396,7 +1370,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
prefetch(0, mem(rdx, 5*8))
vmovups(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
@@ -1422,7 +1395,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 1
prefetch(0, mem(rdx, r9, 1, 5*8))
vmovups(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
@@ -1448,7 +1420,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 2
prefetch(0, mem(rdx, r9, 2, 5*8))
vmovups(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
@@ -1474,7 +1445,6 @@ void bli_cgemmsup_rv_zen_asm_3x2m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
prefetch(0, mem(rdx, rcx, 1, 5*8))
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovups(mem(rbx, 0*32), xmm0)
@@ -1782,4 +1752,5 @@ void bli_cgemmsup_rv_zen_asm_3x2m
return;
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -96,7 +96,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
uint64_t k_left = k0 % 4;
uint64_t m_iter = m0 / 3;
uint64_t m_left = m0 % 3;
uint64_t rs_a = rs_a0;
uint64_t cs_a = cs_a0;
@@ -156,9 +155,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c
prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -166,9 +162,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -184,7 +177,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
prefetch(0, mem(rdx, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -210,7 +202,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 1
prefetch(0, mem(rdx, r9, 1, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -235,7 +226,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 2
prefetch(0, mem(rdx, r9, 2, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -261,7 +251,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
prefetch(0, mem(rdx, rcx, 1, 5*8))
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)
@@ -584,7 +573,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
uint64_t k_left = k0 % 4;
uint64_t m_iter = m0 / 3;
uint64_t m_left = m0 % 3;
uint64_t rs_a = rs_a0;
uint64_t cs_a = cs_a0;
@@ -644,9 +632,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c
prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -654,9 +639,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -672,7 +654,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
prefetch(0, mem(rdx, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -691,7 +672,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 1
prefetch(0, mem(rdx, r9, 1, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -709,7 +689,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 2
prefetch(0, mem(rdx, r9, 2, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -728,7 +707,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
prefetch(0, mem(rdx, rcx, 1, 5*8))
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)
@@ -980,7 +958,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
uint64_t k_left = k0 % 4;
uint64_t m_iter = m0 / 3;
uint64_t m_left = m0 % 3;
uint64_t rs_a = rs_a0;
uint64_t cs_a = cs_a0;
@@ -1038,9 +1015,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c
prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -1049,9 +1023,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -1067,7 +1038,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
prefetch(0, mem(rdx, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1087,7 +1057,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 1
prefetch(0, mem(rdx, r9, 1, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1107,7 +1076,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 2
prefetch(0, mem(rdx, r9, 2, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1128,7 +1096,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
prefetch(0, mem(rdx, rcx, 1, 5*8))
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)
@@ -1379,7 +1346,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
uint64_t k_left = k0 % 4;
uint64_t m_iter = m0 / 3;
uint64_t m_left = m0 % 3;
uint64_t rs_a = rs_a0;
uint64_t cs_a = cs_a0;
@@ -1439,9 +1405,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c
prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -1450,9 +1413,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -1468,7 +1428,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
prefetch(0, mem(rdx, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1483,7 +1442,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 1
prefetch(0, mem(rdx, r9, 1, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1497,7 +1455,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 2
prefetch(0, mem(rdx, r9, 2, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1512,7 +1469,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
prefetch(0, mem(rdx, rcx, 1, 5*8))
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)

View File

@@ -224,9 +224,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c
prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -234,9 +231,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -252,7 +246,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
prefetch(0, mem(rdx, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -285,7 +278,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 1
prefetch(0, mem(rdx, r9, 1, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -318,7 +310,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 2
prefetch(0, mem(rdx, r9, 2, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -351,7 +342,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
prefetch(0, mem(rdx, rcx, 1, 5*8))
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)
@@ -727,14 +717,14 @@ void bli_zgemmsup_rv_zen_asm_3x4m
dcomplex* ai = a + i_edge*rs_a;
dcomplex* bj = b;
sgemmsup_ker_ft ker_fps[3] =
zgemmsup_ker_ft ker_fps[3] =
{
NULL,
bli_zgemmsup_rv_zen_asm_1x4,
bli_zgemmsup_rv_zen_asm_2x4,
};
sgemmsup_ker_ft ker_fp = ker_fps[ m_left ];
zgemmsup_ker_ft ker_fp = ker_fps[ m_left ];
ker_fp
(
@@ -837,12 +827,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c
prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c
//prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c
//prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c
//prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
@@ -851,15 +835,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c
prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c
prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c
//prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c
//prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c
//prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c
//lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c;
//prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c
//prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c
label(.SPOSTPFETCH) // done prefetching c
@@ -875,7 +850,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
prefetch(0, mem(rdx, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -901,7 +875,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 1
prefetch(0, mem(rdx, r9, 1, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -927,7 +900,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 2
prefetch(0, mem(rdx, r9, 2, 5*8))
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -953,7 +925,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
prefetch(0, mem(rdx, rcx, 1, 5*8))
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)
@@ -1238,14 +1209,14 @@ void bli_zgemmsup_rv_zen_asm_3x2m
dcomplex* ai = a + i_edge*rs_a;
dcomplex* bj = b;
sgemmsup_ker_ft ker_fps[3] =
zgemmsup_ker_ft ker_fps[3] =
{
NULL,
bli_zgemmsup_rv_zen_asm_1x2,
bli_zgemmsup_rv_zen_asm_2x2,
};
sgemmsup_ker_ft ker_fp = ker_fps[ m_left ];
zgemmsup_ker_ft ker_fp = ker_fps[ m_left ];
ker_fp
(

File diff suppressed because it is too large Load Diff

View File

@@ -184,3 +184,17 @@ GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x4 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x2 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x2 )
// gemmsup_rv (mkernel in n dim)
GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x8n )
GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_2x8n )
GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_1x8n )
GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x4 )
GEMMSUP_KER_PROT( scomplex, c, gemmsup_rv_zen_asm_3x2 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x4n )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x4n )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4n )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x2 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x1 )