Optimizing sgemm rd kernels on zen3 (#293)

Fixing some inefficiencies on the zen (AVX2) SUP RD kernel for SGEMM.
After performing the iteration for the 8 loop, the next loop that was being performed was the 1 loop for the k-direction.
This caused a lot of unnecessary iterations when the remainder of k < 8.
This has been fixed by introducing masked operations for k < 8
When remainder of k == 1, we handle this with the original non-masked code (with a branch) as the masked code introduces more penalty because of the masking operation.
There were also some unnecessary instructions in the zen4 kernels which have been removed.

AMD-Internal: https://amd.atlassian.net/browse/CPUPL-7775
Co-authored-by: rohrayan@amd.com
This commit is contained in:
Rayan, Rohan
2026-02-04 09:08:11 +05:30
committed by GitHub
parent 50ae5a05ef
commit ebf8721a5c
6 changed files with 4540 additions and 4094 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2023 - 2026, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -41,7 +41,6 @@
#define NR 64
void bli_sgemmsup_rd_zen4_asm_6x64m
(
conj_t conja,
@@ -182,8 +181,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
float *abuf = a;
float *bbuf = b;
float *cbuf = c;
@@ -202,8 +199,8 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( var(iter_1_mask), esi ) // Load mask values for the last loop
kmovw( esi, K(1) )
mov( imm( 0 ), r15 ) // jj = 0;
label( .SLOOP3X4J ) // LOOP OVER jj = [ 0 1 ... ]
@@ -216,11 +213,10 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
imul( imm( 1*4 ), rsi )
lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c
lea(mem( , r15, 1 ), rsi) // rsi = r15 = 4*jj;
lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj;
imul( r9, rsi ) // rsi *= cs_b;
lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b;
mov( var( m_iter ), r11 ) // ii = m_iter;
label( .SLOOP3X4I ) // LOOP OVER ii = [ m_iter ... 1 0 ]
@@ -233,11 +229,10 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
INIT_REG
mov( var( k_iter64 ), rsi ) // load k_iter
mov( var( k_iter64 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_32 )
label( .K_LOOP_ITER64 )
// ITER 0
@@ -340,16 +335,12 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
dec( rsi )
jne( .K_LOOP_ITER64 )
label( .CONSIDER_K_ITER_32 )
mov( var( k_iter32 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_16 )
// ITER 0
// load row from A
vmovups( ( rax ), zmm0 )
@@ -399,7 +390,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
@@ -510,7 +500,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
vucomiss( xmm1, xmm0 ) // check if beta = 0
je( .POST_ACCUM_STOR_BZ )
// Accumulating & storing the results when beta != 0
label( .POST_ACCUM_STOR )
@@ -536,7 +525,7 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
ALPHA_SCALE // Scaling the result of A*B with alpha
C_STOR // Storing result to C
C_STOR // Storing result to C
ZMM_TO_YMM( 20, 21, 22, 23, 4, 5, 6, 7 )
ZMM_TO_YMM( 24, 25, 26, 27, 8, 9, 10, 11 )
@@ -548,11 +537,10 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
ALPHA_SCALE // Scaling the result of A*B with alpha
C_STOR // Storing result to C
C_STOR // Storing result to C
jmp( .SDONE )
// Accumulating & storing the results when beta == 0
label( .POST_ACCUM_STOR_BZ )
@@ -639,7 +627,7 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
"zmm16", "zmm17", "zmm18", "zmm19",
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
"memory"
"memory", "k1"
)
consider_edge_cases:
@@ -763,13 +751,11 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( var( iter_1_mask ), esi ) // Load mask values for the last loop
kmovw( esi, K(1) )
mov( imm( 0 ), r15 ) // jj = 0;
label( .SLOOP3X4J ) // LOOP OVER jj = [ 0 1 ... ]
mov( var( abuf ), r14 ) // load address of a
mov( var( bbuf ), rdx ) // load address of b
mov( var( cbuf ), r12 ) // load address of c
@@ -782,24 +768,22 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
imul( r9, rsi ) // rsi *= cs_b;
lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b;
mov( var( m_iter ), r11 ) // ii = m_iter;
label( .SLOOP3X4I ) // LOOP OVER ii = [ m_iter ... 1 0 ]
lea( mem( r14 ), rax ) // load c to rcx
lea( mem( r12 ), rcx ) // load a to rax
lea( mem( r14 ), rax ) // load a to rax
lea( mem( r12 ), rcx ) // load c to rcx
lea( mem( rdx ), rbx ) // load b to rbx
lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b
lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b
lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_a
lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_a
INIT_REG
mov( var( k_iter64 ), rsi ) // load k_iter
mov( var( k_iter64 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_32 )
label( .K_LOOP_ITER64 )
// ITER 0
@@ -908,9 +892,6 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
test( rsi, rsi )
je( .CONSIDER_K_ITER_16 )
// ITER 0
// load row from A
vmovups( ( rax ), zmm0 )
@@ -960,7 +941,6 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
@@ -1063,7 +1043,6 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
VFMA6( 17, 18, 19, 29, 30, 31 )
label( .POST_ACCUM )
mov( var( beta ), rax ) // load address of beta
@@ -1072,7 +1051,6 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
vucomiss( xmm1, xmm0 ) // check if beta = 0
je( .POST_ACCUM_STOR_BZ )
// Accumulating & storing the results when beta != 0
label( .POST_ACCUM_STOR )
@@ -1114,7 +1092,6 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
jmp( .SDONE )
// Accumulating & storing the results when beta == 0
label( .POST_ACCUM_STOR_BZ )
@@ -1153,7 +1130,7 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
lea( mem( r14, r8, 4 ), r14 ) // a_ii = r14 += 6*rs_a
dec( r11 )
jne( .SLOOP3X4I ) // iterate again if ii != 0.
jne( .SLOOP3X4I ) // iterate again if ii != 0.
add( imm( 4 ), r15 )
cmp( imm( 48 ), r15 )
@@ -1201,7 +1178,7 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
"zmm16", "zmm17", "zmm18", "zmm19",
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
"memory"
"memory", "k1"
)
consider_edge_cases:
@@ -1325,9 +1302,8 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( var( iter_1_mask ), esi ) // Load mask values for the last loop
kmovw( esi, K(1) )
mov( imm(0), r15 ) // jj = 0;
label( .SLOOP3X4J ) // LOOP OVER jj = [ 0 1 ... ]
@@ -1344,24 +1320,22 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
imul( r9, rsi ) // rsi *= cs_b;
lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b;
mov( var( m_iter ), r11 ) // ii = m_iter;
label( .SLOOP3X4I ) // LOOP OVER ii = [ m_iter ... 1 0 ]
lea( mem( r14 ), rax ) // load c to rcx
lea( mem( r12 ), rcx ) // load a to rax
lea( mem( r14 ), rax ) // load a to rax
lea( mem( r12 ), rcx ) // load c to rcx
lea( mem( rdx ), rbx ) // load b to rbx
lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b
lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b
lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_a
lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_a
INIT_REG
mov( var( k_iter64 ), rsi ) // load k_iter
mov( var( k_iter64 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_32 )
label( .K_LOOP_ITER64 )
// ITER 0
@@ -1464,15 +1438,12 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
dec( rsi )
jne( .K_LOOP_ITER64 )
label( .CONSIDER_K_ITER_32 )
mov( var( k_iter32 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_16 )
// ITER 0
// load row from A
vmovups( ( rax ), zmm0 )
@@ -1522,7 +1493,6 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
@@ -1625,7 +1595,6 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
VFMA6( 17, 18, 19, 29, 30, 31 )
label( .POST_ACCUM )
mov( var( beta ), rax ) // load address of beta
@@ -1675,7 +1644,6 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
jmp( .SDONE )
// Accumulating & storing the results when beta == 0
label( .POST_ACCUM_STOR_BZ )
@@ -1714,13 +1682,12 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
lea( mem( r14, r8, 4 ), r14 ) // a_ii = r14 += 6*rs_a
dec( r11 )
jne( .SLOOP3X4I ) // iterate again if ii != 0.
jne( .SLOOP3X4I ) // iterate again if ii != 0.
add( imm( 4 ), r15 )
cmp( imm( 32 ), r15 )
jl( .SLOOP3X4J )
end_asm(
: // output operands (none)
: // input operands
@@ -1763,7 +1730,7 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
"zmm16", "zmm17", "zmm18", "zmm19",
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
"memory"
"memory", "k1"
)
consider_edge_cases:

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2023 - 2026, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -145,8 +145,8 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( var( iter_1_mask ), esi ) // Load mask values for the last loop
kmovw( esi, K(1) )
mov( imm( 0 ), r11 ) // ii = 0;
label( .SLOOP3X4I ) // LOOP OVER ii = [ 0 1 ... ]
@@ -166,7 +166,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
imul( r8, rsi ) // rsi *= cs_b;
lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*ii*cs_b;
mov( var( n_iter ), r15 ) // jj = n_iter;
label( .SLOOP3X4J ) // LOOP OVER jj = [ n_iter ... 1 0 ]
@@ -183,7 +182,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
test( rsi, rsi )
je( .CONSIDER_K_ITER_32 )
label( .K_LOOP_ITER64 )
// ITER 0
@@ -246,7 +244,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
// load column from B
vmovups( ( rbx ), zmm6 )
vmovups( ( rbx ), zmm6 )
VFMA6( 8, 9, 10, 20, 21, 22 )
vmovups( ( rbx, r9, 1 ), zmm6 )
@@ -287,16 +284,12 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
dec( rsi )
jne( .K_LOOP_ITER64 )
label( .CONSIDER_K_ITER_32 )
mov( var( k_iter32 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_16 )
// ITER 0
// load row from A
vmovups( ( rax ), zmm0 )
@@ -346,7 +339,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
@@ -397,7 +389,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
cmp( imm(8), rsi )
jle( .K_FLOATS_LEFT_LE_8 )
label( .K_FLOATS_LEFT_GT_8 )
vmovups( mem(rax), ZMM(0 MASK_KZ(1) ) )
@@ -450,7 +441,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
VFMA6( 17, 18, 19, 29, 30, 31 )
label( .POST_ACCUM )
mov( var( beta ), rax ) // load address of beta
@@ -459,7 +449,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
vucomiss( xmm1, xmm0 ) // check if beta = 0
je( .POST_ACCUM_STOR_BZ )
// Accumulating & storing the results when beta != 0
label( .POST_ACCUM_STOR )
@@ -501,7 +490,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
jmp( .SDONE )
// Accumulating & storing the results when beta == 0
label( .POST_ACCUM_STOR_BZ )
@@ -529,14 +517,13 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
C_STOR_BZ // Storing result to C
label( .SDONE )
add( imm(4*4), r12 )
lea(mem(r14, r9, 4), r14) // a_ii = r14 += 3*rs_a
lea( mem(r14, r9, 4), r14 ) // a_ii = r14 += 3*rs_a
dec( r15 )
jne( .SLOOP3X4J ) // iterate again if ii != 0.
jne( .SLOOP3X4J ) // iterate again if ii != 0.
end_asm(
: // output operands (none)
@@ -580,7 +567,7 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
"zmm16", "zmm17", "zmm18", "zmm19",
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
"memory"
"memory", "k1"
)
consider_edge_cases:
@@ -671,14 +658,13 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( var( iter_1_mask ), esi ) // Load mask values for the last loop
kmovw( esi, K(1) )
mov( var( abuf ), rdx ) // load address of a
mov( var( bbuf ), r14 ) // load address of b
mov( var( cbuf ), r12 ) // load address of c
mov( var( n_iter ), r15 ) // jj = m_iter;
label( .SLOOP3X4J ) // LOOP OVER jj = [ m_iter ... 1 0 ]
@@ -691,11 +677,10 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
INIT_REG
mov( var( k_iter64 ), rsi ) // load k_iter
mov( var( k_iter64 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_32 )
label( .K_LOOP_ITER64 )
// ITER 0
@@ -749,7 +734,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
// load column from B
vmovups( ( rbx ), zmm6 )
vmovups( ( rbx ), zmm6 )
VFMA3( 8, 9, 10 )
vmovups( ( rbx, r9, 1 ), zmm6 )
@@ -787,16 +771,12 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
dec( rsi )
jne( .K_LOOP_ITER64 )
label( .CONSIDER_K_ITER_32 )
mov( var( k_iter32 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_16 )
// ITER 0
// load row from A
vmovups( ( rax ), zmm0 )
@@ -840,7 +820,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
@@ -888,7 +867,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
cmp( imm(8), rsi )
jle( .K_FLOATS_LEFT_LE_8 )
label( .K_FLOATS_LEFT_GT_8 )
vmovups( mem(rax), ZMM(0 MASK_KZ(1) ) )
@@ -935,7 +913,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
VFMA3( 17, 18, 19 )
label( .POST_ACCUM )
mov( var( beta ), rax ) // load address of beta
@@ -944,7 +921,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
vucomiss( xmm1, xmm0 ) // check if beta = 0
je( .POST_ACCUM_STOR_BZ )
// Accumulating & storing the results when beta != 0
label( .POST_ACCUM_STOR )
@@ -962,7 +938,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
jmp( .SDONE )
// Accumulating & storing the results when beta == 0
label( .POST_ACCUM_STOR_BZ )
@@ -1027,7 +1002,7 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
"zmm16", "zmm17", "zmm18", "zmm19",
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
"memory"
"memory", "k1"
)
consider_edge_cases:
@@ -1120,14 +1095,13 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
mov(var(iter_1_mask), esi) // Load mask values for the last loop
kmovw(esi, K(1))
mov( var( iter_1_mask ), esi ) // Load mask values for the last loop
kmovw( esi, K(1) )
mov( var( abuf ), rdx ) // load address of a
mov( var( bbuf ), r14 ) // load address of b
mov( var( cbuf ), r12 ) // load address of c
mov( var( n_iter ), r15 ) // jj = m_iter;
label( .SLOOP3X4J ) // LOOP OVER ii = [ m_iter ... 1 0 ]
@@ -1140,11 +1114,10 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
INIT_REG
mov( var( k_iter64 ), rsi ) // load k_iter
mov( var( k_iter64 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_32 )
label( .K_LOOP_ITER64 )
// ITER 0
@@ -1195,7 +1168,6 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
// load column from B
vmovups( ( rbx ), zmm6 )
vmovups( ( rbx ), zmm6 )
VFMA2( 8, 9 )
vmovups( ( rbx, r9, 1 ), zmm6 )
@@ -1232,16 +1204,12 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
dec( rsi )
jne( .K_LOOP_ITER64 )
label( .CONSIDER_K_ITER_32 )
mov( var( k_iter32 ), rsi ) // load k_iter
test( rsi, rsi )
je( .CONSIDER_K_ITER_16 )
// ITER 0
// load row from A
vmovups( ( rax ), zmm0 )
@@ -1283,7 +1251,6 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_ITER_16 )
mov( var( k_iter16 ), rsi )
test( rsi, rsi )
@@ -1317,7 +1284,6 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
add( imm( 16*4 ), rbx )
label( .CONSIDER_K_LEFT_1 )
mov( var(k_left1), rsi )
test( rsi, rsi )
@@ -1375,7 +1341,6 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
VFMA2( 17, 18 )
label( .POST_ACCUM )
mov( var( beta ), rax ) // load address of beta
@@ -1385,7 +1350,6 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
vucomiss( xmm1, xmm0 ) // check if beta = 0
je( .POST_ACCUM_STOR_BZ )
// Accumulating & storing the results when beta != 0
label( .POST_ACCUM_STOR )
@@ -1401,11 +1365,9 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
jmp( .SDONE )
// Accumulating & storing the results when beta == 0
label( .POST_ACCUM_STOR_BZ )
ZMM_TO_YMM( 8, 9, 11, 12, 4, 5, 7, 8 )
ZMM_TO_YMM( 14, 15, 17, 18, 10, 11, 13, 14 )
@@ -1416,16 +1378,14 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
C_STOR_BZ2 // Storing result to C
label( .SDONE )
add( imm(4*4), r12 )
lea(mem(r14, r9, 4), r14) // a_ii = r14 += 3*rs_a
lea( mem(r14, r9, 4), r14 ) // a_ii = r14 += 3*rs_a
dec( r15 )
jne( .SLOOP3X4J ) // iterate again if jj != 0.
end_asm(
: // output operands (none)
: // input operands
@@ -1466,7 +1426,7 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
"zmm16", "zmm17", "zmm18", "zmm19",
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
"memory"
"memory", "k1"
)
consider_edge_cases: