Bug fix in gemv_n kernel of f32 api.

Description:
1. For column major case when m=1 there was an accuracy mismatch with
   post ops(bias, matrix_add, matrix_add).
2. Added check for column major case and replace _mm512_loadu_ps with
   _mm512_maskz_loadu_ps.

AMD-Internal: [CPUPL-6585]

Change-Id: I8d98e2cb0b9dd445c9868f4c8af3abbc6c2dfc95
This commit is contained in:
Deepak Negi
2025-03-12 04:27:56 +05:30
parent 9f263d2445
commit fb4617d7c3

View File

@@ -216,7 +216,8 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32 )
zmm8 = _mm512_insertf32x4(zmm8, xmm1, 1);
zmm8 = _mm512_insertf32x4(zmm8, xmm2, 2);
zmm8 = _mm512_insertf32x4(zmm8, xmm3, 3);
}else
}
else
{
//Handle fringe cases when mr0 < MR
const float *a_use_fringe = a_use;
@@ -389,7 +390,8 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32 )
if (rs_c == 1)
{
zmm0 = _mm512_maskz_loadu_ps(k2, _cbuf);
}else
}
else
{
//load C into zmm0
float ctemp[16];
@@ -409,13 +411,23 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32 )
POST_OPS_BIAS_6x64F:
{
if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
zmm9 = _mm512_set1_ps(*( ( float * )post_ops_list_temp->op_args1 ) );
}
else
{
// If original output was columns major, then by the time
// kernel sees it, the matrix would be accessed as if it were
// transposed. Due to this the bias array will be accessed by
// the ic index, and each bias element corresponds to an
// entire row of the transposed output array, instead of an
// entire column.
zmm9 = _mm512_set1_ps(*((float *)post_ops_list_temp->op_args1));
zmm9 = _mm512_maskz_loadu_ps( k2,
( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i );
}
zmm8 = _mm512_add_ps(zmm9, zmm8);
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
@@ -500,13 +512,15 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32 )
// instead of an entire column.
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_i + 0 );
selector1 = _mm512_maskz_loadu_ps( k2,
( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_i );
}
if( *( dim_t*)post_ops_list_temp->op_args3 > 1 )
{
zero_point0 = _mm512_loadu_ps( (float *)post_ops_list_temp->op_args1 +
post_op_attr->post_op_c_i + 0 );
zero_point0 = _mm512_maskz_loadu_ps( k2,
( float * )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i );
}
F32_SCL_MULRND(zmm8, selector1, zero_point0);
}
@@ -544,7 +558,8 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32 )
if( ldm == 1 )
{
selector1 = _mm512_maskz_loadu_ps(k2, (matptr + post_ops_attr.post_op_c_i));
selector1 = _mm512_maskz_loadu_ps(k2, ( matptr +
post_ops_attr.post_op_c_i ) );
selector1 = _mm512_mul_ps( selector1, scl_fctr1 );
zmm8 = _mm512_add_ps( selector1, zmm8 );
@@ -595,7 +610,8 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32 )
if( ldm == 1 )
{
selector1 = _mm512_maskz_loadu_ps(k2, (matptr + post_ops_attr.post_op_c_i ) );
selector1 = _mm512_maskz_loadu_ps(k2, ( matptr +
post_ops_attr.post_op_c_i ) );
selector1 = _mm512_mul_ps( selector1, scl_fctr1 );
zmm8 = _mm512_mul_ps( selector1, zmm8 );
}