Bugfix: Disable A Packing for FP32 RD kernels and Post-Ops Fix

- For single-threaded configuration of BLIS, packing of A and B matrices
  are enabled by default. But, packing of A is only supported for RV
  kernels where elements from matrix A are being broadcasted. Since
  elements are being loaded in RD kernels, packing of A results in
  failures. Hence, disabled packing of matrix A for RD kernels.

- Fixed the issue where c_i index pointer was incorrectly being reset
  when exceeding MC block thus, resulting in failures for certain
  Post-Ops.

- Fixed the FP32 reoder case were for n == 1 and rs_b == 1 condition, it
  was incorrectly using sizeof(BLIS_FLOAT) instead of sizeof(float).

AMD-Internal: [SWLCSG-3497]
Change-Id: I6d18afa996c253d79f666ea9789270bb59b629dd
This commit is contained in:
Arnav Sharma
2025-04-17 16:51:29 +05:30
parent 1ff96343f1
commit 87c9230cac
4 changed files with 45 additions and 20 deletions

View File

@@ -179,7 +179,7 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
{
if(rs_b == 1)
{
memcpy(reorder_buf_addr, input_buf_addr, (k * sizeof(BLIS_FLOAT)));
memcpy(reorder_buf_addr, input_buf_addr, (k * sizeof(float)));
}else
{
for(dim_t k0 = 0; k0 < k; k0++)

View File

@@ -423,10 +423,11 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
// Avoid packing of B in transb cases where rd kernels performs
// better than rv + pack. rv kernel calls rd when rs_b==1.
if( ( rs_b == 1 ) && ( mtag_b == PACK ) && ( mtag_a == UNPACKED ) )
if( ( n < 48 ) && ( rs_b == 1 ) && ( mtag_b == PACK ) &&
( mtag_a == UNPACKED ) )
{
if ( n < 48 ) mtag_b = UNPACKED;
else if ( m < 25 ) mtag_b = UNPACKED;
mtag_b = UNPACKED;
should_pack_A = FALSE;
}
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )

View File

@@ -173,6 +173,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m_rd)
// Save c_j index for restoring later.
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
// Save c_i index for restoring later.
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
dim_t jj, ii;
for ( jj = 0; jj < 16; jj += 4 ) // LOOP_6x16J
{
@@ -884,14 +887,14 @@ POST_OPS_6x16F_DISABLE:
} // END LOOP_3x4I
post_ops_attr.post_op_c_j += 4;
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_i = post_op_c_i_save;
} // END LOOP_6x16J
// Reset the value of post_op_c_j to point to the beginning.
post_ops_attr.post_op_c_j = post_op_c_j_save;
// Update the post_op_c_i value to account for the number of rows.
post_ops_attr.post_op_c_i = 3 * m_iter; // Since each iteration processes 3 rows.
post_ops_attr.post_op_c_i = post_op_c_i_save + 3 * m_iter; // Since each iteration processes 3 rows.
// -------------------------------------------------------------------------
consider_edge_cases:
@@ -979,6 +982,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x8m_rd)
// Save c_j index for restoring later.
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
// Save c_i index for restoring later.
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
dim_t jj, ii;
for ( jj = 0; jj < 8; jj += 4 ) // LOOP_6x8J
{
@@ -1689,14 +1695,14 @@ POST_OPS_6x8F_DISABLE:
} // END LOOP_3x4I
post_ops_attr.post_op_c_j += 4;
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_i = post_op_c_i_save;
} // END LOOP_6x8J
// Reset the value of post_op_c_j to point to the beginning.
post_ops_attr.post_op_c_j = post_op_c_j_save;
// Update the post_op_c_i value to account for the number of rows.
post_ops_attr.post_op_c_i = 3 * m_iter; // Since each iteration processes 3 rows.
post_ops_attr.post_op_c_i = post_op_c_i_save + 3 * m_iter; // Since each iteration processes 3 rows.
consider_edge_cases:
@@ -1783,6 +1789,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x4m_rd)
// Save c_j index for restoring later.
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
// Save c_i index for restoring later.
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
dim_t jj, ii;
for ( jj = 0; jj < 4; jj += 4 ) // LOOP_6x4J
{
@@ -2493,14 +2502,14 @@ POST_OPS_6x4F_DISABLE:
} // END LOOP_3x4I
post_ops_attr.post_op_c_j += 4;
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_i = post_op_c_i_save;
} // END LOOP_6x4J
// Reset the value of post_op_c_j to point to the beginning.
post_ops_attr.post_op_c_j = post_op_c_j_save;
// Update the post_op_c_i value to account for the number of rows.
post_ops_attr.post_op_c_i = 3 * m_iter; // Since each iteration processes 3 rows.
post_ops_attr.post_op_c_i = post_op_c_i_save + 3 * m_iter; // Since each iteration processes 3 rows.
consider_edge_cases:
@@ -2589,6 +2598,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x2m_rd)
// Save c_j index for restoring later.
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
// Save c_i index for restoring later.
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
dim_t jj, ii;
for ( jj = 0; jj < 4; jj += 4 ) // LOOP_6x2J
{
@@ -3262,14 +3274,14 @@ POST_OPS_6x2F_DISABLE:
} // END LOOP_3x2I
post_ops_attr.post_op_c_j += 4;
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_i = post_op_c_i_save;
} // END LOOP_6x2J
// Reset the value of post_op_c_j to point to the beginning.
post_ops_attr.post_op_c_j = post_op_c_j_save;
// Update the post_op_c_i value to account for the number of rows.
post_ops_attr.post_op_c_i = 3 * m_iter; // Since each iteration processes 3 rows.
post_ops_attr.post_op_c_i = post_op_c_i_save + 3 * m_iter; // Since each iteration processes 3 rows.
consider_edge_cases:
@@ -3357,6 +3369,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x1m_rd)
// Save c_j index for restoring later.
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
// Save c_i index for restoring later.
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
dim_t jj, ii;
for ( jj = 0; jj < 4; jj += 4 ) // LOOP_6x1J
{
@@ -4011,14 +4026,14 @@ POST_OPS_6x1F_DISABLE:
} // END LOOP_3x1I
post_ops_attr.post_op_c_j += 4;
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_i = post_op_c_i_save;
} // END LOOP_6x1J
// Reset the value of post_op_c_j to point to the beginning.
post_ops_attr.post_op_c_j = post_op_c_j_save;
// Update the post_op_c_i value to account for the number of rows.
post_ops_attr.post_op_c_i = 3 * m_iter; // Since each iteration processes 3 rows.
post_ops_attr.post_op_c_i = post_op_c_i_save + 3 * m_iter; // Since each iteration processes 3 rows.
consider_edge_cases:

