mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Bugfix: Disable A Packing for FP32 RD kernels and Post-Ops Fix
- For single-threaded configuration of BLIS, packing of A and B matrices are enabled by default. But, packing of A is only supported for RV kernels where elements from matrix A are being broadcasted. Since elements are being loaded in RD kernels, packing of A results in failures. Hence, disabled packing of matrix A for RD kernels. - Fixed the issue where c_i index pointer was incorrectly being reset when exceeding MC block thus, resulting in failures for certain Post-Ops. - Fixed the FP32 reoder case were for n == 1 and rs_b == 1 condition, it was incorrectly using sizeof(BLIS_FLOAT) instead of sizeof(float). AMD-Internal: [SWLCSG-3497] Change-Id: I6d18afa996c253d79f666ea9789270bb59b629dd
This commit is contained in:
@@ -179,7 +179,7 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
|
||||
{
|
||||
if(rs_b == 1)
|
||||
{
|
||||
memcpy(reorder_buf_addr, input_buf_addr, (k * sizeof(BLIS_FLOAT)));
|
||||
memcpy(reorder_buf_addr, input_buf_addr, (k * sizeof(float)));
|
||||
}else
|
||||
{
|
||||
for(dim_t k0 = 0; k0 < k; k0++)
|
||||
|
||||
@@ -423,10 +423,11 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
|
||||
// Avoid packing of B in transb cases where rd kernels performs
|
||||
// better than rv + pack. rv kernel calls rd when rs_b==1.
|
||||
if( ( rs_b == 1 ) && ( mtag_b == PACK ) && ( mtag_a == UNPACKED ) )
|
||||
if( ( n < 48 ) && ( rs_b == 1 ) && ( mtag_b == PACK ) &&
|
||||
( mtag_a == UNPACKED ) )
|
||||
{
|
||||
if ( n < 48 ) mtag_b = UNPACKED;
|
||||
else if ( m < 25 ) mtag_b = UNPACKED;
|
||||
mtag_b = UNPACKED;
|
||||
should_pack_A = FALSE;
|
||||
}
|
||||
|
||||
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
|
||||
|
||||
@@ -173,6 +173,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m_rd)
|
||||
// Save c_j index for restoring later.
|
||||
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
|
||||
|
||||
// Save c_i index for restoring later.
|
||||
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
|
||||
|
||||
dim_t jj, ii;
|
||||
for ( jj = 0; jj < 16; jj += 4 ) // LOOP_6x16J
|
||||
{
|
||||
@@ -884,14 +887,14 @@ POST_OPS_6x16F_DISABLE:
|
||||
} // END LOOP_3x4I
|
||||
|
||||
post_ops_attr.post_op_c_j += 4;
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save;
|
||||
} // END LOOP_6x16J
|
||||
|
||||
// Reset the value of post_op_c_j to point to the beginning.
|
||||
post_ops_attr.post_op_c_j = post_op_c_j_save;
|
||||
|
||||
// Update the post_op_c_i value to account for the number of rows.
|
||||
post_ops_attr.post_op_c_i = 3 * m_iter; // Since each iteration processes 3 rows.
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save + 3 * m_iter; // Since each iteration processes 3 rows.
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
consider_edge_cases:
|
||||
@@ -979,6 +982,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x8m_rd)
|
||||
// Save c_j index for restoring later.
|
||||
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
|
||||
|
||||
// Save c_i index for restoring later.
|
||||
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
|
||||
|
||||
dim_t jj, ii;
|
||||
for ( jj = 0; jj < 8; jj += 4 ) // LOOP_6x8J
|
||||
{
|
||||
@@ -1689,14 +1695,14 @@ POST_OPS_6x8F_DISABLE:
|
||||
} // END LOOP_3x4I
|
||||
|
||||
post_ops_attr.post_op_c_j += 4;
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save;
|
||||
} // END LOOP_6x8J
|
||||
|
||||
// Reset the value of post_op_c_j to point to the beginning.
|
||||
post_ops_attr.post_op_c_j = post_op_c_j_save;
|
||||
|
||||
// Update the post_op_c_i value to account for the number of rows.
|
||||
post_ops_attr.post_op_c_i = 3 * m_iter; // Since each iteration processes 3 rows.
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save + 3 * m_iter; // Since each iteration processes 3 rows.
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
@@ -1783,6 +1789,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x4m_rd)
|
||||
// Save c_j index for restoring later.
|
||||
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
|
||||
|
||||
// Save c_i index for restoring later.
|
||||
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
|
||||
|
||||
dim_t jj, ii;
|
||||
for ( jj = 0; jj < 4; jj += 4 ) // LOOP_6x4J
|
||||
{
|
||||
@@ -2493,14 +2502,14 @@ POST_OPS_6x4F_DISABLE:
|
||||
} // END LOOP_3x4I
|
||||
|
||||
post_ops_attr.post_op_c_j += 4;
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save;
|
||||
} // END LOOP_6x4J
|
||||
|
||||
// Reset the value of post_op_c_j to point to the beginning.
|
||||
post_ops_attr.post_op_c_j = post_op_c_j_save;
|
||||
|
||||
// Update the post_op_c_i value to account for the number of rows.
|
||||
post_ops_attr.post_op_c_i = 3 * m_iter; // Since each iteration processes 3 rows.
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save + 3 * m_iter; // Since each iteration processes 3 rows.
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
@@ -2589,6 +2598,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x2m_rd)
|
||||
// Save c_j index for restoring later.
|
||||
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
|
||||
|
||||
// Save c_i index for restoring later.
|
||||
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
|
||||
|
||||
dim_t jj, ii;
|
||||
for ( jj = 0; jj < 4; jj += 4 ) // LOOP_6x2J
|
||||
{
|
||||
@@ -3262,14 +3274,14 @@ POST_OPS_6x2F_DISABLE:
|
||||
} // END LOOP_3x2I
|
||||
|
||||
post_ops_attr.post_op_c_j += 4;
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save;
|
||||
} // END LOOP_6x2J
|
||||
|
||||
// Reset the value of post_op_c_j to point to the beginning.
|
||||
post_ops_attr.post_op_c_j = post_op_c_j_save;
|
||||
|
||||
// Update the post_op_c_i value to account for the number of rows.
|
||||
post_ops_attr.post_op_c_i = 3 * m_iter; // Since each iteration processes 3 rows.
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save + 3 * m_iter; // Since each iteration processes 3 rows.
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
@@ -3357,6 +3369,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x1m_rd)
|
||||
// Save c_j index for restoring later.
|
||||
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
|
||||
|
||||
// Save c_i index for restoring later.
|
||||
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
|
||||
|
||||
dim_t jj, ii;
|
||||
for ( jj = 0; jj < 4; jj += 4 ) // LOOP_6x1J
|
||||
{
|
||||
@@ -4011,14 +4026,14 @@ POST_OPS_6x1F_DISABLE:
|
||||
} // END LOOP_3x1I
|
||||
|
||||
post_ops_attr.post_op_c_j += 4;
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save;
|
||||
} // END LOOP_6x1J
|
||||
|
||||
// Reset the value of post_op_c_j to point to the beginning.
|
||||
post_ops_attr.post_op_c_j = post_op_c_j_save;
|
||||
|
||||
// Update the post_op_c_i value to account for the number of rows.
|
||||
post_ops_attr.post_op_c_i = 3 * m_iter; // Since each iteration processes 3 rows.
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save + 3 * m_iter; // Since each iteration processes 3 rows.
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
|
||||
@@ -227,6 +227,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m_rd)
|
||||
// Save c_j index for restoring later.
|
||||
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
|
||||
|
||||
// Save c_i index for restoring later.
|
||||
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
|
||||
|
||||
dim_t jj, ii;
|
||||
for ( jj = 0; jj < 64; jj += 4 ) // LOOP_6x64J
|
||||
{
|
||||
@@ -1526,14 +1529,14 @@ POST_OPS_6x64F_DISABLE:
|
||||
} // END LOOP_6x4I
|
||||
|
||||
post_ops_attr.post_op_c_j += 4;
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save;
|
||||
} // END LOOP_6x64J
|
||||
|
||||
// Reset the value of post_op_c_j to point to the beginning.
|
||||
post_ops_attr.post_op_c_j = post_op_c_j_save;
|
||||
|
||||
// Update the post_op_c_i value to account for the number of rows.
|
||||
post_ops_attr.post_op_c_i = MR * m_iter;
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save + MR * m_iter;
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
@@ -1643,6 +1646,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x48m_rd)
|
||||
// Save c_j index for restoring later.
|
||||
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
|
||||
|
||||
// Save c_i index for restoring later.
|
||||
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
|
||||
|
||||
dim_t jj, ii;
|
||||
for ( jj = 0; jj < 48; jj += 4 ) // LOOP_6x48J
|
||||
{
|
||||
@@ -2945,14 +2951,14 @@ POST_OPS_6x48F_DISABLE:
|
||||
} // END LOOP_6x4I
|
||||
|
||||
post_ops_attr.post_op_c_j += 4;
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save;
|
||||
} // END LOOP_6x48J
|
||||
|
||||
// Reset the value of post_op_c_j to point to the beginning.
|
||||
post_ops_attr.post_op_c_j = post_op_c_j_save;
|
||||
|
||||
// Update the post_op_c_i value to account for the number of rows.
|
||||
post_ops_attr.post_op_c_i = MR * m_iter;
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save + MR * m_iter;
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
@@ -3062,6 +3068,9 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x32m_rd)
|
||||
// Save c_j index for restoring later.
|
||||
uint64_t post_op_c_j_save = post_ops_attr.post_op_c_j;
|
||||
|
||||
// Save c_i index for restoring later.
|
||||
uint64_t post_op_c_i_save = post_ops_attr.post_op_c_i;
|
||||
|
||||
dim_t jj, ii;
|
||||
for ( jj = 0; jj < 32; jj += 4 ) // LOOP_6x32J
|
||||
{
|
||||
@@ -4361,14 +4370,14 @@ POST_OPS_6x32F_DISABLE:
|
||||
} // END LOOP_6x4I
|
||||
|
||||
post_ops_attr.post_op_c_j += 4;
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save;
|
||||
} // END LOOP_6x32J
|
||||
|
||||
// Reset the value of post_op_c_j to point to the beginning.
|
||||
post_ops_attr.post_op_c_j = post_op_c_j_save;
|
||||
|
||||
// Update the post_op_c_i value to account for the number of rows.
|
||||
post_ops_attr.post_op_c_i = MR * m_iter;
|
||||
post_ops_attr.post_op_c_i = post_op_c_i_save + MR * m_iter;
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user