mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Optimizing sgemm rd kernels on zen3 (#293)
Fixing some inefficiencies on the zen (AVX2) SUP RD kernel for SGEMM. After performing the iteration for the 8 loop, the next loop that was being performed was the 1 loop for the k-direction. This caused a lot of unnecessary iterations when the remainder of k < 8. This has been fixed by introducing masked operations for k < 8 When remainder of k == 1, we handle this with the original non-masked code (with a branch) as the masked code introduces more penalty because of the masking operation. There were also some unnecessary instructions in the zen4 kernels which have been removed. AMD-Internal: https://amd.atlassian.net/browse/CPUPL-7775 Co-authored-by: rohrayan@amd.com
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2023 - 2026, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -41,7 +41,6 @@
|
||||
|
||||
#define NR 64
|
||||
|
||||
|
||||
void bli_sgemmsup_rd_zen4_asm_6x64m
|
||||
(
|
||||
conj_t conja,
|
||||
@@ -182,8 +181,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
|
||||
|
||||
float *abuf = a;
|
||||
float *bbuf = b;
|
||||
float *cbuf = c;
|
||||
@@ -202,8 +199,8 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
|
||||
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
|
||||
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
|
||||
|
||||
mov(var(iter_1_mask), esi) // Load mask values for the last loop
|
||||
kmovw(esi, K(1))
|
||||
mov( var(iter_1_mask), esi ) // Load mask values for the last loop
|
||||
kmovw( esi, K(1) )
|
||||
|
||||
mov( imm( 0 ), r15 ) // jj = 0;
|
||||
label( .SLOOP3X4J ) // LOOP OVER jj = [ 0 1 ... ]
|
||||
@@ -216,11 +213,10 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
|
||||
imul( imm( 1*4 ), rsi )
|
||||
lea( mem( r12, rsi, 1 ), r12 ) // c += r15 * cs_c
|
||||
|
||||
lea(mem( , r15, 1 ), rsi) // rsi = r15 = 4*jj;
|
||||
lea( mem( , r15, 1 ), rsi ) // rsi = r15 = 4*jj;
|
||||
imul( r9, rsi ) // rsi *= cs_b;
|
||||
lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b;
|
||||
|
||||
|
||||
mov( var( m_iter ), r11 ) // ii = m_iter;
|
||||
label( .SLOOP3X4I ) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
|
||||
@@ -233,11 +229,10 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
|
||||
|
||||
INIT_REG
|
||||
|
||||
mov( var( k_iter64 ), rsi ) // load k_iter
|
||||
mov( var( k_iter64 ), rsi ) // load k_iter
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_32 )
|
||||
|
||||
|
||||
label( .K_LOOP_ITER64 )
|
||||
|
||||
// ITER 0
|
||||
@@ -340,16 +335,12 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
|
||||
dec( rsi )
|
||||
jne( .K_LOOP_ITER64 )
|
||||
|
||||
|
||||
label( .CONSIDER_K_ITER_32 )
|
||||
|
||||
mov( var( k_iter32 ), rsi ) // load k_iter
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_16 )
|
||||
|
||||
|
||||
|
||||
|
||||
// ITER 0
|
||||
// load row from A
|
||||
vmovups( ( rax ), zmm0 )
|
||||
@@ -399,7 +390,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
|
||||
|
||||
add( imm( 16*4 ), rbx )
|
||||
|
||||
|
||||
label( .CONSIDER_K_ITER_16 )
|
||||
mov( var( k_iter16 ), rsi )
|
||||
test( rsi, rsi )
|
||||
@@ -510,7 +500,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
|
||||
vucomiss( xmm1, xmm0 ) // check if beta = 0
|
||||
je( .POST_ACCUM_STOR_BZ )
|
||||
|
||||
|
||||
// Accumulating & storing the results when beta != 0
|
||||
label( .POST_ACCUM_STOR )
|
||||
|
||||
@@ -536,7 +525,7 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
|
||||
|
||||
ALPHA_SCALE // Scaling the result of A*B with alpha
|
||||
|
||||
C_STOR // Storing result to C
|
||||
C_STOR // Storing result to C
|
||||
|
||||
ZMM_TO_YMM( 20, 21, 22, 23, 4, 5, 6, 7 )
|
||||
ZMM_TO_YMM( 24, 25, 26, 27, 8, 9, 10, 11 )
|
||||
@@ -548,11 +537,10 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
|
||||
|
||||
ALPHA_SCALE // Scaling the result of A*B with alpha
|
||||
|
||||
C_STOR // Storing result to C
|
||||
C_STOR // Storing result to C
|
||||
|
||||
jmp( .SDONE )
|
||||
|
||||
|
||||
// Accumulating & storing the results when beta == 0
|
||||
label( .POST_ACCUM_STOR_BZ )
|
||||
|
||||
@@ -639,7 +627,7 @@ void bli_sgemmsup_rd_zen4_asm_6x64m
|
||||
"zmm16", "zmm17", "zmm18", "zmm19",
|
||||
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
|
||||
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
|
||||
"memory"
|
||||
"memory", "k1"
|
||||
)
|
||||
|
||||
consider_edge_cases:
|
||||
@@ -763,13 +751,11 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
|
||||
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
|
||||
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
|
||||
|
||||
mov(var(iter_1_mask), esi) // Load mask values for the last loop
|
||||
kmovw(esi, K(1))
|
||||
|
||||
mov( var( iter_1_mask ), esi ) // Load mask values for the last loop
|
||||
kmovw( esi, K(1) )
|
||||
|
||||
mov( imm( 0 ), r15 ) // jj = 0;
|
||||
label( .SLOOP3X4J ) // LOOP OVER jj = [ 0 1 ... ]
|
||||
|
||||
mov( var( abuf ), r14 ) // load address of a
|
||||
mov( var( bbuf ), rdx ) // load address of b
|
||||
mov( var( cbuf ), r12 ) // load address of c
|
||||
@@ -782,24 +768,22 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
|
||||
imul( r9, rsi ) // rsi *= cs_b;
|
||||
lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b;
|
||||
|
||||
|
||||
mov( var( m_iter ), r11 ) // ii = m_iter;
|
||||
label( .SLOOP3X4I ) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
|
||||
lea( mem( r14 ), rax ) // load c to rcx
|
||||
lea( mem( r12 ), rcx ) // load a to rax
|
||||
lea( mem( r14 ), rax ) // load a to rax
|
||||
lea( mem( r12 ), rcx ) // load c to rcx
|
||||
lea( mem( rdx ), rbx ) // load b to rbx
|
||||
|
||||
lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b
|
||||
lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b
|
||||
lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_a
|
||||
lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_a
|
||||
|
||||
INIT_REG
|
||||
|
||||
mov( var( k_iter64 ), rsi ) // load k_iter
|
||||
mov( var( k_iter64 ), rsi ) // load k_iter
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_32 )
|
||||
|
||||
|
||||
label( .K_LOOP_ITER64 )
|
||||
|
||||
// ITER 0
|
||||
@@ -908,9 +892,6 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_16 )
|
||||
|
||||
|
||||
|
||||
|
||||
// ITER 0
|
||||
// load row from A
|
||||
vmovups( ( rax ), zmm0 )
|
||||
@@ -960,7 +941,6 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
|
||||
|
||||
add( imm( 16*4 ), rbx )
|
||||
|
||||
|
||||
label( .CONSIDER_K_ITER_16 )
|
||||
mov( var( k_iter16 ), rsi )
|
||||
test( rsi, rsi )
|
||||
@@ -1063,7 +1043,6 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
|
||||
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
|
||||
VFMA6( 17, 18, 19, 29, 30, 31 )
|
||||
|
||||
|
||||
label( .POST_ACCUM )
|
||||
|
||||
mov( var( beta ), rax ) // load address of beta
|
||||
@@ -1072,7 +1051,6 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
|
||||
vucomiss( xmm1, xmm0 ) // check if beta = 0
|
||||
je( .POST_ACCUM_STOR_BZ )
|
||||
|
||||
|
||||
// Accumulating & storing the results when beta != 0
|
||||
label( .POST_ACCUM_STOR )
|
||||
|
||||
@@ -1114,7 +1092,6 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
|
||||
|
||||
jmp( .SDONE )
|
||||
|
||||
|
||||
// Accumulating & storing the results when beta == 0
|
||||
label( .POST_ACCUM_STOR_BZ )
|
||||
|
||||
@@ -1153,7 +1130,7 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
|
||||
lea( mem( r14, r8, 4 ), r14 ) // a_ii = r14 += 6*rs_a
|
||||
|
||||
dec( r11 )
|
||||
jne( .SLOOP3X4I ) // iterate again if ii != 0.
|
||||
jne( .SLOOP3X4I ) // iterate again if ii != 0.
|
||||
|
||||
add( imm( 4 ), r15 )
|
||||
cmp( imm( 48 ), r15 )
|
||||
@@ -1201,7 +1178,7 @@ void bli_sgemmsup_rd_zen4_asm_6x48m
|
||||
"zmm16", "zmm17", "zmm18", "zmm19",
|
||||
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
|
||||
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
|
||||
"memory"
|
||||
"memory", "k1"
|
||||
)
|
||||
|
||||
consider_edge_cases:
|
||||
@@ -1325,9 +1302,8 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
|
||||
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
|
||||
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
|
||||
|
||||
mov(var(iter_1_mask), esi) // Load mask values for the last loop
|
||||
kmovw(esi, K(1))
|
||||
|
||||
mov( var( iter_1_mask ), esi ) // Load mask values for the last loop
|
||||
kmovw( esi, K(1) )
|
||||
|
||||
mov( imm(0), r15 ) // jj = 0;
|
||||
label( .SLOOP3X4J ) // LOOP OVER jj = [ 0 1 ... ]
|
||||
@@ -1344,24 +1320,22 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
|
||||
imul( r9, rsi ) // rsi *= cs_b;
|
||||
lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*jj*cs_b;
|
||||
|
||||
|
||||
mov( var( m_iter ), r11 ) // ii = m_iter;
|
||||
label( .SLOOP3X4I ) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
|
||||
lea( mem( r14 ), rax ) // load c to rcx
|
||||
lea( mem( r12 ), rcx ) // load a to rax
|
||||
lea( mem( r14 ), rax ) // load a to rax
|
||||
lea( mem( r12 ), rcx ) // load c to rcx
|
||||
lea( mem( rdx ), rbx ) // load b to rbx
|
||||
|
||||
lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_b
|
||||
lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_b
|
||||
lea( mem( r8, r8, 2 ), r10 ) // r10 = 3 * rs_a
|
||||
lea( mem( r10, r8, 2 ), rdi ) // rdi = 5 * rs_a
|
||||
|
||||
INIT_REG
|
||||
|
||||
mov( var( k_iter64 ), rsi ) // load k_iter
|
||||
mov( var( k_iter64 ), rsi ) // load k_iter
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_32 )
|
||||
|
||||
|
||||
label( .K_LOOP_ITER64 )
|
||||
|
||||
// ITER 0
|
||||
@@ -1464,15 +1438,12 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
|
||||
dec( rsi )
|
||||
jne( .K_LOOP_ITER64 )
|
||||
|
||||
|
||||
label( .CONSIDER_K_ITER_32 )
|
||||
|
||||
mov( var( k_iter32 ), rsi ) // load k_iter
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_16 )
|
||||
|
||||
|
||||
|
||||
// ITER 0
|
||||
// load row from A
|
||||
vmovups( ( rax ), zmm0 )
|
||||
@@ -1522,7 +1493,6 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
|
||||
|
||||
add( imm( 16*4 ), rbx )
|
||||
|
||||
|
||||
label( .CONSIDER_K_ITER_16 )
|
||||
mov( var( k_iter16 ), rsi )
|
||||
test( rsi, rsi )
|
||||
@@ -1625,7 +1595,6 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
|
||||
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
|
||||
VFMA6( 17, 18, 19, 29, 30, 31 )
|
||||
|
||||
|
||||
label( .POST_ACCUM )
|
||||
|
||||
mov( var( beta ), rax ) // load address of beta
|
||||
@@ -1675,7 +1644,6 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
|
||||
|
||||
jmp( .SDONE )
|
||||
|
||||
|
||||
// Accumulating & storing the results when beta == 0
|
||||
label( .POST_ACCUM_STOR_BZ )
|
||||
|
||||
@@ -1714,13 +1682,12 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
|
||||
lea( mem( r14, r8, 4 ), r14 ) // a_ii = r14 += 6*rs_a
|
||||
|
||||
dec( r11 )
|
||||
jne( .SLOOP3X4I ) // iterate again if ii != 0.
|
||||
jne( .SLOOP3X4I ) // iterate again if ii != 0.
|
||||
|
||||
add( imm( 4 ), r15 )
|
||||
cmp( imm( 32 ), r15 )
|
||||
jl( .SLOOP3X4J )
|
||||
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
: // input operands
|
||||
@@ -1763,7 +1730,7 @@ void bli_sgemmsup_rd_zen4_asm_6x32m
|
||||
"zmm16", "zmm17", "zmm18", "zmm19",
|
||||
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
|
||||
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
|
||||
"memory"
|
||||
"memory", "k1"
|
||||
)
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2023 - 2026, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -145,8 +145,8 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
|
||||
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
|
||||
|
||||
mov(var(iter_1_mask), esi) // Load mask values for the last loop
|
||||
kmovw(esi, K(1))
|
||||
mov( var( iter_1_mask ), esi ) // Load mask values for the last loop
|
||||
kmovw( esi, K(1) )
|
||||
|
||||
mov( imm( 0 ), r11 ) // ii = 0;
|
||||
label( .SLOOP3X4I ) // LOOP OVER ii = [ 0 1 ... ]
|
||||
@@ -166,7 +166,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
imul( r8, rsi ) // rsi *= cs_b;
|
||||
lea( mem( rdx, rsi, 1 ), rdx ) // rbx = b + 4*ii*cs_b;
|
||||
|
||||
|
||||
mov( var( n_iter ), r15 ) // jj = n_iter;
|
||||
label( .SLOOP3X4J ) // LOOP OVER jj = [ n_iter ... 1 0 ]
|
||||
|
||||
@@ -183,7 +182,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_32 )
|
||||
|
||||
|
||||
label( .K_LOOP_ITER64 )
|
||||
|
||||
// ITER 0
|
||||
@@ -246,7 +244,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
|
||||
// load column from B
|
||||
vmovups( ( rbx ), zmm6 )
|
||||
vmovups( ( rbx ), zmm6 )
|
||||
VFMA6( 8, 9, 10, 20, 21, 22 )
|
||||
|
||||
vmovups( ( rbx, r9, 1 ), zmm6 )
|
||||
@@ -287,16 +284,12 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
dec( rsi )
|
||||
jne( .K_LOOP_ITER64 )
|
||||
|
||||
|
||||
label( .CONSIDER_K_ITER_32 )
|
||||
|
||||
mov( var( k_iter32 ), rsi ) // load k_iter
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_16 )
|
||||
|
||||
|
||||
|
||||
|
||||
// ITER 0
|
||||
// load row from A
|
||||
vmovups( ( rax ), zmm0 )
|
||||
@@ -346,7 +339,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
|
||||
add( imm( 16*4 ), rbx )
|
||||
|
||||
|
||||
label( .CONSIDER_K_ITER_16 )
|
||||
mov( var( k_iter16 ), rsi )
|
||||
test( rsi, rsi )
|
||||
@@ -397,7 +389,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
cmp( imm(8), rsi )
|
||||
jle( .K_FLOATS_LEFT_LE_8 )
|
||||
|
||||
|
||||
label( .K_FLOATS_LEFT_GT_8 )
|
||||
|
||||
vmovups( mem(rax), ZMM(0 MASK_KZ(1) ) )
|
||||
@@ -450,7 +441,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
|
||||
VFMA6( 17, 18, 19, 29, 30, 31 )
|
||||
|
||||
|
||||
label( .POST_ACCUM )
|
||||
|
||||
mov( var( beta ), rax ) // load address of beta
|
||||
@@ -459,7 +449,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
vucomiss( xmm1, xmm0 ) // check if beta = 0
|
||||
je( .POST_ACCUM_STOR_BZ )
|
||||
|
||||
|
||||
// Accumulating & storing the results when beta != 0
|
||||
label( .POST_ACCUM_STOR )
|
||||
|
||||
@@ -501,7 +490,6 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
|
||||
jmp( .SDONE )
|
||||
|
||||
|
||||
// Accumulating & storing the results when beta == 0
|
||||
label( .POST_ACCUM_STOR_BZ )
|
||||
|
||||
@@ -529,14 +517,13 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
|
||||
C_STOR_BZ // Storing result to C
|
||||
|
||||
|
||||
label( .SDONE )
|
||||
|
||||
add( imm(4*4), r12 )
|
||||
lea(mem(r14, r9, 4), r14) // a_ii = r14 += 3*rs_a
|
||||
lea( mem(r14, r9, 4), r14 ) // a_ii = r14 += 3*rs_a
|
||||
|
||||
dec( r15 )
|
||||
jne( .SLOOP3X4J ) // iterate again if ii != 0.
|
||||
jne( .SLOOP3X4J ) // iterate again if ii != 0.
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
@@ -580,7 +567,7 @@ void bli_sgemmsup_rd_zen4_asm_6x64n
|
||||
"zmm16", "zmm17", "zmm18", "zmm19",
|
||||
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
|
||||
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
|
||||
"memory"
|
||||
"memory", "k1"
|
||||
)
|
||||
|
||||
consider_edge_cases:
|
||||
@@ -671,14 +658,13 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
|
||||
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
|
||||
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
|
||||
|
||||
mov(var(iter_1_mask), esi) // Load mask values for the last loop
|
||||
kmovw(esi, K(1))
|
||||
mov( var( iter_1_mask ), esi ) // Load mask values for the last loop
|
||||
kmovw( esi, K(1) )
|
||||
|
||||
mov( var( abuf ), rdx ) // load address of a
|
||||
mov( var( bbuf ), r14 ) // load address of b
|
||||
mov( var( cbuf ), r12 ) // load address of c
|
||||
|
||||
|
||||
mov( var( n_iter ), r15 ) // jj = m_iter;
|
||||
label( .SLOOP3X4J ) // LOOP OVER jj = [ m_iter ... 1 0 ]
|
||||
|
||||
@@ -691,11 +677,10 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
|
||||
|
||||
INIT_REG
|
||||
|
||||
mov( var( k_iter64 ), rsi ) // load k_iter
|
||||
mov( var( k_iter64 ), rsi ) // load k_iter
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_32 )
|
||||
|
||||
|
||||
label( .K_LOOP_ITER64 )
|
||||
|
||||
// ITER 0
|
||||
@@ -749,7 +734,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
|
||||
|
||||
// load column from B
|
||||
vmovups( ( rbx ), zmm6 )
|
||||
vmovups( ( rbx ), zmm6 )
|
||||
VFMA3( 8, 9, 10 )
|
||||
|
||||
vmovups( ( rbx, r9, 1 ), zmm6 )
|
||||
@@ -787,16 +771,12 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
|
||||
dec( rsi )
|
||||
jne( .K_LOOP_ITER64 )
|
||||
|
||||
|
||||
label( .CONSIDER_K_ITER_32 )
|
||||
|
||||
mov( var( k_iter32 ), rsi ) // load k_iter
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_16 )
|
||||
|
||||
|
||||
|
||||
|
||||
// ITER 0
|
||||
// load row from A
|
||||
vmovups( ( rax ), zmm0 )
|
||||
@@ -840,7 +820,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
|
||||
|
||||
add( imm( 16*4 ), rbx )
|
||||
|
||||
|
||||
label( .CONSIDER_K_ITER_16 )
|
||||
mov( var( k_iter16 ), rsi )
|
||||
test( rsi, rsi )
|
||||
@@ -888,7 +867,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
|
||||
cmp( imm(8), rsi )
|
||||
jle( .K_FLOATS_LEFT_LE_8 )
|
||||
|
||||
|
||||
label( .K_FLOATS_LEFT_GT_8 )
|
||||
|
||||
vmovups( mem(rax), ZMM(0 MASK_KZ(1) ) )
|
||||
@@ -935,7 +913,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
|
||||
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
|
||||
VFMA3( 17, 18, 19 )
|
||||
|
||||
|
||||
label( .POST_ACCUM )
|
||||
|
||||
mov( var( beta ), rax ) // load address of beta
|
||||
@@ -944,7 +921,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
|
||||
vucomiss( xmm1, xmm0 ) // check if beta = 0
|
||||
je( .POST_ACCUM_STOR_BZ )
|
||||
|
||||
|
||||
// Accumulating & storing the results when beta != 0
|
||||
label( .POST_ACCUM_STOR )
|
||||
|
||||
@@ -962,7 +938,6 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
|
||||
|
||||
jmp( .SDONE )
|
||||
|
||||
|
||||
// Accumulating & storing the results when beta == 0
|
||||
label( .POST_ACCUM_STOR_BZ )
|
||||
|
||||
@@ -1027,7 +1002,7 @@ void bli_sgemmsup_rd_zen4_asm_3x64n
|
||||
"zmm16", "zmm17", "zmm18", "zmm19",
|
||||
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
|
||||
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
|
||||
"memory"
|
||||
"memory", "k1"
|
||||
)
|
||||
|
||||
consider_edge_cases:
|
||||
@@ -1120,14 +1095,13 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
|
||||
lea( mem( , r10, 4 ), r10 ) // cs_a *= sizeof(dt) => cs_a *= 4
|
||||
lea( mem( r9, r9, 2 ), r13 ) // r13 = 3 * rs_b
|
||||
|
||||
mov(var(iter_1_mask), esi) // Load mask values for the last loop
|
||||
kmovw(esi, K(1))
|
||||
mov( var( iter_1_mask ), esi ) // Load mask values for the last loop
|
||||
kmovw( esi, K(1) )
|
||||
|
||||
mov( var( abuf ), rdx ) // load address of a
|
||||
mov( var( bbuf ), r14 ) // load address of b
|
||||
mov( var( cbuf ), r12 ) // load address of c
|
||||
|
||||
|
||||
mov( var( n_iter ), r15 ) // jj = m_iter;
|
||||
label( .SLOOP3X4J ) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
|
||||
@@ -1140,11 +1114,10 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
|
||||
|
||||
INIT_REG
|
||||
|
||||
mov( var( k_iter64 ), rsi ) // load k_iter
|
||||
mov( var( k_iter64 ), rsi ) // load k_iter
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_32 )
|
||||
|
||||
|
||||
label( .K_LOOP_ITER64 )
|
||||
|
||||
// ITER 0
|
||||
@@ -1195,7 +1168,6 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
|
||||
|
||||
// load column from B
|
||||
vmovups( ( rbx ), zmm6 )
|
||||
vmovups( ( rbx ), zmm6 )
|
||||
VFMA2( 8, 9 )
|
||||
|
||||
vmovups( ( rbx, r9, 1 ), zmm6 )
|
||||
@@ -1232,16 +1204,12 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
|
||||
dec( rsi )
|
||||
jne( .K_LOOP_ITER64 )
|
||||
|
||||
|
||||
label( .CONSIDER_K_ITER_32 )
|
||||
|
||||
mov( var( k_iter32 ), rsi ) // load k_iter
|
||||
test( rsi, rsi )
|
||||
je( .CONSIDER_K_ITER_16 )
|
||||
|
||||
|
||||
|
||||
|
||||
// ITER 0
|
||||
// load row from A
|
||||
vmovups( ( rax ), zmm0 )
|
||||
@@ -1283,7 +1251,6 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
|
||||
|
||||
add( imm( 16*4 ), rbx )
|
||||
|
||||
|
||||
label( .CONSIDER_K_ITER_16 )
|
||||
mov( var( k_iter16 ), rsi )
|
||||
test( rsi, rsi )
|
||||
@@ -1317,7 +1284,6 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
|
||||
|
||||
add( imm( 16*4 ), rbx )
|
||||
|
||||
|
||||
label( .CONSIDER_K_LEFT_1 )
|
||||
mov( var(k_left1), rsi )
|
||||
test( rsi, rsi )
|
||||
@@ -1375,7 +1341,6 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
|
||||
vmovups( mem(rbx, r13, 1), YMM(6 MASK_KZ(1) ) )
|
||||
VFMA2( 17, 18 )
|
||||
|
||||
|
||||
label( .POST_ACCUM )
|
||||
|
||||
mov( var( beta ), rax ) // load address of beta
|
||||
@@ -1385,7 +1350,6 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
|
||||
vucomiss( xmm1, xmm0 ) // check if beta = 0
|
||||
je( .POST_ACCUM_STOR_BZ )
|
||||
|
||||
|
||||
// Accumulating & storing the results when beta != 0
|
||||
label( .POST_ACCUM_STOR )
|
||||
|
||||
@@ -1401,11 +1365,9 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
|
||||
|
||||
jmp( .SDONE )
|
||||
|
||||
|
||||
// Accumulating & storing the results when beta == 0
|
||||
label( .POST_ACCUM_STOR_BZ )
|
||||
|
||||
|
||||
ZMM_TO_YMM( 8, 9, 11, 12, 4, 5, 7, 8 )
|
||||
ZMM_TO_YMM( 14, 15, 17, 18, 10, 11, 13, 14 )
|
||||
|
||||
@@ -1416,16 +1378,14 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
|
||||
|
||||
C_STOR_BZ2 // Storing result to C
|
||||
|
||||
|
||||
label( .SDONE )
|
||||
|
||||
add( imm(4*4), r12 )
|
||||
lea(mem(r14, r9, 4), r14) // a_ii = r14 += 3*rs_a
|
||||
lea( mem(r14, r9, 4), r14 ) // a_ii = r14 += 3*rs_a
|
||||
|
||||
dec( r15 )
|
||||
jne( .SLOOP3X4J ) // iterate again if jj != 0.
|
||||
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
: // input operands
|
||||
@@ -1466,7 +1426,7 @@ void bli_sgemmsup_rd_zen4_asm_2x64n
|
||||
"zmm16", "zmm17", "zmm18", "zmm19",
|
||||
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
|
||||
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31",
|
||||
"memory"
|
||||
"memory", "k1"
|
||||
)
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
Reference in New Issue
Block a user