Bug Fixes for GEMV AVX2 BF16 to F32 path

- Added the correct strides to be used while unreorder/convert B matrix in m=1 cases.
 - Modified Zero point vector loads to proper instructions.
 - Modified bf16 store in AVX2 GEMV M kenrel

AMD Internal - [SWLCSG - 3602 ]
This commit is contained in:
V, Varsha
2025-07-10 16:23:46 +05:30
committed by GitHub
parent ab4bb2f1e8
commit 837d3974d4
5 changed files with 113 additions and 115 deletions

View File

@@ -1004,6 +1004,15 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
dim_t jc_cur_loop_rem = 0;
dim_t n_sub_updated = 0;
dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR );
mem_b_size_req = sizeof( float ) * nc0_updated * k_updated;
lpgemm_alloc_mem_panel
(
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
&mem_b, rntm
);
if (mtag_b == REORDERED)
{
get_B_panel_reordered_start_offset_width
@@ -1021,17 +1030,6 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
dim_t kc0_updated = kc0;
kc0_updated += ( kc0_updated & 0x1 );
dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR );
mem_b_size_req = sizeof( float ) * nc0_updated * k_updated;
n_sub_updated = nc0_updated;
lpgemm_alloc_mem_panel
(
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
&mem_b, rntm
);
if( mtag_b == REORDERED )
{
float *b_unreorder =
@@ -1043,8 +1041,8 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
unpackb_nr64_bf16_f32
(
( b + ( jc_cur_loop * k_updated ) +
( jc_cur_loop_rem * kc0_updated) ),
( b_unreorder ),
( jc_cur_loop_rem * kc0_updated) + ( n_sub_updated * pc )),
( b_unreorder + (nc0 * pc)),
kc0, nc0 ,
rs_b_use, cs_b_use, FALSE
);
@@ -1059,7 +1057,7 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
cvt_bf16_f32
(
( cvt_b_buffer_bf16_f32 ),
( cvt_b_buffer_bf16_f32 + (nc0 * pc) ),
( b + ( rs_b * pc ) + ( cs_b * jc ) ), rs_b, cs_b,
kc0, nc0,
rs_b_use, cs_b_use

View File

@@ -233,9 +233,42 @@ multiply with Beta, and add to alpha*A*B*/
); \
ymm2 = _mm256_fmadd_ps(ymm0, beta, ymm2); \
#define BF16_F32_C_BNZ_GEMV_MASK(m_ind,n_ind,ymm0,beta,ymm2,mask) \
load_mask = _mm256_cvtepi32_epi16(mask); \
BF16_F32_C_BNZ_8_MASK(m_ind,n_ind,ymm0,beta,ymm2,load_mask) \
#define BF16_F32_C_BNZ_GEMV_MASK(n_ind,ymm0,beta,ymm2,n_elems) \
{\
int16_t data_feeder[8] = {0}; \
bfloat16 *post_op_ptr = ( bfloat16* )( post_ops_attr.buf_downscale) + \
( post_ops_attr.post_op_c_j + ( n_ind * 8 ) ); \
for(dim_t i = 0; i < n_elems; i++) data_feeder[i] = *(post_op_ptr + i); \
ymm0 = (__m256)_mm256_sllv_epi32 \
( \
_mm256_cvtepi16_epi32 \
( \
_mm_loadu_si128 \
( \
( __m128i const* )( data_feeder ) \
) \
), _mm256_set1_epi32( 16 ) \
); \
ymm2 = _mm256_fmadd_ps(ymm0, beta, ymm2); \
}\
#define BF16_F32_BIAS_AVX2_GEMV_MASK(n_ind,ymm0,n_elems) \
{\
int16_t data_feeder[8] = {0}; \
bfloat16 *post_op_ptr = ( bfloat16* )( post_ops_attr.buf_downscale) + \
( post_ops_attr.post_op_c_j + ( n_ind * 8 ) ); \
for(dim_t i = 0; i < n_elems; i++) data_feeder[i] = *(post_op_ptr + i); \
ymm0 = (__m256)_mm256_sllv_epi32 \
( \
_mm256_cvtepi16_epi32 \
( \
_mm_loadu_si128 \
( \
( __m128i const* )( data_feeder ) \
) \
), _mm256_set1_epi32( 16 ) \
); \
}\
/*Load C from buf_downscale and convert to F32,
multiply with Beta, and add to alpha*A*B*/
@@ -681,7 +714,7 @@ multiply with Beta, and add to alpha*A*B*/
( \
_mm256_cvtepi16_epi32 \
( \
_mm_load_si128 \
_mm_loadu_si128 \
( \
( __m128i const* )( \
( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
@@ -761,6 +794,20 @@ multiply with Beta, and add to alpha*A*B*/
); \
}
#define BF16_F32_BIAS_LOAD_AVX2_GEMV(scr,n_ind) \
scr = (__m256)( _mm256_sllv_epi32 \
( \
_mm256_cvtepi16_epi32 \
( \
_mm_loadu_si128( \
( __m128i const* )( \
( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_i + ( n_ind * 8 ) ) \
) \
), _mm256_set1_epi32( 16 ) \
) \
); \
#define BF16_F32_BIAS_BCAST_LT4BF16_AVX2(scr,m_ind) \
{ \
scr = (__m128)_mm_sllv_epi32 \
@@ -776,6 +823,19 @@ multiply with Beta, and add to alpha*A*B*/
); \
}
#define BF16_F32_BIAS_BCAST_AVX2_GEMV(scr) \
scr = (__m256)( _mm256_sllv_epi32 \
( \
_mm256_cvtepi16_epi32 \
( \
_mm_set1_epi16 \
( \
*( ( bfloat16* )post_ops_list_temp->op_args1 ) \
) \
), _mm256_set1_epi32( 16 ) \
) \
);
#define STORE_F32_BF16_N_ONE_YMM( temp, m_ele ) \
{ \
dest = ( bfloat16* )post_ops_attr.buf_downscale + \
@@ -871,17 +931,7 @@ multiply with Beta, and add to alpha*A*B*/
/*Downscale Zeropoint BF16->F32 Helpers*/
#define BF16_F32_ZP_SCALAR_BCAST_AVX2(scr) \
scr = (__m256)( _mm256_sllv_epi32 \
( \
_mm256_cvtepi16_epi32 \
( \
_mm_set1_epi16 \
( \
*( ( bfloat16* )post_ops_list_temp->op_args1 ) \
) \
), _mm256_set1_epi32( 16 ) \
) \
);
BF16_F32_BIAS_BCAST_AVX2_GEMV( scr )
#define BF16_F32_ZP_VECTOR_BCAST_AVX2(scr, m_ind) \
BF16_F32_BIAS_BCAST_AVX2(scr,m_ind);
@@ -892,20 +942,8 @@ multiply with Beta, and add to alpha*A*B*/
#define BF16_F32_ZP_VECTOR_LOAD_AVX2_MASK(scr,n_ind,mask) \
BF16_F32_BIAS_LOAD_AVX2_MASK(scr,n_ind,mask)
#define BF16_F32_BIAS_LOAD_AVX2_MASK_GEMV(scr,n_ind,mask) \
scr = (__m256)( _mm256_sllv_epi32 \
( \
_mm256_cvtepi16_epi32 \
( \
_mm_maskload_epi32 \
( \
( int const* )( \
( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_i + ( n_ind * 8 ) ) \
, mask ) \
), _mm256_set1_epi32( 16 ) \
) \
); \
#define BF16_F32_ZP_VECTOR_LOAD_AVX2_GEMV(scr,n_ind) \
BF16_F32_BIAS_LOAD_AVX2_GEMV(scr,n_ind) \
#define BF16_F32_ZP_SCALAR_BCAST_SSE(scr) \
scr = (__m128)_mm_sllv_epi32 \

View File

@@ -111,15 +111,22 @@ void lpgemv_m_one_f32f32f32of32_avx2_LT16
ZERO_ACC_YMM_4_REG(ymm12, ymm13, ymm14, ymm15);
dim_t n_left = n0 % 8;
// n1, n2 holds the n_elems values.
dim_t n1 = 0, n2 = 0;
if (nr0 < 8)
{
k1 = masks[n_left];
n1 = n_left;
k2 =masks[0];
n2 = 0;
}
else
{
k1 = masks[8];
n1 = 8;
k2 = masks[n_left];
n2 = n_left;
}
_mm_prefetch((c_use + 0 * rs_c), _MM_HINT_T0);
@@ -239,8 +246,8 @@ void lpgemv_m_one_f32f32f32of32_avx2_LT16
ymm3 = _mm256_set1_ps( beta );
if( post_ops_attr.buf_downscale != NULL )
{
BF16_F32_C_BNZ_8(0,0,ymm0, ymm3,ymm8)
BF16_F32_C_BNZ_8(0,1,ymm1, ymm3,ymm12)
BF16_F32_C_BNZ_GEMV_MASK(0, ymm0, ymm3, ymm8, n1)
BF16_F32_C_BNZ_GEMV_MASK(1, ymm1, ymm3, ymm12, n2)
}
else
{
@@ -264,8 +271,10 @@ void lpgemv_m_one_f32f32f32of32_avx2_LT16
{
if( post_ops_list_temp->stor_type == BF16 )
{
BF16_F32_BIAS_LOAD_AVX2( ymm0, 0 );
BF16_F32_BIAS_LOAD_AVX2( ymm1, 1 );
BF16_F32_BIAS_AVX2_GEMV_MASK(0, ymm0, n1 )
BF16_F32_BIAS_AVX2_GEMV_MASK(1, ymm1, n2 )
// BF16_F32_BIAS_LOAD_AVX2( ymm0, 0 );
// BF16_F32_BIAS_LOAD_AVX2( ymm1, 1 );
}
else
{

View File

@@ -369,11 +369,8 @@ POST_OPS_BIAS_1x16F:
{
if( post_ops_list_temp->stor_type == BF16 )
{
ymm0 = (__m256)( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
_mm_set1_epi16(
*( ( bfloat16* )post_ops_list_temp->op_args1 )
) ), _mm256_set1_epi32( 16 ) )
);
BF16_F32_BIAS_BCAST_AVX2_GEMV( ymm0)
}
else
{
@@ -390,13 +387,7 @@ POST_OPS_BIAS_1x16F:
// entire column.
if( post_ops_list_temp->stor_type == BF16 )
{
__m128i bias_mask = _mm_loadu_si128((__m128i*)mask[mr0]);
ymm0 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
_mm_maskload_epi32(
( int const* )( ( ( bfloat16* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i )
, bias_mask ) ), _mm256_set1_epi32( 16 ) )
);
BF16_F32_BIAS_LOAD_AVX2_GEMV( ymm0, 0 );
}
else
{
@@ -500,13 +491,7 @@ POST_OPS_DOWNSCALE_1x16F:
{
if( is_bf16 == TRUE )
{
__m128i zp_mask = _mm_loadu_si128((__m128i*)mask[mr0]);
zero_point0 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
_mm_maskload_epi32(
( int const* )( ( ( bfloat16* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i )
, zp_mask ) ), _mm256_set1_epi32( 16 ) )
);
BF16_F32_ZP_VECTOR_LOAD_AVX2_GEMV( ymm0, 0 )
}
else
{

View File

@@ -422,17 +422,17 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32_avx512_256 )
if ( post_ops_attr.buf_downscale != NULL )
{
ymm0 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
_mm_loadu_si128(
( __m128i const* )( ( ( bfloat16* )post_ops_attr.buf_downscale ) +
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + 0 ) )
+ post_ops_attr.post_op_c_j + (0 * 8) ) ) ), _mm256_set1_epi32( 16 ) )
);
_mm_loadu_si128( ( __m128i const* )
( ( ( bfloat16* )post_ops_attr.buf_downscale ) +
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i ) )
+ (0 * 8) ) ) ), _mm256_set1_epi32( 16 ) )
);
ymm1 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
_mm_loadu_si128(
( __m128i const* )( ( ( bfloat16* )post_ops_attr.buf_downscale ) +
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + 0 ) )
+ post_ops_attr.post_op_c_j + (1 * 8) ) ) ), _mm256_set1_epi32( 16 ) )
);
_mm_loadu_si128( ( __m128i const* )
( ( ( bfloat16* )post_ops_attr.buf_downscale ) +
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i ) )
+ (1 * 8) ) ) ), _mm256_set1_epi32( 16 ) )
);
}
else
{
@@ -492,16 +492,8 @@ POST_OPS_BIAS_1x32F:
{
if( post_ops_list_temp->stor_type == BF16 )
{
ymm0 = (__m256)( _mm256_sllv_epi32( _mm256_cvtepi16_epi32 \
( _mm_set1_epi16 \
( *( ( bfloat16* )post_ops_list_temp->op_args1 )
) ), _mm256_set1_epi32( 16 ) )
);
ymm1 = (__m256)( _mm256_sllv_epi32( _mm256_cvtepi16_epi32
( _mm_set1_epi16
( *( ( bfloat16* )post_ops_list_temp->op_args1 )
) ), _mm256_set1_epi32( 16 ) )
);
BF16_F32_BIAS_BCAST_AVX2_GEMV( ymm0)
BF16_F32_BIAS_BCAST_AVX2_GEMV( ymm1)
}
else
{
@@ -519,20 +511,8 @@ POST_OPS_BIAS_1x32F:
// entire column.
if( post_ops_list_temp->stor_type == BF16 )
{
__m128i bias_mask1 = _mm256_cvtepi32_epi16(store_mask1);
__m128i bias_mask2 = _mm256_cvtepi32_epi16(store_mask2);
ymm0 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
_mm_maskload_epi32(
( int const* )( ( ( bfloat16* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i )
, bias_mask1 ) ), _mm256_set1_epi32( 16 ) )
);
ymm1 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
_mm_maskload_epi32(
( int const* )( ( ( bfloat16* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 8)
, bias_mask2 ) ), _mm256_set1_epi32( 16 ) )
);
BF16_F32_BIAS_LOAD_AVX2_GEMV( ymm0, 0 )
BF16_F32_BIAS_LOAD_AVX2_GEMV( ymm1, 1 )
}
else
{
@@ -632,15 +612,6 @@ POST_OPS_DOWNSCALE_1x32F:
zero_point1 = _mm256_set1_ps( *(float *)post_ops_list_temp->op_args1 );
}
}
else
{
// If original output was columns major, then by the time
// kernel sees it, the matrix would be accessed as if it were
// transposed. Due to this the scale as well as zp array will
// be accessed by the ic index, and each scale/zp element
// corresponds to an entire row of the transposed output array,
// instead of an entire column.
}
if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
@@ -665,12 +636,10 @@ POST_OPS_DOWNSCALE_1x32F:
}
if( *( dim_t*)post_ops_list_temp->op_args3 > 1 )
{
__m128i zp_mask1 = _mm256_cvtepi32_epi16(store_mask1);
__m128i zp_mask2 = _mm256_cvtepi32_epi16(store_mask2);
if ( is_bf16 == TRUE )
{
BF16_F32_BIAS_LOAD_AVX2_MASK_GEMV(zero_point0,0,zp_mask1)
BF16_F32_BIAS_LOAD_AVX2_MASK_GEMV(zero_point1,1,zp_mask2)
BF16_F32_ZP_VECTOR_LOAD_AVX2_GEMV(zero_point0, 0)
BF16_F32_ZP_VECTOR_LOAD_AVX2_GEMV(zero_point1, 1)
}
else
{
@@ -975,7 +944,6 @@ POST_OPS_1x32F_DISABLE:
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
if( rs_c == 1 )
{
_mm256_maskstore_ps((float*)temp, store_mask1, ymm30);
@@ -986,7 +954,7 @@ POST_OPS_1x32F_DISABLE:
else
{
_mm256_storeu_ps((float*)temp, ymm30);
_mm256_storeu_ps((float*)(temp+8), ymm31);
_mm256_storeu_ps((float*)(temp + 8), ymm31);
STORE_F32_BF16_N_ONE_YMM( temp, mr0 )
}