diff --git a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.c b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.c index 98c3bb3e2..53bedf2fa 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.c @@ -61,7 +61,6 @@ void bli_sgemmsup_rd_zen_asm_5x64_avx512 uint64_t k_left1 = k_left32 % 8; uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -487,7 +486,6 @@ void bli_sgemmsup_rd_zen_asm_4x64_avx512 uint64_t k_left1 = k_left32 % 8; uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -888,7 +886,6 @@ void bli_sgemmsup_rd_zen_asm_3x64_avx512 uint64_t k_left1 = k_left32 % 8; uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -1987,7 +1984,6 @@ void bli_sgemmsup_rd_zen_asm_5x48_avx512 uint64_t k_left1 = k_left32 % 8; uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -2406,7 +2402,6 @@ void bli_sgemmsup_rd_zen_asm_4x48_avx512 uint64_t k_left1 = k_left32 % 8; uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -2814,7 +2809,6 @@ void bli_sgemmsup_rd_zen_asm_3x48_avx512 uint64_t k_left1 = k_left32 % 8; uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -3913,7 +3907,6 @@ void bli_sgemmsup_rd_zen_asm_5x32_avx512 uint64_t k_left1 = k_left32 % 8; uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -4331,7 +4324,6 @@ void bli_sgemmsup_rd_zen_asm_4x32_avx512 uint64_t k_left1 = k_left32 % 8; uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -4738,7 +4730,6 @@ void bli_sgemmsup_rd_zen_asm_3x32_avx512 uint64_t k_left1 = k_left32 % 8; uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; diff --git a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.h b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.h index 03063f058..80e43843c 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.h +++ b/kernels/zen4/3/sup/bli_gemmsup_rd_zen_s6x64.h @@ -198,4 +198,4 @@ mov( var( rs_c ), rdi ) \ lea( mem( , rdi, 4 ), rdi ) \ vmovups( xmm4, mem( rcx ) ) \ - add( rdi, rcx ) \ No newline at end of file + add( rdi, rcx ) diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.c b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.c index 30a1947a4..e47bbd6a7 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.c @@ -53,7 +53,8 @@ void bli_sgemmsup_rv_zen_asm_5x48_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -80,23 +81,121 @@ void bli_sgemmsup_rv_zen_asm_5x48_avx512 lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) => rs_c *= 4 lea( mem( r8, r8, 2 ), r13 ) // r13 = 3 * rs_a lea( mem( r8, r8, 4 ), r15 ) // r15 = 5 * rs_a - + INIT_REG - + mov( var( abuf ), rax ) // load address of a mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) - + // ITER 0 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA3( 4, 20, 21, 22 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA3( 5, 24, 25, 26 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA3( 4, 20, 21, 22 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA3( 5, 24, 25, 26 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA3( 4, 20, 21, 22 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA3( 5, 24, 25, 26 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA3( 4, 20, 21, 22 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA3( 5, 24, 25, 26 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) + // Load 3 rows from B matrix. vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) @@ -117,7 +216,9 @@ void bli_sgemmsup_rv_zen_asm_5x48_avx512 add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE3( 7, 8, 9, 10 ) @@ -125,7 +226,7 @@ void bli_sgemmsup_rv_zen_asm_5x48_avx512 ALPHA_SCALE3( 7, 16, 17, 18 ) ALPHA_SCALE3( 7, 20, 21, 22 ) ALPHA_SCALE3( 7, 24, 25, 26 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -233,6 +334,7 @@ void bli_sgemmsup_rv_zen_asm_5x48_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -278,7 +380,8 @@ void bli_sgemmsup_rv_zen_asm_5x32_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -305,27 +408,28 @@ void bli_sgemmsup_rv_zen_asm_5x32_avx512 lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) => rs_c *= 4 lea( mem( r8, r8, 2 ), r13 ) // r13 = 3 * rs_a lea( mem( r8, r8, 4 ), r15 ) // r15 = 5 * rs_a - + INIT_REG - + mov( var( abuf ), rax ) // load address of a mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) - + // ITER 0 // Load 2 rows from B matrix. vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) - + // Broadcast 5 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA2( 4, 8, 9 ) @@ -337,11 +441,107 @@ void bli_sgemmsup_rv_zen_asm_5x32_avx512 VFMA2( 4, 20, 21 ) vbroadcastss( mem( rax, r8, 4 ), zmm5 ) VFMA2( 5, 24, 25 ) - + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA2( 4, 20, 21 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA2( 5, 24, 25 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA2( 4, 20, 21 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA2( 5, 24, 25 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA2( 4, 20, 21 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA2( 5, 24, 25 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) + + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA2( 4, 20, 21 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA2( 5, 24, 25 ) + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE2( 7, 8, 9 ) @@ -349,7 +549,7 @@ void bli_sgemmsup_rv_zen_asm_5x32_avx512 ALPHA_SCALE2( 7, 16, 17 ) ALPHA_SCALE2( 7, 20, 21 ) ALPHA_SCALE2( 7, 24, 25 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -448,6 +648,7 @@ void bli_sgemmsup_rv_zen_asm_5x32_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -493,7 +694,8 @@ void bli_sgemmsup_rv_zen_asm_5x16_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -520,23 +722,24 @@ void bli_sgemmsup_rv_zen_asm_5x16_avx512 lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) => rs_c *= 4 lea( mem( r8, r8, 2 ), r13 ) // r13 = 3 * rs_a lea( mem( r8, r8, 4 ), r15 ) // r15 = 5 * rs_a - + INIT_REG - + mov( var( abuf ), rax ) // load address of a mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) - + // ITER 0 // Load 1 row from B matrix. vmovups( ( rbx ), zmm0 ) @@ -551,11 +754,103 @@ void bli_sgemmsup_rv_zen_asm_5x16_avx512 VFMA1( 4, 20 ) vbroadcastss( mem( rax, r8, 4 ), zmm5 ) VFMA1( 5, 24 ) - + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA1( 4, 20 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA1( 5, 24 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA1( 4, 20 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA1( 5, 24 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA1( 4, 20 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA1( 5, 24 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) + + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA1( 4, 20 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA1( 5, 24 ) + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE1( 7, 8 ) @@ -563,7 +858,7 @@ void bli_sgemmsup_rv_zen_asm_5x16_avx512 ALPHA_SCALE1( 7, 16 ) ALPHA_SCALE1( 7, 20 ) ALPHA_SCALE1( 7, 24 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -658,6 +953,7 @@ void bli_sgemmsup_rv_zen_asm_5x16_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -703,7 +999,8 @@ void bli_sgemmsup_rv_zen_asm_3x48_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -730,23 +1027,24 @@ void bli_sgemmsup_rv_zen_asm_3x48_avx512 lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) => rs_c *= 4 lea( mem( r8, r8, 2 ), r13 ) // r13 = 3 * rs_a lea( mem( r8, r8, 4 ), r15 ) // r15 = 5 * rs_a - + INIT_REG - + mov( var( abuf ), rax ) // load address of a mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) - + // ITER 0 // Load 3 rows from B matrix. vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) @@ -759,17 +1057,101 @@ void bli_sgemmsup_rv_zen_asm_3x48_avx512 VFMA3( 5, 12, 13, 14 ) vbroadcastss( mem( rax, r8, 2 ), zmm6 ) VFMA3( 6, 16, 17, 18 ) - + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) + + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE3( 7, 8, 9, 10 ) ALPHA_SCALE3( 7, 12, 13, 14 ) ALPHA_SCALE3( 7, 16, 17, 18 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -871,6 +1253,7 @@ void bli_sgemmsup_rv_zen_asm_3x48_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -916,7 +1299,8 @@ void bli_sgemmsup_rv_zen_asm_3x32_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -943,23 +1327,24 @@ void bli_sgemmsup_rv_zen_asm_3x32_avx512 lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) => rs_c *= 4 lea( mem( r8, r8, 2 ), r13 ) // r13 = 3 * rs_a lea( mem( r8, r8, 4 ), r15 ) // r15 = 5 * rs_a - + INIT_REG - + mov( var( abuf ), rax ) // load address of a mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) - + // ITER 0 // Load 2 rows from B matrix. vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) @@ -971,17 +1356,97 @@ void bli_sgemmsup_rv_zen_asm_3x32_avx512 VFMA2( 5, 12, 13 ) vbroadcastss( mem( rax, r8, 2 ), zmm6 ) VFMA2( 6, 16, 17 ) - + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) + + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE2( 7, 8, 9 ) ALPHA_SCALE2( 7, 12, 13 ) ALPHA_SCALE2( 7, 16, 17 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -1064,7 +1529,7 @@ void bli_sgemmsup_rv_zen_asm_3x32_avx512 mov( var( cs_c ), rdi ) // load cs_c lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c *= sizeof(dt) => cs_c *= 4 lea( mem( rdi, rdi, 2 ), r12 ) - + UPDATE_C_1X16_BZ( 16 ) UPDATE_C_1X16_BZ( 17 ) @@ -1077,6 +1542,7 @@ void bli_sgemmsup_rv_zen_asm_3x32_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -1122,7 +1588,8 @@ void bli_sgemmsup_rv_zen_asm_3x16_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -1149,23 +1616,24 @@ void bli_sgemmsup_rv_zen_asm_3x16_avx512 lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) => rs_c *= 4 lea( mem( r8, r8, 2 ), r13 ) // r13 = 3 * rs_a lea( mem( r8, r8, 4 ), r15 ) // r15 = 5 * rs_a - + INIT_REG - + mov( var( abuf ), rax ) // load address of a mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) - + // ITER 0 // Load 1 row from B matrix. vmovups( ( rbx ), zmm0 ) @@ -1176,17 +1644,93 @@ void bli_sgemmsup_rv_zen_asm_3x16_avx512 VFMA1( 5, 12 ) vbroadcastss( mem( rax, r8, 2 ), zmm6 ) VFMA1( 6, 16 ) - + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) + + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE1( 7, 8 ) ALPHA_SCALE1( 7, 12 ) ALPHA_SCALE1( 7, 16 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -1271,11 +1815,12 @@ void bli_sgemmsup_rv_zen_asm_3x16_avx512 label( .SDONE ) - + end_asm( : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.h b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.h index 8ee3d6c2d..ae5023c40 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.h +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64.h @@ -354,4 +354,4 @@ vmovss( xmm6, (rcx, rdi, 1) ) \ vmovss( xmm7, (rcx, rdi, 2) ) \ vmovss( xmm12, (rcx, r12, 1) ) \ - lea( (rcx, rdi, 4), rcx ) \ No newline at end of file + lea( (rcx, rdi, 4), rcx ) diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c index 284430237..9fe45581a 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c @@ -41,12 +41,12 @@ /* rrr: - -------- ------ -------- - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : - -------- ------ : + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : Assumptions: - B is row-stored; - A is row-stored; @@ -195,20 +195,20 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 else { const dim_t mr = 6; - + // Since A is packed into row panels, // we must use a loop over gemv. dim_t m_iter = ( m0 + mr - 1 ) / mr; dim_t m_left = m0 % mr; - + float* restrict ai_ii = ai; float* restrict cij_ii = cij; - + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) { dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, @@ -217,7 +217,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 ); cij_ii += mr_cur * rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -274,14 +274,14 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 mov( var( cbuf ), rcx ) // load address of c // C Prefetch - lea( mem( rcx, rdi, 2 ), rdx ) - lea( mem( rdx, rdi, 1 ), rdx ) - cmp( imm( 4 ), rdi ) jz( .SPOSTPFETCH ) // haven't added col-prefetch cases label( .SROWPFETCH ) + lea( mem( rcx, rdi, 2 ), rdx ) + lea( mem( rdx, rdi, 1 ), rdx ) + prefetch( 0, mem( rcx, 7*8 ) ) prefetch( 0, mem( rcx, rdi, 1, 7*8 ) ) prefetch( 0, mem( rcx, rdi, 2, 7*8 ) ) @@ -321,7 +321,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 VFMA4( 5, 24, 25, 26, 27 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA4( 6, 28, 29, 30, 31 ) - + add( r9, rbx ) add( r10, rax ) @@ -345,7 +345,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 VFMA4( 5, 24, 25, 26, 27 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA4( 6, 28, 29, 30, 31 ) - + add( r9, rbx ) add( r10, rax ) @@ -369,7 +369,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 VFMA4( 5, 24, 25, 26, 27 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA4( 6, 28, 29, 30, 31 ) - + add( r9, rbx ) add( r10, rax ) @@ -393,7 +393,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 VFMA4( 5, 24, 25, 26, 27 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA4( 6, 28, 29, 30, 31 ) - + add( r9, rbx ) add( r10, rax ) @@ -401,7 +401,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop label( .CONSID_K_LEFT ) - + mov( var( k_left ), rsi ) // i = k_left; test( rsi, rsi ) // check i via logical AND. je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. @@ -427,13 +427,13 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 VFMA4( 5, 24, 25, 26, 27 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA4( 6, 28, 29, 30, 31 ) - + add( r9, rbx ) add( r10, rax ) - dec(rsi) // i -= 1; + dec( rsi ) // i -= 1; jne( .K_LEFT_LOOP ) // iterate again if i != 0. - label(.SPOSTACCUM) + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE4( 7, 8, 9, 10, 11 ) @@ -442,7 +442,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 ALPHA_SCALE4( 7, 20, 21, 22, 23 ) ALPHA_SCALE4( 7, 24, 25, 26, 27 ) ALPHA_SCALE4( 7, 28, 29, 30, 31 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -554,7 +554,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 mov( var( cs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) lea( mem( rdi, rdi, 2 ), r12 ) - + TRANSPOSE_4X16_BZ( 8, 12, 16, 20 ) lea( mem( rcx, r12, 4 ), rcx ) TRANSPOSE_4X16_BZ( 9, 13, 17, 21 ) @@ -585,7 +585,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 add( rdx, rax ) // a += rs_a * 6(MR) mov( rax, var( abuf ) ) // store updated a - mov( var( rs_c ), rdi ) + mov( var( rs_c ), rdi ) lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 lea( mem( , rdi, 2 ), rdx ) // rdx = rs_c * 2 lea( mem( rdx, rdi, 4 ), rdx ) // rdx = rdi * 4 => rdx = rs_c * 6 @@ -657,7 +657,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 ai += mr_cur * rs_a; m_left -= mr_cur; } - + if ( 2 <= m_left ) { const dim_t mr_cur = 2; @@ -788,7 +788,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) vmovups( 0x80( rbx ), zmm2 ) - + // Broadcast 6 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA3( 4, 8, 9, 10 ) @@ -802,7 +802,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 VFMA3( 5, 24, 25, 26 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA3( 6, 28, 29, 30 ) - + add( r9, rbx ) add( r10, rax ) @@ -811,7 +811,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) vmovups( 0x80( rbx ), zmm2 ) - + // Broadcast 6 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA3( 4, 8, 9, 10 ) @@ -825,7 +825,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 VFMA3( 5, 24, 25, 26 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA3( 6, 28, 29, 30 ) - + add( r9, rbx ) add( r10, rax ) @@ -834,7 +834,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) vmovups( 0x80( rbx ), zmm2 ) - + // Broadcast 6 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA3( 4, 8, 9, 10 ) @@ -848,7 +848,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 VFMA3( 5, 24, 25, 26 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA3( 6, 28, 29, 30 ) - + add( r9, rbx ) add( r10, rax ) @@ -857,7 +857,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) vmovups( 0x80( rbx ), zmm2 ) - + // Broadcast 6 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA3( 4, 8, 9, 10 ) @@ -871,7 +871,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 VFMA3( 5, 24, 25, 26 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA3( 6, 28, 29, 30 ) - + add( r9, rbx ) add( r10, rax ) @@ -880,7 +880,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 label( .CONSID_K_LEFT ) - + mov( var( k_left ), rsi ) // i = k_left; test( rsi, rsi ) // check i via logical AND. je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. @@ -891,7 +891,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) vmovups( 0x80( rbx ), zmm2 ) - + // Broadcast 6 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA3( 4, 8, 9, 10 ) @@ -905,7 +905,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 VFMA3( 5, 24, 25, 26 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA3( 6, 28, 29, 30 ) - + add( r9, rbx ) add( r10, rax ) @@ -921,7 +921,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 ALPHA_SCALE3( 7, 20, 21, 22 ) ALPHA_SCALE3( 7, 24, 25, 26 ) ALPHA_SCALE3( 7, 28, 29, 30 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -972,7 +972,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 TRANSPOSE_2X16( 25, 29 ) lea( mem( rcx, rdi, 2 ), rcx ) TRANSPOSE_2X16( 26, 30 ) - + jmp( .SDONE ) // jump to the end @@ -981,7 +981,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 cmp( imm( 4 ), rdi ) // set ZF if (4*rs_c) == 4. jz( .SCOLSTORBZ ) // jump to column storage case - + label( .SROWSTORBZ ) UPDATE_C3_BZ( 8, 9, 10 ) @@ -1023,7 +1023,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 TRANSPOSE_2X16_BZ( 25, 29 ) lea( mem( rcx, rdi, 2 ), rcx ) TRANSPOSE_2X16_BZ( 26, 30 ) - + jmp( .SDONE ) // jump to the end @@ -1035,7 +1035,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 add( rdx, rax ) // a += rs_a * 6(MR) mov( rax, var( abuf ) ) // store updated a - mov( var( rs_c ), rdi ) + mov( var( rs_c ), rdi ) lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 lea( mem( , rdi, 2 ), rdx ) // rdx = rs_c * 2 lea( mem( rdx, rdi, 4 ), rdx ) // rdx = rdi * 4 => rdx = rs_c * 6 @@ -1107,7 +1107,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 ai += mr_cur * rs_a; m_left -= mr_cur; } - + if ( 2 <= m_left ) { const dim_t mr_cur = 2; @@ -1223,7 +1223,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 label( .SPOSTPFETCH ) - + mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) @@ -1237,7 +1237,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 // Load 2 rows from B matrix. vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) - + // Broadcast 6 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA2( 4, 8, 9 ) @@ -1251,7 +1251,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 VFMA2( 5, 24, 25 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA2( 6, 28, 29 ) - + add( r9, rbx ) add( r10, rax ) @@ -1259,7 +1259,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 // Load 2 rows from B matrix. vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) - + // Broadcast 6 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA2( 4, 8, 9 ) @@ -1273,7 +1273,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 VFMA2( 5, 24, 25 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA2( 6, 28, 29 ) - + add( r9, rbx ) add( r10, rax ) @@ -1281,7 +1281,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 // Load 2 rows from B matrix. vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) - + // Broadcast 6 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA2( 4, 8, 9 ) @@ -1295,7 +1295,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 VFMA2( 5, 24, 25 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA2( 6, 28, 29 ) - + add( r9, rbx ) add( r10, rax ) @@ -1303,7 +1303,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 // Load 2 rows from B matrix. vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) - + // Broadcast 6 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA2( 4, 8, 9 ) @@ -1317,14 +1317,14 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 VFMA2( 5, 24, 25 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA2( 6, 28, 29 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop label( .CONSID_K_LEFT ) - + mov( var( k_left ), rsi ) // i = k_left; test( rsi, rsi ) // check i via logical AND. je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. @@ -1334,7 +1334,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 // Load 2 rows from B matrix. vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) - + // Broadcast 6 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA2( 4, 8, 9 ) @@ -1348,7 +1348,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 VFMA2( 5, 24, 25 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA2( 6, 28, 29 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) @@ -1362,7 +1362,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 ALPHA_SCALE2( 7, 20, 21 ) ALPHA_SCALE2( 7, 24, 25 ) ALPHA_SCALE2( 7, 28, 29 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -1384,7 +1384,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 UPDATE_C2( 4, 28, 29 ) jmp( .SDONE ) // jump to the end - + label( .SCOLSTORED ) /** @@ -1411,12 +1411,12 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 TRANSPOSE_2X16( 24, 28 ) lea( mem( rcx, rdi, 2 ), rcx ) TRANSPOSE_2X16( 25, 29 ) - + jmp( .SDONE ) // jump to the end label(.SBETAZERO) - + cmp( imm( 4 ), rdi ) // set ZF if (4*rs_c) == 4. jz( .SCOLSTORBZ ) // jump to column storage case @@ -1429,7 +1429,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 UPDATE_C2_BZ( 20, 21 ) UPDATE_C2_BZ( 24, 25 ) UPDATE_C2_BZ( 28, 29 ) - + jmp( .SDONE ) // jump to the end @@ -1457,7 +1457,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 TRANSPOSE_2X16_BZ( 24, 28 ) lea( mem( rcx, rdi, 2 ), rcx ) TRANSPOSE_2X16_BZ( 25, 29 ) - + jmp( .SDONE ) // jump to the end @@ -1469,7 +1469,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 add( rdx, rax ) // a += rs_a * 6(MR) mov( rax, var( abuf ) ) // store updated a - mov( var( rs_c ), rdi ) + mov( var( rs_c ), rdi ) lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 lea( mem( , rdi, 2 ), rdx ) // rdx = rs_c * 2 lea( mem( rdx, rdi, 4 ), rdx ) // rdx = rdi * 4 => rdx = rs_c * 6 @@ -1541,7 +1541,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 ai += mr_cur * rs_a; m_left -= mr_cur; } - + if ( 2 <= m_left ) { const dim_t mr_cur = 2; @@ -1685,7 +1685,7 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 VFMA1( 5, 24 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA1( 6, 28 ) - + add( r9, rbx ) add( r10, rax ) @@ -1706,7 +1706,7 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 VFMA1( 5, 24 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA1( 6, 28 ) - + add( r9, rbx ) add( r10, rax ) @@ -1727,7 +1727,7 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 VFMA1( 5, 24 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA1( 6, 28 ) - + add( r9, rbx ) add( r10, rax ) @@ -1748,14 +1748,14 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 VFMA1( 5, 24 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA1( 6, 28 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop label( .CONSID_K_LEFT ) - + mov( var( k_left ), rsi ) // i = k_left; test( rsi, rsi ) // check i via logical AND. je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. @@ -1778,7 +1778,7 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 VFMA1( 5, 24 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA1( 6, 28 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) @@ -1792,7 +1792,7 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 ALPHA_SCALE1( 7, 20 ) ALPHA_SCALE1( 7, 24 ) ALPHA_SCALE1( 7, 28 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -1838,7 +1838,7 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 mov( var( cbuf ), rcx ) // load address of c lea( mem( rcx, r10, 4 ), rcx ) TRANSPOSE_2X16( 24, 28 ) - + jmp( .SDONE ) // jump to the end label( .SBETAZERO ) @@ -1855,7 +1855,7 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 UPDATE_C1_BZ( 20 ) UPDATE_C1_BZ( 24 ) UPDATE_C1_BZ( 28 ) - + jmp( .SDONE ) // jump to the end @@ -1880,7 +1880,7 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 mov( var( cbuf ), rcx ) // load address of c lea( mem( rcx, r10, 4 ), rcx ) TRANSPOSE_2X16_BZ( 24, 28 ) - + jmp( .SDONE ) // jump to the end @@ -1892,7 +1892,7 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 add( rdx, rax ) // a += rs_a * 6(MR) mov( rax, var( abuf ) ) // store updated a - mov( var( rs_c ), rdi ) + mov( var( rs_c ), rdi ) lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 lea( mem( , rdi, 2 ), rdx ) // rdx = rs_c * 2 lea( mem( rdx, rdi, 4 ), rdx ) // rdx = rdi * 4 => rdx = rs_c * 6 @@ -2017,7 +2017,8 @@ void bli_sgemmsup_rv_zen_asm_4x64m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -2051,15 +2052,110 @@ void bli_sgemmsup_rv_zen_asm_4x64m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. + label( .K_LEFT_LOOP ) // Load 4 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -2076,18 +2172,21 @@ void bli_sgemmsup_rv_zen_asm_4x64m_avx512 VFMA4( 6, 16, 17, 18, 19 ) vbroadcastss( mem( rax, r13, 1 ), zmm4 ) VFMA4( 4, 20, 21, 22, 23 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE4( 7, 8, 9, 10, 11 ) ALPHA_SCALE4( 7, 12, 13, 14, 15 ) ALPHA_SCALE4( 7, 16, 17, 18, 19 ) ALPHA_SCALE4( 7, 20, 21, 22, 23 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -2183,6 +2282,7 @@ void bli_sgemmsup_rv_zen_asm_4x64m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -2228,7 +2328,8 @@ void bli_sgemmsup_rv_zen_asm_4x48m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -2263,21 +2364,22 @@ void bli_sgemmsup_rv_zen_asm_4x48m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) - + // ITER 0 // Load 3 rows from B matrix. vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) vmovups( 0x80( rbx ), zmm2 ) - + // Broadcast 4 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA3( 4, 8, 9, 10 ) @@ -2287,18 +2389,110 @@ void bli_sgemmsup_rv_zen_asm_4x48m_avx512 VFMA3( 6, 16, 17, 18 ) vbroadcastss( mem( rax, r13, 1 ), zmm4 ) VFMA3( 4, 20, 21, 22 ) - + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA3( 4, 20, 21, 22 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA3( 4, 20, 21, 22 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA3( 4, 20, 21, 22 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) + + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA3( 6, 16, 17, 18 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA3( 4, 20, 21, 22 ) + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE3( 7, 8, 9, 10 ) ALPHA_SCALE3( 7, 12, 13, 14 ) ALPHA_SCALE3( 7, 16, 17, 18 ) ALPHA_SCALE3( 7, 20, 21, 22 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -2340,7 +2534,7 @@ void bli_sgemmsup_rv_zen_asm_4x48m_avx512 lea( mem( rcx, r12, 4 ), rcx ) TRANSPOSE_4X16( 10, 14, 18, 22 ) lea( mem( rcx, r12, 4 ), rcx ) - + jmp( .SDONE ) // jump to the end @@ -2380,7 +2574,7 @@ void bli_sgemmsup_rv_zen_asm_4x48m_avx512 TRANSPOSE_4X16_BZ( 9, 13, 17, 21 ) lea( mem( rcx, r12, 4 ), rcx ) TRANSPOSE_4X16_BZ( 10, 14, 18, 22 ) - + jmp( .SDONE ) // jump to the end @@ -2390,6 +2584,7 @@ void bli_sgemmsup_rv_zen_asm_4x48m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -2435,7 +2630,8 @@ void bli_sgemmsup_rv_zen_asm_4x32m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -2469,15 +2665,101 @@ void bli_sgemmsup_rv_zen_asm_4x32m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA2( 4, 20, 21 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA2( 4, 20, 21 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA2( 4, 20, 21 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA2( 6, 16, 17 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA2( 4, 20, 21 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 2 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -2492,18 +2774,21 @@ void bli_sgemmsup_rv_zen_asm_4x32m_avx512 VFMA2( 6, 16, 17 ) vbroadcastss( mem( rax, r13, 1 ), zmm4 ) VFMA2( 4, 20, 21 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE2( 7, 8, 9 ) ALPHA_SCALE2( 7, 12, 13 ) ALPHA_SCALE2( 7, 16, 17 ) ALPHA_SCALE2( 7, 20, 21 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -2592,6 +2877,7 @@ void bli_sgemmsup_rv_zen_asm_4x32m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -2637,7 +2923,8 @@ void bli_sgemmsup_rv_zen_asm_4x16m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -2672,15 +2959,97 @@ void bli_sgemmsup_rv_zen_asm_4x16m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA1( 4, 20 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA1( 4, 20 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA1( 4, 20 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA1( 6, 16 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA1( 4, 20 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 1 row from B matrix. vmovups( ( rbx ), zmm0 ) @@ -2694,18 +3063,21 @@ void bli_sgemmsup_rv_zen_asm_4x16m_avx512 VFMA1( 6, 16 ) vbroadcastss( mem( rax, r13, 1 ), zmm4 ) VFMA1( 4, 20 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE1( 7, 8 ) ALPHA_SCALE1( 7, 12 ) ALPHA_SCALE1( 7, 16 ) ALPHA_SCALE1( 7, 20 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -2790,6 +3162,7 @@ void bli_sgemmsup_rv_zen_asm_4x16m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -2835,7 +3208,8 @@ void bli_sgemmsup_rv_zen_asm_2x64m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -2869,15 +3243,93 @@ void bli_sgemmsup_rv_zen_asm_2x64m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 4 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -2890,16 +3342,19 @@ void bli_sgemmsup_rv_zen_asm_2x64m_avx512 VFMA4( 4, 8, 9, 10, 11 ) vbroadcastss( mem( rax, r8, 1 ), zmm5 ) VFMA4( 5, 12, 13, 14, 15 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE4( 7, 8, 9, 10, 11 ) ALPHA_SCALE4( 7, 12, 13, 14, 15 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -2991,6 +3446,7 @@ void bli_sgemmsup_rv_zen_asm_2x64m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -3036,7 +3492,8 @@ void bli_sgemmsup_rv_zen_asm_2x48m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -3070,15 +3527,89 @@ void bli_sgemmsup_rv_zen_asm_2x48m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA3( 5, 12, 13, 14 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 3 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -3090,16 +3621,19 @@ void bli_sgemmsup_rv_zen_asm_2x48m_avx512 VFMA3( 4, 8, 9, 10 ) vbroadcastss( mem( rax, r8, 1 ), zmm5 ) VFMA3( 5, 12, 13, 14 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE3( 7, 8, 9, 10 ) ALPHA_SCALE3( 7, 12, 13, 14 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -3187,6 +3721,7 @@ void bli_sgemmsup_rv_zen_asm_2x48m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -3232,7 +3767,8 @@ void bli_sgemmsup_rv_zen_asm_2x32m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -3266,35 +3802,108 @@ void bli_sgemmsup_rv_zen_asm_2x32m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) - + // ITER 0 // Load 2 rows from B matrix. vmovups( ( rbx ), zmm0 ) vmovups( 0x40( rbx ), zmm1 ) - + // Broadcast 2 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA2( 4, 8, 9 ) vbroadcastss( mem( rax, r8, 1 ), zmm5 ) VFMA2( 5, 12, 13 ) - + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) + + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA2( 5, 12, 13 ) + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE2( 7, 8, 9 ) ALPHA_SCALE2( 7, 12, 13 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -3378,6 +3987,7 @@ void bli_sgemmsup_rv_zen_asm_2x32m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -3423,7 +4033,8 @@ void bli_sgemmsup_rv_zen_asm_2x16m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -3457,15 +4068,81 @@ void bli_sgemmsup_rv_zen_asm_2x16m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 2 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA1( 5, 12 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 1 row from B matrix. vmovups( ( rbx ), zmm0 ) @@ -3475,16 +4152,19 @@ void bli_sgemmsup_rv_zen_asm_2x16m_avx512 VFMA1( 4, 8 ) vbroadcastss( mem( rax, r8, 1 ), zmm5 ) VFMA1( 5, 12 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE1( 7, 8 ) ALPHA_SCALE1( 7, 12 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -3563,6 +4243,7 @@ void bli_sgemmsup_rv_zen_asm_2x16m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -3608,7 +4289,8 @@ void bli_sgemmsup_rv_zen_asm_1x64m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -3642,15 +4324,84 @@ void bli_sgemmsup_rv_zen_asm_1x64m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label( .K_LEFT_LOOP ) // Load 4 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -3661,15 +4412,18 @@ void bli_sgemmsup_rv_zen_asm_1x64m_avx512 // Broadcast 1 element from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA4( 4, 8, 9, 10, 11 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE4( 7, 8, 9, 10, 11 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -3739,6 +4493,7 @@ void bli_sgemmsup_rv_zen_asm_1x64m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -3784,7 +4539,8 @@ void bli_sgemmsup_rv_zen_asm_1x48m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -3818,15 +4574,80 @@ void bli_sgemmsup_rv_zen_asm_1x48m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 3 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA3( 4, 8, 9, 10 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label( .K_LEFT_LOOP ) // Load 3 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -3836,15 +4657,18 @@ void bli_sgemmsup_rv_zen_asm_1x48m_avx512 // Broadcast 1 element from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA3( 4, 8, 9, 10 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE3( 7, 8, 9, 10 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -3912,6 +4736,7 @@ void bli_sgemmsup_rv_zen_asm_1x48m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -3957,7 +4782,8 @@ void bli_sgemmsup_rv_zen_asm_1x32m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -3991,15 +4817,77 @@ void bli_sgemmsup_rv_zen_asm_1x32m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 2 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA2( 4, 8, 9 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 2 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -4008,15 +4896,18 @@ void bli_sgemmsup_rv_zen_asm_1x32m_avx512 // Broadcast 1 element from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA2( 4, 8, 9 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE2( 7, 8, 9 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -4080,6 +4971,7 @@ void bli_sgemmsup_rv_zen_asm_1x32m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -4125,7 +5017,8 @@ void bli_sgemmsup_rv_zen_asm_1x16m_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; @@ -4159,31 +5052,92 @@ void bli_sgemmsup_rv_zen_asm_1x16m_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) - + // ITER 0 // Load 1 row from B matrix. vmovups( ( rbx ), zmm0 ) - + // Broadcast 1 element from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA1( 4, 8 ) - + add( r9, rbx ) add( r10, rax ) + + // ITER 1 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + + add( r9, rbx ) + add( r10, rax ) + dec( rsi ) jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) + + // Load 1 row from B matrix. + vmovups( ( rbx ), zmm0 ) + + // Broadcast 1 element from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA1( 4, 8 ) + + add( r9, rbx ) + add( r10, rax ) + dec( rsi ) + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) + // Scaling A * B with alpha. ALPHA_SCALE1( 7, 8 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -4247,6 +5201,7 @@ void bli_sgemmsup_rv_zen_asm_1x16m_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c index eaa025a5e..106e06dea 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c @@ -41,12 +41,12 @@ /* rrr: - -------- ------ -------- - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : - -------- ------ : + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : Assumptions: - B is row-stored; - A is row-stored; @@ -126,7 +126,8 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 } } - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t n_iter = n0 / 64; uint64_t n_left = n0 % 64; @@ -160,7 +161,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 mov( var( n_iter ), r11 ) // load n_iter label( .N_LOOP_ITER ) - + mov( var( rs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) @@ -170,14 +171,124 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 6 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA4( 5, 24, 25, 26, 27 ) + vbroadcastss( mem( rax, r15, 1 ), zmm6 ) + VFMA4( 6, 28, 29, 30, 31 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 6 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA4( 5, 24, 25, 26, 27 ) + vbroadcastss( mem( rax, r15, 1 ), zmm6 ) + VFMA4( 6, 28, 29, 30, 31 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 6 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA4( 5, 24, 25, 26, 27 ) + vbroadcastss( mem( rax, r15, 1 ), zmm6 ) + VFMA4( 6, 28, 29, 30, 31 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 6 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA4( 5, 24, 25, 26, 27 ) + vbroadcastss( mem( rax, r15, 1 ), zmm6 ) + VFMA4( 6, 28, 29, 30, 31 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 4 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -198,11 +309,14 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 VFMA4( 5, 24, 25, 26, 27 ) vbroadcastss( mem( rax, r15, 1 ), zmm6 ) VFMA4( 6, 28, 29, 30, 31 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE4( 7, 8, 9, 10, 11 ) @@ -211,7 +325,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 ALPHA_SCALE4( 7, 20, 21, 22, 23 ) ALPHA_SCALE4( 7, 24, 25, 26, 27 ) ALPHA_SCALE4( 7, 28, 29, 30, 31 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -304,7 +418,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 mov( var( cs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) lea( mem( rdi, rdi, 2 ), r12 ) - + TRANSPOSE_4X16_BZ( 8, 12, 16, 20 ) lea( mem( rcx, r12, 4 ), rcx ) TRANSPOSE_4X16_BZ( 9, 13, 17, 21 ) @@ -352,6 +466,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -390,7 +505,6 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 const dim_t j_edge = n0 - ( dim_t )n_left; uint64_t ps_b = bli_auxinfo_ps_b( data ); - uint64_t ps_b4 = ps_b * sizeof( float ); float* restrict cij = c + j_edge*cs_c; float* restrict bj = b + n_iter * ps_b; @@ -508,20 +622,20 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 else { const dim_t mr = 6; - + // Since A is packed into row panels, we must use a loop over // gemv. dim_t m_iter = ( m0 + mr - 1 ) / mr; dim_t m_left = m0 % mr; - + float* restrict ai_ii = ai; float* restrict cij_ii = cij; - + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) { dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, @@ -530,7 +644,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -553,7 +667,8 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t n_iter = n0 / 64; uint64_t n_left = n0 % 64; @@ -587,7 +702,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 mov( var( n_iter ), r11 ) // load n_iter label( .N_LOOP_ITER ) - + mov( var( rs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) @@ -597,14 +712,117 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA4( 5, 24, 25, 26, 27 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA4( 5, 24, 25, 26, 27 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA4( 5, 24, 25, 26, 27 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 5 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + vbroadcastss( mem( rax, r8, 4 ), zmm5 ) + VFMA4( 5, 24, 25, 26, 27 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 4 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -627,7 +845,10 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) + jne( .K_LEFT_LOOP ) + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE4( 7, 8, 9, 10, 11 ) @@ -635,7 +856,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 ALPHA_SCALE4( 7, 16, 17, 18, 19 ) ALPHA_SCALE4( 7, 20, 21, 22, 23 ) ALPHA_SCALE4( 7, 24, 25, 26, 27 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -654,7 +875,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 UPDATE_C4( 4, 16, 17, 18, 19 ) UPDATE_C4( 4, 20, 21, 22, 23 ) UPDATE_C4( 4, 24, 25, 26, 27 ) - + jmp( .SDONE ) // jump to the end @@ -729,7 +950,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 mov( var( cs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) lea( mem( rdi, rdi, 2 ), r12 ) - + TRANSPOSE_4X16_BZ( 8, 12, 16, 20 ) lea( mem( rcx, r12, 4 ), rcx ) TRANSPOSE_4X16_BZ( 9, 13, 17, 21 ) @@ -780,6 +1001,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -818,7 +1040,6 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 const dim_t j_edge = n0 - ( dim_t )n_left; uint64_t ps_b = bli_auxinfo_ps_b( data ); - uint64_t ps_b4 = ps_b * sizeof( float ); float* restrict cij = c + j_edge*cs_c; float* restrict bj = b + n_iter * ps_b; @@ -936,20 +1157,20 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 else { const dim_t mr = 5; - + // Since A is packed into row panels, we must use a loop over // gemv. dim_t m_iter = ( m0 + mr - 1 ) / mr; dim_t m_left = m0 % mr; - + float* restrict ai_ii = ai; float* restrict cij_ii = cij; - + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) { dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, @@ -958,7 +1179,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -981,7 +1202,8 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t n_iter = n0 / 64; uint64_t n_left = n0 % 64; @@ -1015,7 +1237,7 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 mov( var( n_iter ), r11 ) // load n_iter label( .N_LOOP_ITER ) - + mov( var( rs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) @@ -1025,14 +1247,108 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 4 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + vbroadcastss( mem( rax, r13, 1 ), zmm4 ) + VFMA4( 4, 20, 21, 22, 23 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 4 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -1049,18 +1365,21 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 VFMA4( 6, 16, 17, 18, 19 ) vbroadcastss( mem( rax, r13, 1 ), zmm4 ) VFMA4( 4, 20, 21, 22, 23 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE4( 7, 8, 9, 10, 11 ) ALPHA_SCALE4( 7, 12, 13, 14, 15 ) ALPHA_SCALE4( 7, 16, 17, 18, 19 ) ALPHA_SCALE4( 7, 20, 21, 22, 23 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -1138,7 +1457,7 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 mov( var( cs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) lea( mem( rdi, rdi, 2 ), r12 ) - + TRANSPOSE_4X16_BZ( 8, 12, 16, 20 ) lea( mem( rcx, r12, 4 ), rcx ) TRANSPOSE_4X16_BZ( 9, 13, 17, 21 ) @@ -1175,6 +1494,7 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -1213,7 +1533,6 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 const dim_t j_edge = n0 - ( dim_t )n_left; uint64_t ps_b = bli_auxinfo_ps_b( data ); - uint64_t ps_b4 = ps_b * sizeof( float ); float* restrict cij = c + j_edge*cs_c; float* restrict bj = b + n_iter * ps_b; @@ -1331,20 +1650,20 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 else { const dim_t mr = 4; - + // Since A is packed into row panels, we must use a loop over // gemv. dim_t m_iter = ( m0 + mr - 1 ) / mr; dim_t m_left = m0 % mr; - + float* restrict ai_ii = ai; float* restrict cij_ii = cij; - + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) { dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, @@ -1353,7 +1672,7 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -1376,7 +1695,8 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t n_iter = n0 / 64; uint64_t n_left = n0 % 64; @@ -1410,7 +1730,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 mov( var( n_iter ), r11 ) // load n_iter label( .N_LOOP_ITER ) - + mov( var( rs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) @@ -1420,14 +1740,100 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + vbroadcastss( mem( rax, r8, 2 ), zmm6 ) + VFMA4( 6, 16, 17, 18, 19 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 4 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -1442,17 +1848,20 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 VFMA4( 5, 12, 13, 14, 15 ) vbroadcastss( mem( rax, r8, 2 ), zmm6 ) VFMA4( 6, 16, 17, 18, 19 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE4( 7, 8, 9, 10, 11 ) ALPHA_SCALE4( 7, 12, 13, 14, 15 ) ALPHA_SCALE4( 7, 16, 17, 18, 19 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -1474,7 +1883,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 label( .SCOLSTORED ) - + /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c mov( var( cs_c ), rdi ) // load rs_c @@ -1497,7 +1906,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 mov( var( cs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) lea( mem( rdi, rdi, 2 ), r12 ) - + UPDATE_C_1X16( 16 ) UPDATE_C_1X16( 17 ) UPDATE_C_1X16( 18 ) @@ -1545,7 +1954,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 mov( var( cs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) lea( mem( rdi, rdi, 2 ), r12 ) - + UPDATE_C_1X16_BZ( 16 ) UPDATE_C_1X16_BZ( 17 ) UPDATE_C_1X16_BZ( 18 ) @@ -1579,6 +1988,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -1617,7 +2027,6 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 const dim_t j_edge = n0 - ( dim_t )n_left; uint64_t ps_b = bli_auxinfo_ps_b( data ); - uint64_t ps_b4 = ps_b * sizeof( float ); float* restrict cij = c + j_edge*cs_c; float* restrict bj = b + n_iter * ps_b; @@ -1735,20 +2144,20 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 else { const dim_t mr = 3; - + // Since A is packed into row panels, we must use a loop over // gemv. dim_t m_iter = ( m0 + mr - 1 ) / mr; dim_t m_left = m0 % mr; - + float* restrict ai_ii = ai; float* restrict cij_ii = cij; - + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) { dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, @@ -1757,7 +2166,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -1780,7 +2189,8 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t n_iter = n0 / 64; uint64_t n_left = n0 % 64; @@ -1814,7 +2224,7 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 mov( var( n_iter ), r11 ) // load n_iter label( .N_LOOP_ITER ) - + mov( var( rs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) @@ -1824,14 +2234,92 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + vbroadcastss( mem( rax, r8, 1 ), zmm5 ) + VFMA4( 5, 12, 13, 14, 15 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 4 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -1848,12 +2336,15 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE4( 7, 8, 9, 10, 11 ) ALPHA_SCALE4( 7, 12, 13, 14, 15 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -1874,7 +2365,7 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 label( .SCOLSTORED ) - + /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c mov( var( cs_c ), rdi ) // load rs_c @@ -1950,6 +2441,7 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -1988,7 +2480,6 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 const dim_t j_edge = n0 - ( dim_t )n_left; uint64_t ps_b = bli_auxinfo_ps_b( data ); - uint64_t ps_b4 = ps_b * sizeof( float ); float* restrict cij = c + j_edge*cs_c; float* restrict bj = b + n_iter * ps_b; @@ -2106,20 +2597,20 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 else { const dim_t mr = 2; - + // Since A is packed into row panels, we must use a loop over // gemv. dim_t m_iter = ( m0 + mr - 1 ) / mr; dim_t m_left = m0 % mr; - + float* restrict ai_ii = ai; float* restrict cij_ii = cij; - + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) { dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, @@ -2128,7 +2619,7 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -2151,7 +2642,8 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 cntx_t* restrict cntx ) { - uint64_t k_iter = k0; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; uint64_t n_iter = n0 / 64; uint64_t n_left = n0 % 64; @@ -2185,7 +2677,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 mov( var( n_iter ), r11 ) // load n_iter label( .N_LOOP_ITER ) - + mov( var( rs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) @@ -2195,14 +2687,84 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 mov( var( bbuf ), rbx ) // load address of b mov( var( cbuf ), rcx ) // load address of c - mov( var( k_iter ), rsi ) // load k_iter - test( rsi, rsi ) - mov( var( alpha ), rdx ) // load address of alpha vbroadcastss( ( rdx ), zmm7 ) + mov( var( k_iter ), rsi ) // load k_iter + test( rsi, rsi ) + je( .CONSID_K_LEFT ) + // The k-loop iterates over 4 rows of B, and broadcasts of each row of A. label( .K_LOOP_ITER ) + // ITER 0 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 1 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 2 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + + add( r9, rbx ) + add( r10, rax ) + + // ITER 3 + // Load 4 rows from B matrix. + vmovups( ( rbx ), zmm0 ) + vmovups( 0x40( rbx ), zmm1 ) + vmovups( 0x80( rbx ), zmm2 ) + vmovups( 0xc0( rbx ), zmm3 ) + + // Broadcast 3 elements from a row of A & do VFMA with rows of B. + vbroadcastss( ( rax ), zmm4 ) + VFMA4( 4, 8, 9, 10, 11 ) + + add( r9, rbx ) + add( r10, rax ) + + dec( rsi ) + jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + + + label( .CONSID_K_LEFT ) + + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label( .K_LEFT_LOOP ) // Load 4 rows from B matrix. vmovups( ( rbx ), zmm0 ) @@ -2213,15 +2775,18 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 // Broadcast 3 elements from a row of A & do VFMA with rows of B. vbroadcastss( ( rax ), zmm4 ) VFMA4( 4, 8, 9, 10, 11 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) - jne( .K_LOOP_ITER ) // if rsi != 0, repeat k-loop + jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop + + + label( .SPOSTACCUM ) // Scaling A * B with alpha. ALPHA_SCALE4( 7, 8, 9, 10, 11 ) - + mov( var( beta ), rdx ) // load address of beta vbroadcastss( ( rdx ), zmm4 ) @@ -2241,7 +2806,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 label( .SCOLSTORED ) - + /* Transposing 1x16 tiles to 16x1 tiles */ mov( var( cbuf ), rcx ) // load address of c mov( var( cs_c ), rdi ) // load rs_c @@ -2276,7 +2841,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 mov( var( cs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) lea( mem( rdi, rdi, 2 ), r12 ) - + UPDATE_C_1X16_BZ( 8 ) UPDATE_C_1X16_BZ( 9 ) UPDATE_C_1X16_BZ( 10 ) @@ -2310,6 +2875,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 : // output operands (none) : // input operands [k_iter] "m" (k_iter), + [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), @@ -2348,7 +2914,6 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 const dim_t j_edge = n0 - ( dim_t )n_left; uint64_t ps_b = bli_auxinfo_ps_b( data ); - uint64_t ps_b4 = ps_b * sizeof( float ); float* restrict cij = c + j_edge*cs_c; float* restrict bj = b + n_iter * ps_b; @@ -2466,20 +3031,20 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 else { const dim_t mr = 2; - + // Since A is packed into row panels, we must use a loop over // gemv. dim_t m_iter = ( m0 + mr - 1 ) / mr; dim_t m_left = m0 % mr; - + float* restrict ai_ii = ai; float* restrict cij_ii = cij; - + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) { dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, @@ -2488,7 +3053,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; }