Improving sgemm rd kernel on zen4/zen5 (#292)

Fixing some inefficiencies on the zen4 SUP RD kernel for SGEMM
The loops for the 8 and 1 iteration of the K-loop were performing loads on ymm/xmm registers and computation on zmm registers
This caused multiple unnecessary iterations in the kernel for matrices with certain k-values.
Fixed by introducing masked loads and computations for these cases

AMD-Internal: https://amd.atlassian.net/browse/CPUPL-7762
Co-authored-by: Rohan Rayan <rohrayan@amd.com>
This commit is contained in:
Rayan, Rohan
2025-12-17 18:48:50 +05:30
committed by GitHub
parent 504ac9d8a2
commit 9cbb1c45d8
3 changed files with 1567 additions and 1147 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -168,8 +168,9 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
uint64_t k_left64 = k0 % 64;
uint64_t k_iter32 = k_left64 / 32;
uint64_t k_left32 = k_left64 % 32;
uint64_t k_iter8 = k_left32 / 8;
uint64_t k_left1 = k_left32 % 8;
uint64_t k_iter16 = k_left32 / 16;
uint64_t k_left1 = k_left32 % 16;
int32_t iter_1_mask = (1 << k_left1) - 1;
uint64_t m_iter = m0 / 6;
uint64_t m_left = m0 % 6;
@@ -181,6 +182,8 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
float *abuf = a;
float *bbuf = b;
float *cbuf = c;
@@ -199,6 +202,8 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( imm( 0 ), r15 ) // jj = 0;
label( .SLOOP3X4J ) // LOOP OVER jj = [ 0 1 ... ]
@@ -340,10 +345,10 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
mov( var( k_iter32 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_8 )
je( .CONSIDER_K_ITER_16 )
label( .K_LOOP_ITER32 )
// ITER 0
// load row from A
@@ -394,93 +399,108 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
add( imm( 16*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER32 )
label( .CONSIDER_K_ITER_8 )
mov( var( k_iter8 ), rsi )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
je( .CONSIDER_K_LEFT_1 )
label( .K_LOOP_ITER8 )
// If the k-loop decomposition uses iterations of 64, 32, and 8 elements, which is inefficient for k values below 32.
// For example, when k=31, the current implementation requires three 8-element loops (processing 24 elements) followed
// by seven scalar 1-element loops (processing the remaining 7 elements), totaling 10 loop iterations.
// By redesigning the decomposition to use 64, 32, 16, and masked operations instead, the same k=31 case would
// require only one 16-element loop followed by a single masked operation to process the remaining 15 elements.
// This reduces the total iterations from 10 down to 2, significantly improving efficiency for k values in the range of 16-31.
// ITER 0
// Load row from A using ymm registers
// Upper 256-bit lanes are cleared for the
// zmm counterpart
vmovups( ( rax ), ymm0 )
vmovups( ( rax, r8, 1 ), ymm1 )
vmovups( ( rax, r8, 2 ), ymm2 )
vmovups( ( rax, r10, 1 ), ymm3 )
vmovups( ( rax, r8, 4 ), ymm4 )
vmovups( ( rax, rdi, 1 ), ymm5 )
add( imm( 8*4 ), rax )
vmovups( ( rax ), zmm0 )
vmovups( ( rax, r8, 1 ), zmm1 )
vmovups( ( rax, r8, 2 ), zmm2 )
vmovups( ( rax, r10, 1 ), zmm3 )
vmovups( ( rax, r8, 4 ), zmm4 )
vmovups( ( rax, rdi, 1 ), zmm5 )
add( imm( 16*4 ), rax )
// Load column from B using ymm registers
// Upper 256-bit lane is cleared for the
// zmm counterpart
// Thus, we can re-use the VFMA6 macro
vmovups( ( rbx ), ymm6 )
vmovups( ( rbx ), zmm6 )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovups( ( rbx, r9, 1 ), ymm6 )
vmovups( ( rbx, r9, 1 ), zmm6 )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovups( ( rbx, r9, 2 ), ymm6 )
vmovups( ( rbx, r9, 2 ), zmm6 )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovups( ( rbx, r13, 1 ), ymm6 )
vmovups( ( rbx, r13, 1 ), zmm6 )
VFMA6( 17, 18, 19, 29, 30, 31 )
add( imm( 8*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER8 )
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_LEFT_1 )
mov( var( k_left1 ), rsi )
mov( var(k_left1), rsi )
test( rsi, rsi )
je( .POST_ACCUM )
// In the case where we need to only compute on floats
// which fit in the ymm register, it is better to
// operate on masked ymm registers in this case
// because on Zen4, the throughput of masked loads
// on zmm is 0.5 while on ymm/xmm is 1
cmp( imm(8), rsi )
jle( .K_FLOATS_LEFT_LE_8 )
label( .K_LOOP_LEFT1 )
label( .K_FLOATS_LEFT_GT_8 )
// Instead of looping over element by element and performing
// VFMAs on for every element which is wasteful.
// Perform a masked FMA operation on the remaining elements
vmovups( mem(rax), ZMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), ZMM(1 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 2), ZMM(2 MASK_KZ(1) ) )
vmovups( mem(rax, r10, 1), ZMM(3 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 4), ZMM(4 MASK_KZ(1) ) )
vmovups( mem(rax, rdi, 1), ZMM(5 MASK_KZ(1) ) )
// Load row from A using xmm registers
// Upper 256-bit lanes and the upper 224
// bits of the lower 256-bit lane are cleared
// for the zmm counterpart
vmovss( ( rax ), xmm0 )
vmovss( ( rax, r8, 1 ), xmm1 )
vmovss( ( rax, r8, 2 ), xmm2 )
vmovss( ( rax, r10, 1 ), xmm3 )
vmovss( ( rax, r8, 4 ), xmm4 )
vmovss( ( rax, rdi, 1 ), xmm5 )
add( imm( 1*4 ), rax )
// Load column from B using xmm registers
// Upper 256-bit lanes and the upper 224
// bits of the lower 256-bit lane are cleared
// for the zmm counterpart
// Thus, we can re-use the VFMA6 macro
vmovss( ( rbx ), xmm6 )
vmovups( mem(rbx), ZMM(6 MASK_KZ(1) ) )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovss( ( rbx, r9, 1 ), xmm6 )
vmovups( mem(rbx, r9, 1), ZMM(6 MASK_KZ(1) ) )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovss( ( rbx, r9, 2 ), xmm6 )
vmovups( mem(rbx, r9, 2), ZMM(6 MASK_KZ(1) ) )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovss( ( rbx, r13, 1 ), xmm6 )
vmovups( mem(rbx, r13, 1), ZMM(6 MASK_KZ(1) ) )
VFMA6( 17, 18, 19, 29, 30, 31 )
add( imm( 1*4 ), rbx )
// unconditional branch to end of the loop after
// the computation of the case processing >8 floats
jmp( .POST_ACCUM )
dec( rsi )
jne( .K_LOOP_LEFT1 )
label( .K_FLOATS_LEFT_LE_8 )
// When operating on elements <= 8, it is better to operate
// on masked YMM registers on Zen4 because vmovups on
// masked YMM registers has a throughput of 1 while
// the same operation on ZMM has a throughput of 0.5
// Instead of looping over element by element and performing
// VFMAs on for every element which is wasteful.
// Perform a masked FMA operation on the remaining elements
vmovups( mem(rax), YMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), YMM(1 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 2), YMM(2 MASK_KZ(1) ) )
vmovups( mem(rax, r10, 1), YMM(3 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 4), YMM(4 MASK_KZ(1) ) )
vmovups( mem(rax, rdi, 1), YMM(5 MASK_KZ(1) ) )
vmovups( mem(rbx), YMM(6 MASK_KZ(1) ) )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovups( mem(rbx, r9, 1), YMM(6 MASK_KZ(1) ) )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovups( mem(rbx, r9, 2), YMM(6 MASK_KZ(1) ) )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
VFMA6( 17, 18, 19, 29, 30, 31 )
label( .POST_ACCUM )
@@ -580,11 +600,12 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
end_asm(
: // output operands (none)
: // input operands
[iter_1_mask] "m" (iter_1_mask),
[k_iter64] "m" (k_iter64),
[k_left64] "m" (k_left64),
[k_iter32] "m" (k_iter32),
[k_left32] "m" (k_left32),
[k_iter8] "m" (k_iter8),
[k_iter16] "m" (k_iter16),
[k_left1] "m" (k_left1),
[a] "m" (a),
[rs_a] "m" (rs_a),
@@ -710,8 +731,9 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
uint64_t k_left64 = k0 % 64;
uint64_t k_iter32 = k_left64 / 32;
uint64_t k_left32 = k_left64 % 32;
uint64_t k_iter8 = k_left32 / 8;
uint64_t k_left1 = k_left32 % 8;
uint64_t k_iter16 = k_left32 / 16;
uint64_t k_left1 = k_left32 % 16;
int32_t iter_1_mask = (1 << k_left1) - 1;
uint64_t m_iter = m0 / 6;
uint64_t m_left = m0 % 6;
@@ -741,6 +763,9 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( imm( 0 ), r15 ) // jj = 0;
label( .SLOOP3X4J ) // LOOP OVER jj = [ 0 1 ... ]
@@ -881,10 +906,10 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
mov( var( k_iter32 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_8 )
je( .CONSIDER_K_ITER_16 )
label( .K_LOOP_ITER32 )
// ITER 0
// load row from A
@@ -935,92 +960,108 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
add( imm( 16*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER32 )
label( .CONSIDER_K_ITER_8 )
mov( var( k_iter8 ), rsi )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
je( .CONSIDER_K_LEFT_1 )
je( .CONSIDER_K_LEFT_1)
label( .K_LOOP_ITER8 )
// If the k-loop decomposition uses iterations of 64, 32, and 8 elements, which is inefficient for k values below 32.
// For example, when k=31, the current implementation requires three 8-element loops (processing 24 elements) followed
// by seven scalar 1-element loops (processing the remaining 7 elements), totaling 10 loop iterations.
// By redesigning the decomposition to use 64, 32, 16, and masked operations instead, the same k=31 case would
// require only one 16-element loop followed by a single masked operation to process the remaining 15 elements.
// This reduces the total iterations from 10 down to 2, significantly improving efficiency for k values in the range of 16-31.
// ITER 0
// Load row from A using ymm registers
// Upper 256-bit lanes are cleared for the
// zmm counterpart
vmovups( ( rax ), ymm0 )
vmovups( ( rax, r8, 1 ), ymm1 )
vmovups( ( rax, r8, 2 ), ymm2 )
vmovups( ( rax, r10, 1 ), ymm3 )
vmovups( ( rax, r8, 4 ), ymm4 )
vmovups( ( rax, rdi, 1 ), ymm5 )
add( imm( 8*4 ), rax )
vmovups( ( rax ), zmm0 )
vmovups( ( rax, r8, 1 ), zmm1 )
vmovups( ( rax, r8, 2 ), zmm2 )
vmovups( ( rax, r10, 1 ), zmm3 )
vmovups( ( rax, r8, 4 ), zmm4 )
vmovups( ( rax, rdi, 1 ), zmm5 )
add( imm( 16*4 ), rax )
// Load column from B using ymm registers
// Upper 256-bit lane is cleared for the
// zmm counterpart
// Thus, we can re-use the VFMA6 macro
vmovups( ( rbx ), ymm6 )
vmovups( ( rbx ), zmm6 )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovups( ( rbx, r9, 1 ), ymm6 )
vmovups( ( rbx, r9, 1 ), zmm6 )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovups( ( rbx, r9, 2 ), ymm6 )
vmovups( ( rbx, r9, 2 ), zmm6 )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovups( ( rbx, r13, 1 ), ymm6 )
vmovups( ( rbx, r13, 1 ), zmm6 )
VFMA6( 17, 18, 19, 29, 30, 31 )
add( imm( 8*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER8 )
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_LEFT_1 )
mov( var( k_left1 ), rsi )
mov( var(k_left1), rsi )
test( rsi, rsi )
je( .POST_ACCUM )
// In the case where we need to only compute on floats
// which fit in the ymm register, it is better to
// operate on masked ymm registers in this case
// because on Zen4, the throughput of masked loads
// on zmm is 0.5 while on ymm/xmm is 1
cmp( imm(8), rsi )
jle( .K_FLOATS_LEFT_LE_8 )
label( .K_LOOP_LEFT1 )
label( .K_FLOATS_LEFT_GT_8 )
// Instead of looping over element by element and performing
// VFMAs on for every element which is wasteful.
// Perform a masked FMA operation on the remaining elements
vmovups( mem(rax), ZMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), ZMM(1 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 2), ZMM(2 MASK_KZ(1) ) )
vmovups( mem(rax, r10, 1), ZMM(3 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 4), ZMM(4 MASK_KZ(1) ) )
vmovups( mem(rax, rdi, 1), ZMM(5 MASK_KZ(1) ) )
// Load row from A using xmm registers
// Upper 256-bit lanes and the upper 224
// bits of the lower 256-bit lane are cleared
// for the zmm counterpart
vmovss( ( rax ), xmm0 )
vmovss( ( rax, r8, 1 ), xmm1 )
vmovss( ( rax, r8, 2 ), xmm2 )
vmovss( ( rax, r10, 1 ), xmm3 )
vmovss( ( rax, r8, 4 ), xmm4 )
vmovss( ( rax, rdi, 1 ), xmm5 )
add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4;
// Load column from B using xmm registers
// Upper 256-bit lanes and the upper 224
// bits of the lower 256-bit lane are cleared
// for the zmm counterpart
// Thus, we can re-use the VFMA6 macro
vmovss( ( rbx ), xmm6 )
vmovups( mem(rbx), ZMM(6 MASK_KZ(1) ) )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovss( ( rbx, r9, 1 ), xmm6 )
vmovups( mem(rbx, r9, 1), ZMM(6 MASK_KZ(1) ) )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovss( ( rbx, r9, 2 ), xmm6 )
vmovups( mem(rbx, r9, 2), ZMM(6 MASK_KZ(1) ) )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovss( ( rbx, r13, 1 ), xmm6 )
vmovups( mem(rbx, r13, 1), ZMM(6 MASK_KZ(1) ) )
VFMA6( 17, 18, 19, 29, 30, 31 )
add( imm( 1*4 ), rbx ) // b += 1*rs_b = 1*4;
// unconditional branch to end of the loop after
// the computation of the case processing >8 floats
jmp( .POST_ACCUM )
dec( rsi )
jne( .K_LOOP_LEFT1 )
label( .K_FLOATS_LEFT_LE_8 )
// When operating on elements <= 8, it is better to operate
// on masked YMM registers on Zen4 because vmovups on
// masked YMM registers has a throughput of 1 while
// the same operation on ZMM has a throughput of 0.5
// Instead of looping over element by element and performing
// VFMAs on for every element which is wasteful.
// Perform a masked FMA operation on the remaining elements
vmovups( mem(rax), YMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), YMM(1 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 2), YMM(2 MASK_KZ(1) ) )
vmovups( mem(rax, r10, 1), YMM(3 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 4), YMM(4 MASK_KZ(1) ) )
vmovups( mem(rax, rdi, 1), YMM(5 MASK_KZ(1) ) )
vmovups( mem(rbx), YMM(6 MASK_KZ(1) ) )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovups( mem(rbx, r9, 1), YMM(6 MASK_KZ(1) ) )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovups( mem(rbx, r9, 2), YMM(6 MASK_KZ(1) ) )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
VFMA6( 17, 18, 19, 29, 30, 31 )
label( .POST_ACCUM )
@@ -1121,11 +1162,12 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
end_asm(
: // output operands (none)
: // input operands
[iter_1_mask] "m" (iter_1_mask),
[k_iter64] "m" (k_iter64),
[k_left64] "m" (k_left64),
[k_iter32] "m" (k_iter32),
[k_left32] "m" (k_left32),
[k_iter8] "m" (k_iter8),
[k_iter16] "m" (k_iter16),
[k_left1] "m" (k_left1),
[a] "m" (a),
[rs_a] "m" (rs_a),
@@ -1251,8 +1293,9 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
uint64_t k_left64 = k0 % 64;
uint64_t k_iter32 = k_left64 / 32;
uint64_t k_left32 = k_left64 % 32;
uint64_t k_iter8 = k_left32 / 8;
uint64_t k_left1 = k_left32 % 8;
uint64_t k_iter16 = k_left32 / 16;
uint64_t k_left1 = k_left32 % 16;
int32_t iter_1_mask = (1 << k_left1) - 1;
uint64_t m_iter = m0 / 6;
uint64_t m_left = m0 % 6;
@@ -1282,6 +1325,9 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( imm(0), r15 ) // jj = 0;
label( .SLOOP3X4J ) // LOOP OVER jj = [ 0 1 ... ]
@@ -1423,9 +1469,9 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
mov( var( k_iter32 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_8 )
je( .CONSIDER_K_ITER_16 )
label( .K_LOOP_ITER32 )
// ITER 0
// load row from A
@@ -1476,96 +1522,112 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
add( imm( 16*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER32 )
label( .CONSIDER_K_ITER_8 )
mov( var( k_iter8 ), rsi )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
je( .CONSIDER_K_LEFT_1 )
je( .CONSIDER_K_LEFT_1)
label( .K_LOOP_ITER8 )
// If the k-loop decomposition uses iterations of 64, 32, and 8 elements, which is inefficient for k values below 32.
// For example, when k=31, the current implementation requires three 8-element loops (processing 24 elements) followed
// by seven scalar 1-element loops (processing the remaining 7 elements), totaling 10 loop iterations.
// By redesigning the decomposition to use 64, 32, 16, and masked operations instead, the same k=31 case would
// require only one 16-element loop followed by a single masked operation to process the remaining 15 elements.
// This reduces the total iterations from 10 down to 2, significantly improving efficiency for k values in the range of 16-31.
// ITER 0
// Load row from A using ymm registers
// Upper 256-bit lanes are cleared for the
// zmm counterpart
vmovups( ( rax ), ymm0 )
vmovups( ( rax, r8, 1 ), ymm1 )
vmovups( ( rax, r8, 2 ), ymm2 )
vmovups( ( rax, r10, 1 ), ymm3 )
vmovups( ( rax, r8, 4 ), ymm4 )
vmovups( ( rax, rdi, 1 ), ymm5 )
add( imm( 8*4 ), rax )
vmovups( ( rax ), zmm0 )
vmovups( ( rax, r8, 1 ), zmm1 )
vmovups( ( rax, r8, 2 ), zmm2 )
vmovups( ( rax, r10, 1 ), zmm3 )
vmovups( ( rax, r8, 4 ), zmm4 )
vmovups( ( rax, rdi, 1 ), zmm5 )
add( imm( 16*4 ), rax )
// Load column from B using ymm registers
// Upper 256-bit lane is cleared for the
// zmm counterpart
// Thus, we can re-use the VFMA6 macro
vmovups( ( rbx ), ymm6 )
vmovups( ( rbx ), zmm6 )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovups( ( rbx, r9, 1 ), ymm6 )
vmovups( ( rbx, r9, 1 ), zmm6 )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovups( ( rbx, r9, 2 ), ymm6 )
vmovups( ( rbx, r9, 2 ), zmm6 )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovups( ( rbx, r13, 1 ), ymm6 )
vmovups( ( rbx, r13, 1 ), zmm6 )
VFMA6( 17, 18, 19, 29, 30, 31 )
add( imm( 8*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER8 )
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_LEFT_1 )
mov( var( k_left1 ), rsi )
mov( var(k_left1), rsi )
test( rsi, rsi )
je( .POST_ACCUM )
// In the case where we need to only compute on floats
// which fit in the ymm register, it is better to
// operate on masked ymm registers in this case
// because on Zen4, the throughput of masked loads
// on zmm is 0.5 while on ymm/xmm is 1
cmp( imm(8), rsi )
jle( .K_FLOATS_LEFT_LE_8 )
label( .K_LOOP_LEFT1 )
label( .K_FLOATS_LEFT_GT_8 )
// Instead of looping over element by element and performing
// VFMAs on for every element which is wasteful.
// Perform a masked FMA operation on the remaining elements
vmovups( mem(rax), ZMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), ZMM(1 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 2), ZMM(2 MASK_KZ(1) ) )
vmovups( mem(rax, r10, 1), ZMM(3 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 4), ZMM(4 MASK_KZ(1) ) )
vmovups( mem(rax, rdi, 1), ZMM(5 MASK_KZ(1) ) )
// Load row from A using xmm registers
// Upper 256-bit lanes and the upper 224
// bits of the lower 256-bit lane are cleared
// for the zmm counterpart
vmovss( ( rax ), xmm0 )
vmovss( ( rax, r8, 1 ), xmm1 )
vmovss( ( rax, r8, 2 ), xmm2 )
vmovss( ( rax, r10, 1 ), xmm3 )
vmovss( ( rax, r8, 4 ), xmm4 )
vmovss( ( rax, rdi, 1 ), xmm5 )
add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4;
// Load column from B using xmm registers
// Upper 256-bit lanes and the upper 224
// bits of the lower 256-bit lane are cleared
// for the zmm counterpart
// Thus, we can re-use the VFMA6 macro
vmovss( ( rbx ), xmm6 )
vmovups( mem(rbx), ZMM(6 MASK_KZ(1) ) )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovss( ( rbx, r9, 1 ), xmm6 )
vmovups( mem(rbx, r9, 1), ZMM(6 MASK_KZ(1) ) )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovss( ( rbx, r9, 2 ), xmm6 )
vmovups( mem(rbx, r9, 2), ZMM(6 MASK_KZ(1) ) )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovss( ( rbx, r13, 1 ), xmm6 )
vmovups( mem(rbx, r13, 1), ZMM(6 MASK_KZ(1) ) )
VFMA6( 17, 18, 19, 29, 30, 31 )
add( imm( 1*4 ), rbx ) // b += 1*rs_b = 1*4;
// unconditional branch to end of the loop after
// the computation of the case processing >8 floats
jmp( .POST_ACCUM )
dec( rsi )
jne( .K_LOOP_LEFT1 )
label( .K_FLOATS_LEFT_LE_8 )
// When operating on elements <= 8, it is better to operate
// on masked YMM registers on Zen4 because vmovups on
// masked YMM registers has a throughput of 1 while
// the same operation on ZMM has a throughput of 0.5
// Instead of looping over element by element and performing
// VFMAs on for every element which is wasteful.
// Perform a masked FMA operation on the remaining elements
vmovups( mem(rax), YMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), YMM(1 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 2), YMM(2 MASK_KZ(1) ) )
vmovups( mem(rax, r10, 1), YMM(3 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 4), YMM(4 MASK_KZ(1) ) )
vmovups( mem(rax, rdi, 1), YMM(5 MASK_KZ(1) ) )
vmovups( mem(rbx), YMM(6 MASK_KZ(1) ) )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovups( mem(rbx, r9, 1), YMM(6 MASK_KZ(1) ) )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovups( mem(rbx, r9, 2), YMM(6 MASK_KZ(1) ) )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
VFMA6( 17, 18, 19, 29, 30, 31 )
label( .POST_ACCUM )
mov( var( beta ), rax ) // load address of beta
vbroadcastss( ( rax ), xmm0 )
vxorps( xmm1, xmm1, xmm1 )
@@ -1662,11 +1724,12 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
end_asm(
: // output operands (none)
: // input operands
[iter_1_mask] "m" (iter_1_mask),
[k_iter64] "m" (k_iter64),
[k_left64] "m" (k_left64),
[k_iter32] "m" (k_iter32),
[k_left32] "m" (k_left32),
[k_iter8] "m" (k_iter8),
[k_iter16] "m" (k_iter16),
[k_left1] "m" (k_left1),
[a] "m" (a),
[rs_a] "m" (rs_a),

View File

@@ -113,8 +113,9 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
uint64_t k_left64 = k0 % 64;
uint64_t k_iter32 = k_left64 / 32;
uint64_t k_left32 = k_left64 % 32;
uint64_t k_iter8 = k_left32 / 8;
uint64_t k_left1 = k_left32 % 8;
uint64_t k_iter16 = k_left32 / 16;
uint64_t k_left1 = k_left32 % 16;
int32_t iter_1_mask = (1 << k_left1) - 1;
uint64_t n_iter = n0 / 4;
uint64_t n_left = n0 % 4;
@@ -144,6 +145,9 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( imm( 0 ), r11 ) // ii = 0;
label( .SLOOP3X4I ) // LOOP OVER ii = [ 0 1 ... ]
@@ -288,10 +292,10 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
mov( var( k_iter32 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_8 )
je( .CONSIDER_K_ITER_16 )
label( .K_LOOP_ITER32 )
// ITER 0
// load row from A
@@ -342,78 +346,109 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
add( imm( 16*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER32 )
label( .CONSIDER_K_ITER_8 )
mov( var(k_iter8), rsi )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
je( .CONSIDER_K_LEFT_1 )
je( .CONSIDER_K_LEFT_1)
label( .K_LOOP_ITER8 )
// If the k-loop decomposition uses iterations of 64, 32, and 8 elements, which is inefficient for k values below 32.
// For example, when k=31, the current implementation requires three 8-element loops (processing 24 elements) followed
// by seven scalar 1-element loops (processing the remaining 7 elements), totaling 10 loop iterations.
// By redesigning the decomposition to use 64, 32, 16, and masked operations instead, the same k=31 case would
// require only one 16-element loop followed by a single masked operation to process the remaining 15 elements.
// This reduces the total iterations from 10 down to 2, significantly improving efficiency for k values in the range of 16-31.
// ITER 0
// load row from A
vmovups( ( rax ), ymm0 )
vmovups( ( rax, r8, 1 ), ymm1 )
vmovups( ( rax, r8, 2 ), ymm2 )
vmovups( ( rax, r10, 1 ), ymm3 )
vmovups( ( rax, r8, 4 ), ymm4 )
vmovups( ( rax, rdi, 1 ), ymm5 )
add( imm( 8*4 ), rax )
vmovups( ( rax ), zmm0 )
vmovups( ( rax, r8, 1 ), zmm1 )
vmovups( ( rax, r8, 2 ), zmm2 )
vmovups( ( rax, r10, 1 ), zmm3 )
vmovups( ( rax, r8, 4 ), zmm4 )
vmovups( ( rax, rdi, 1 ), zmm5 )
add( imm( 16*4 ), rax )
// load column from B
vmovups( ( rbx ), ymm6 )
vmovups( ( rbx ), zmm6 )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovups( ( rbx, r9, 1 ), ymm6 )
vmovups( ( rbx, r9, 1 ), zmm6 )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovups( ( rbx, r9, 2 ), ymm6 )
vmovups( ( rbx, r9, 2 ), zmm6 )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovups( ( rbx, r13, 1 ), ymm6 )
vmovups( ( rbx, r13, 1 ), zmm6 )
VFMA6( 17, 18, 19, 29, 30, 31 )
add( imm( 8*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER8 )
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_LEFT_1 )
mov( var(k_left1), rsi )
test( rsi, rsi )
je( .POST_ACCUM )
// In the case where we need to only compute on floats
// which fit in the ymm register, it is better to
// operate on masked ymm registers in this case
// because on Zen4, the throughput of masked loads
// on zmm is 0.5 while on ymm/xmm is 1
cmp( imm(8), rsi )
jle( .K_FLOATS_LEFT_LE_8 )
label( .K_LOOP_LEFT1 )
label( .K_FLOATS_LEFT_GT_8 )
vmovss( ( rax ), xmm0 )
vmovss( ( rax, r8, 1 ), xmm1 )
vmovss( ( rax, r8, 2 ), xmm2 )
vmovss( ( rax, r10, 1 ), xmm3 )
vmovss( ( rax, r8, 4 ), xmm4 )
vmovss( ( rax, rdi, 1 ), xmm5 )
add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4;
vmovups( mem(rax), ZMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), ZMM(1 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 2), ZMM(2 MASK_KZ(1) ) )
vmovups( mem(rax, r10, 1), ZMM(3 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 4), ZMM(4 MASK_KZ(1) ) )
vmovups( mem(rax, rdi, 1), ZMM(5 MASK_KZ(1) ) )
vmovss( ( rbx ), xmm6 )
vmovups( mem(rbx), ZMM(6 MASK_KZ(1) ) )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovss( ( rbx, r9, 1 ), xmm6 )
vmovups( mem(rbx, r9, 1), ZMM(6 MASK_KZ(1) ) )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovss( ( rbx, r9, 2 ), xmm6 )
vmovups( mem(rbx, r9, 2), ZMM(6 MASK_KZ(1) ) )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovss( ( rbx, r13, 1 ), xmm6 )
vmovups( mem(rbx, r13, 1), ZMM(6 MASK_KZ(1) ) )
VFMA6( 17, 18, 19, 29, 30, 31 )
// unconditional branch to end of the loop after
// the computation of the case processing >8 floats
jmp( .POST_ACCUM )
add( imm( 1*4 ), rbx ) // b += 1*rs_b = 1*4;
label( .K_FLOATS_LEFT_LE_8 )
// When operating on elements <= 8, it is better to operate
// on masked YMM registers on Zen4 because vmovups on
// masked YMM registers has a throughput of 1 while
// the same operation on ZMM has a throughput of 0.5
// Instead of looping over element by element and performing
// VFMAs on for every element which is wasteful.
// Perform a masked FMA operation on the remaining elements
vmovups( mem(rax), YMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), YMM(1 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 2), YMM(2 MASK_KZ(1) ) )
vmovups( mem(rax, r10, 1), YMM(3 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 4), YMM(4 MASK_KZ(1) ) )
vmovups( mem(rax, rdi, 1), YMM(5 MASK_KZ(1) ) )
dec( rsi )
jne( .K_LOOP_LEFT1 )
vmovups( mem(rbx), YMM(6 MASK_KZ(1) ) )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovups( mem(rbx, r9, 1), YMM(6 MASK_KZ(1) ) )
VFMA6( 11, 12, 13, 23, 24, 25 )
vmovups( mem(rbx, r9, 2), YMM(6 MASK_KZ(1) ) )
VFMA6( 14, 15, 16, 26, 27, 28 )
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
VFMA6( 17, 18, 19, 29, 30, 31 )
label( .POST_ACCUM )
@@ -506,11 +541,12 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
end_asm(
: // output operands (none)
: // input operands
[iter_1_mask] "m" (iter_1_mask),
[k_iter64] "m" (k_iter64),
[k_left64] "m" (k_left64),
[k_iter32] "m" (k_iter32),
[k_left32] "m" (k_left32),
[k_iter8] "m" (k_iter8),
[k_iter16] "m" (k_iter16),
[k_left1] "m" (k_left1),
[a] "m" (a),
[rs_a] "m" (rs_a),
@@ -605,8 +641,9 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
uint64_t k_left64 = k0 % 64;
uint64_t k_iter32 = k_left64 / 32;
uint64_t k_left32 = k_left64 % 32;
uint64_t k_iter8 = k_left32 / 8;
uint64_t k_left1 = k_left32 % 8;
uint64_t k_iter16 = k_left32 / 16;
uint64_t k_left1 = k_left32 % 16;
int32_t iter_1_mask = (1 << k_left1) - 1;
uint64_t n_iter = n0 / 4;
uint64_t n_left = n0 % 4;
@@ -634,6 +671,9 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( var( abuf ), rdx ) // load address of a
mov( var( bbuf ), r14 ) // load address of b
mov( var( cbuf ), r12 ) // load address of c
@@ -752,10 +792,10 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
mov( var( k_iter32 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_8 )
je( .CONSIDER_K_ITER_16 )
label( .K_LOOP_ITER32 )
// ITER 0
// load row from A
@@ -800,72 +840,100 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
add( imm( 16*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER32 )
label( .CONSIDER_K_ITER_8 )
mov( var(k_iter8), rsi )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
je( .CONSIDER_K_LEFT_1 )
je( .CONSIDER_K_LEFT_1)
label( .K_LOOP_ITER8 )
// If the k-loop decomposition uses iterations of 64, 32, and 8 elements, which is inefficient for k values below 32.
// For example, when k=31, the current implementation requires three 8-element loops (processing 24 elements) followed
// by seven scalar 1-element loops (processing the remaining 7 elements), totaling 10 loop iterations.
// By redesigning the decomposition to use 64, 32, 16, and masked operations instead, the same k=31 case would
// require only one 16-element loop followed by a single masked operation to process the remaining 15 elements.
// This reduces the total iterations from 10 down to 2, significantly improving efficiency for k values in the range of 16-31.
// ITER 0
// load row from A
vmovups( ( rax ), ymm0 )
vmovups( ( rax, r8, 1 ), ymm1 )
vmovups( ( rax, r8, 2 ), ymm2 )
add( imm( 8*4 ), rax )
vmovups( ( rax ), zmm0 )
vmovups( ( rax, r8, 1 ), zmm1 )
vmovups( ( rax, r8, 2 ), zmm2 )
add( imm( 16*4 ), rax )
// load column from B
vmovups( ( rbx ), ymm6 )
vmovups( ( rbx ), zmm6 )
VFMA3( 8, 9, 10 )
vmovups( ( rbx, r9, 1 ), ymm6 )
vmovups( ( rbx, r9, 1 ), zmm6 )
VFMA3( 11, 12, 13 )
vmovups( ( rbx, r9, 2 ), ymm6 )
vmovups( ( rbx, r9, 2 ), zmm6 )
VFMA3( 14, 15, 16 )
vmovups( ( rbx, r13, 1 ), ymm6 )
vmovups( ( rbx, r13, 1 ), zmm6 )
VFMA3( 17, 18, 19 )
add( imm( 8*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER8 )
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_LEFT_1 )
mov( var(k_left1), rsi )
test( rsi, rsi )
je( .POST_ACCUM )
// In the case where we need to only compute on floats
// which fit in the ymm register, it is better to
// operate on masked ymm registers in this case
// because on Zen4, the throughput of masked loads
// on zmm is 0.5 while on ymm/xmm is 1
cmp( imm(8), rsi )
jle( .K_FLOATS_LEFT_LE_8 )
label( .K_LOOP_LEFT1 )
label( .K_FLOATS_LEFT_GT_8 )
vmovss( ( rax ), xmm0 )
vmovss( ( rax, r8, 1 ), xmm1 )
vmovss( ( rax, r8, 2 ), xmm2 )
add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4;
vmovups( mem(rax), ZMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), ZMM(1 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 2), ZMM(2 MASK_KZ(1) ) )
vmovss( ( rbx ), xmm6 )
vmovups( mem(rbx), ZMM(6 MASK_KZ(1) ) )
VFMA3( 8, 9, 10 )
vmovss( ( rbx, r9, 1 ), xmm6 )
vmovups( mem(rbx, r9, 1), ZMM(6 MASK_KZ(1) ) )
VFMA3( 11, 12, 13 )
vmovss( ( rbx, r9, 2 ), xmm6 )
vmovups( mem(rbx, r9, 2), ZMM(6 MASK_KZ(1) ) )
VFMA3( 14, 15, 16 )
vmovss( ( rbx, r13, 1 ), xmm6 )
vmovups( mem(rbx, r13, 1), ZMM(6 MASK_KZ(1) ) )
VFMA3( 17, 18, 19 )
add( imm( 1*4 ), rbx ) // b += 1*rs_b = 1*4;
// unconditional branch to end of the loop after
// the computation of the case processing >8 floats
jmp( .POST_ACCUM )
dec( rsi )
jne( .K_LOOP_LEFT1 )
label( .K_FLOATS_LEFT_LE_8 )
// When operating on elements <= 8, it is better to operate
// on masked YMM registers on Zen4 because vmovups on
// masked YMM registers has a throughput of 1 while
// the same operation on ZMM has a throughput of 0.5
// Instead of looping over element by element and performing
// VFMAs on for every element which is wasteful.
// Perform a masked FMA operation on the remaining elements
vmovups( mem(rax), YMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), YMM(1 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 2), YMM(2 MASK_KZ(1) ) )
vmovups( mem(rbx), YMM(6 MASK_KZ(1) ) )
VFMA3( 8, 9, 10 )
vmovups( mem(rbx, r9, 1), YMM(6 MASK_KZ(1) ) )
VFMA3( 11, 12, 13 )
vmovups( mem(rbx, r9, 2), YMM(6 MASK_KZ(1) ) )
VFMA3( 14, 15, 16 )
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
VFMA3( 17, 18, 19 )
label( .POST_ACCUM )
@@ -922,11 +990,12 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
end_asm(
: // output operands (none)
: // input operands
[iter_1_mask] "m" (iter_1_mask),
[k_iter64] "m" (k_iter64),
[k_left64] "m" (k_left64),
[k_iter32] "m" (k_iter32),
[k_left32] "m" (k_left32),
[k_iter8] "m" (k_iter8),
[k_iter16] "m" (k_iter16),
[k_left1] "m" (k_left1),
[a] "m" (a),
[rs_a] "m" (rs_a),
@@ -1019,8 +1088,9 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
uint64_t k_left64 = k0 % 64;
uint64_t k_iter32 = k_left64 / 32;
uint64_t k_left32 = k_left64 % 32;
uint64_t k_iter8 = k_left32 / 8;
uint64_t k_left1 = k_left32 % 8;
uint64_t k_iter16 = k_left32 / 16;
uint64_t k_left1 = k_left32 % 16;
int32_t iter_1_mask = (1 << k_left1) - 1;
uint64_t n_iter = n0 / 4;
uint64_t n_left = n0 % 4;
@@ -1050,6 +1120,9 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( var( abuf ), rdx ) // load address of a
mov( var( bbuf ), r14 ) // load address of b
mov( var( cbuf ), r12 ) // load address of c
@@ -1164,10 +1237,10 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
mov( var( k_iter32 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_8 )
je( .CONSIDER_K_ITER_16 )
label( .K_LOOP_ITER32 )
// ITER 0
// load row from A
@@ -1210,70 +1283,97 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
add( imm( 16*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER32 )
label( .CONSIDER_K_ITER_8 )
mov( var(k_iter8), rsi )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
je( .CONSIDER_K_LEFT_1 )
je( .CONSIDER_K_LEFT_1)
label( .K_LOOP_ITER8 )
// If the k-loop decomposition uses iterations of 64, 32, and 8 elements, which is inefficient for k values below 32.
// For example, when k=31, the current implementation requires three 8-element loops (processing 24 elements) followed
// by seven scalar 1-element loops (processing the remaining 7 elements), totaling 10 loop iterations.
// By redesigning the decomposition to use 64, 32, 16, and masked operations instead, the same k=31 case would
// require only one 16-element loop followed by a single masked operation to process the remaining 15 elements.
// This reduces the total iterations from 10 down to 2, significantly improving efficiency for k values in the range of 16-31.
// ITER 0
// load row from A
vmovups( ( rax ), ymm0 )
vmovups( ( rax, r8, 1 ), ymm1 )
add( imm( 8*4 ), rax )
vmovups( ( rax ), zmm0 )
vmovups( ( rax, r8, 1 ), zmm1 )
add( imm( 16*4 ), rax )
// load column from B
vmovups( ( rbx ), ymm6 )
vmovups( ( rbx ), zmm6 )
VFMA2( 8, 9 )
vmovups( ( rbx, r9, 1 ), ymm6 )
vmovups( ( rbx, r9, 1 ), zmm6 )
VFMA2( 11, 12 )
vmovups( ( rbx, r9, 2 ), ymm6 )
vmovups( ( rbx, r9, 2 ), zmm6 )
VFMA2( 14, 15 )
vmovups( ( rbx, r13, 1 ), ymm6 )
vmovups( ( rbx, r13, 1 ), zmm6 )
VFMA2( 17, 18 )
add( imm( 8*4 ), rbx )
dec( rsi )
jne( .K_LOOP_ITER8 )
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_LEFT_1 )
mov( var(k_left1), rsi )
test( rsi, rsi )
je( .POST_ACCUM )
// In the case where we need to only compute on floats
// which fit in the ymm register, it is better to
// operate on masked ymm registers in this case
// because on Zen4, the throughput of masked loads
// on zmm is 0.5 while on ymm/xmm is 1
cmp( imm(8), rsi )
jle( .K_FLOATS_LEFT_LE_8 )
label( .K_LOOP_LEFT1 )
label( .K_FLOATS_LEFT_GT_8 )
vmovss( ( rax ), xmm0 )
vmovss( ( rax, r8, 1 ), xmm1 )
add( imm( 1*4 ), rax ) // a += 1*cs_b = 1*4;
vmovups( mem(rax), ZMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), ZMM(1 MASK_KZ(1) ) )
vmovss( ( rbx ), xmm6 )
vmovups( mem(rbx), ZMM(6 MASK_KZ(1) ) )
VFMA2( 8, 9 )
vmovss( ( rbx, r9, 1 ), xmm6 )
vmovups( mem(rbx, r9, 1), ZMM(6 MASK_KZ(1) ) )
VFMA2( 11, 12 )
vmovss( ( rbx, r9, 2 ), xmm6 )
vmovups( mem(rbx, r9, 2), ZMM(6 MASK_KZ(1) ) )
VFMA2( 14, 15 )
vmovss( ( rbx, r13, 1 ), xmm6 )
vmovups( mem(rbx, r13, 1), ZMM(6 MASK_KZ(1) ) )
VFMA2( 17, 18 )
add( imm( 1*4 ), rbx ) // b += 1*rs_b = 1*4;
// unconditional branch to end of the loop after
// the computation of the case processing >8 floats
jmp( .POST_ACCUM )
dec( rsi )
jne( .K_LOOP_LEFT1 )
label( .K_FLOATS_LEFT_LE_8 )
// When operating on elements <= 8, it is better to operate
// on masked YMM registers on Zen4 because vmovups on
// masked YMM registers has a throughput of 1 while
// the same operation on ZMM has a throughput of 0.5
// Instead of looping over element by element and performing
// VFMAs on for every element which is wasteful.
// Perform a masked FMA operation on the remaining elements
vmovups( mem(rax), YMM(0 MASK_KZ(1) ) )
vmovups( mem(rax, r8, 1), YMM(1 MASK_KZ(1) ) )
vmovups( mem(rbx), YMM(6 MASK_KZ(1) ) )
VFMA2( 8, 9 )
vmovups( mem(rbx, r9, 1), YMM(6 MASK_KZ(1) ) )
VFMA2( 11, 12 )
vmovups( mem(rbx, r9, 2), YMM(6 MASK_KZ(1) ) )
VFMA2( 14, 15 )
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
VFMA2( 17, 18 )
label( .POST_ACCUM )
@@ -1329,11 +1429,12 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
end_asm(
: // output operands (none)
: // input operands
[iter_1_mask] "m" (iter_1_mask),
[k_iter64] "m" (k_iter64),
[k_left64] "m" (k_left64),
[k_iter32] "m" (k_iter32),
[k_left32] "m" (k_left32),
[k_iter8] "m" (k_iter8),
[k_iter16] "m" (k_iter16),
[k_left1] "m" (k_left1),
[a] "m" (a),
[rs_a] "m" (rs_a),