View File

@@ -227,6 +227,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m_rd)
// Save c_j index for restoring later.
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
// Save c_i index for restoring later.
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
dim_t jj, ii;
for ( jj = 0; jj < 64; jj += 4 ) // LOOP_6x64J
{
@@ -1526,14 +1529,14 @@ POST_OPS_6x64F_DISABLE:
} // END LOOP_6x4I
post_ops_attr.post_op_c_j += 4;
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_i = post_op_c_i_save;
} // END LOOP_6x64J
// Reset the value of post_op_c_j to point to the beginning.
post_ops_attr.post_op_c_j = post_op_c_j_save;
// Update the post_op_c_i value to account for the number of rows.
post_ops_attr.post_op_c_i = MR * m_iter;
post_ops_attr.post_op_c_i = post_op_c_i_save + MR * m_iter;
consider_edge_cases:
@@ -1643,6 +1646,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x48m_rd)
// Save c_j index for restoring later.
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
// Save c_i index for restoring later.
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
dim_t jj, ii;
for ( jj = 0; jj < 48; jj += 4 ) // LOOP_6x48J
{
@@ -2945,14 +2951,14 @@ POST_OPS_6x48F_DISABLE:
} // END LOOP_6x4I
post_ops_attr.post_op_c_j += 4;
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_i = post_op_c_i_save;
} // END LOOP_6x48J
// Reset the value of post_op_c_j to point to the beginning.
post_ops_attr.post_op_c_j = post_op_c_j_save;
// Update the post_op_c_i value to account for the number of rows.
post_ops_attr.post_op_c_i = MR * m_iter;
post_ops_attr.post_op_c_i = post_op_c_i_save + MR * m_iter;
consider_edge_cases:
@@ -3062,6 +3068,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x32m_rd)
// Save c_j index for restoring later.
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
// Save c_i index for restoring later.
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
dim_t jj, ii;
for ( jj = 0; jj < 32; jj += 4 ) // LOOP_6x32J
{
@@ -4361,14 +4370,14 @@ POST_OPS_6x32F_DISABLE:
} // END LOOP_6x4I
post_ops_attr.post_op_c_j += 4;
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_i = post_op_c_i_save;
} // END LOOP_6x32J
// Reset the value of post_op_c_j to point to the beginning.
post_ops_attr.post_op_c_j = post_op_c_j_save;
// Update the post_op_c_i value to account for the number of rows.
post_ops_attr.post_op_c_i = MR * m_iter;
post_ops_attr.post_op_c_i = post_op_c_i_save + MR * m_iter;
consider_edge_cases: