mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Bug Fixes for GEMV AVX2 BF16 to F32 path
- Added the correct strides to be used while unreorder/convert B matrix in m=1 cases. - Modified Zero point vector loads to proper instructions. - Modified bf16 store in AVX2 GEMV M kenrel AMD Internal - [SWLCSG - 3602 ]
This commit is contained in:
@@ -1004,6 +1004,15 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
|
||||
dim_t jc_cur_loop_rem = 0;
|
||||
dim_t n_sub_updated = 0;
|
||||
|
||||
dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR );
|
||||
mem_b_size_req = sizeof( float ) * nc0_updated * k_updated;
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
|
||||
&mem_b, rntm
|
||||
);
|
||||
|
||||
if (mtag_b == REORDERED)
|
||||
{
|
||||
get_B_panel_reordered_start_offset_width
|
||||
@@ -1021,17 +1030,6 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
|
||||
dim_t kc0_updated = kc0;
|
||||
kc0_updated += ( kc0_updated & 0x1 );
|
||||
|
||||
dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR );
|
||||
mem_b_size_req = sizeof( float ) * nc0_updated * k_updated;
|
||||
|
||||
n_sub_updated = nc0_updated;
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
|
||||
&mem_b, rntm
|
||||
);
|
||||
|
||||
if( mtag_b == REORDERED )
|
||||
{
|
||||
float *b_unreorder =
|
||||
@@ -1043,8 +1041,8 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
|
||||
unpackb_nr64_bf16_f32
|
||||
(
|
||||
( b + ( jc_cur_loop * k_updated ) +
|
||||
( jc_cur_loop_rem * kc0_updated) ),
|
||||
( b_unreorder ),
|
||||
( jc_cur_loop_rem * kc0_updated) + ( n_sub_updated * pc )),
|
||||
( b_unreorder + (nc0 * pc)),
|
||||
kc0, nc0 ,
|
||||
rs_b_use, cs_b_use, FALSE
|
||||
);
|
||||
@@ -1059,7 +1057,7 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
|
||||
|
||||
cvt_bf16_f32
|
||||
(
|
||||
( cvt_b_buffer_bf16_f32 ),
|
||||
( cvt_b_buffer_bf16_f32 + (nc0 * pc) ),
|
||||
( b + ( rs_b * pc ) + ( cs_b * jc ) ), rs_b, cs_b,
|
||||
kc0, nc0,
|
||||
rs_b_use, cs_b_use
|
||||
|
||||
@@ -233,9 +233,42 @@ multiply with Beta, and add to alpha*A*B*/
|
||||
); \
|
||||
ymm2 = _mm256_fmadd_ps(ymm0, beta, ymm2); \
|
||||
|
||||
#define BF16_F32_C_BNZ_GEMV_MASK(m_ind,n_ind,ymm0,beta,ymm2,mask) \
|
||||
load_mask = _mm256_cvtepi32_epi16(mask); \
|
||||
BF16_F32_C_BNZ_8_MASK(m_ind,n_ind,ymm0,beta,ymm2,load_mask) \
|
||||
#define BF16_F32_C_BNZ_GEMV_MASK(n_ind,ymm0,beta,ymm2,n_elems) \
|
||||
{\
|
||||
int16_t data_feeder[8] = {0}; \
|
||||
bfloat16 *post_op_ptr = ( bfloat16* )( post_ops_attr.buf_downscale) + \
|
||||
( post_ops_attr.post_op_c_j + ( n_ind * 8 ) ); \
|
||||
for(dim_t i = 0; i < n_elems; i++) data_feeder[i] = *(post_op_ptr + i); \
|
||||
ymm0 = (__m256)_mm256_sllv_epi32 \
|
||||
( \
|
||||
_mm256_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm_loadu_si128 \
|
||||
( \
|
||||
( __m128i const* )( data_feeder ) \
|
||||
) \
|
||||
), _mm256_set1_epi32( 16 ) \
|
||||
); \
|
||||
ymm2 = _mm256_fmadd_ps(ymm0, beta, ymm2); \
|
||||
}\
|
||||
|
||||
#define BF16_F32_BIAS_AVX2_GEMV_MASK(n_ind,ymm0,n_elems) \
|
||||
{\
|
||||
int16_t data_feeder[8] = {0}; \
|
||||
bfloat16 *post_op_ptr = ( bfloat16* )( post_ops_attr.buf_downscale) + \
|
||||
( post_ops_attr.post_op_c_j + ( n_ind * 8 ) ); \
|
||||
for(dim_t i = 0; i < n_elems; i++) data_feeder[i] = *(post_op_ptr + i); \
|
||||
ymm0 = (__m256)_mm256_sllv_epi32 \
|
||||
( \
|
||||
_mm256_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm_loadu_si128 \
|
||||
( \
|
||||
( __m128i const* )( data_feeder ) \
|
||||
) \
|
||||
), _mm256_set1_epi32( 16 ) \
|
||||
); \
|
||||
}\
|
||||
|
||||
/*Load C from buf_downscale and convert to F32,
|
||||
multiply with Beta, and add to alpha*A*B*/
|
||||
@@ -681,7 +714,7 @@ multiply with Beta, and add to alpha*A*B*/
|
||||
( \
|
||||
_mm256_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm_load_si128 \
|
||||
_mm_loadu_si128 \
|
||||
( \
|
||||
( __m128i const* )( \
|
||||
( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
|
||||
@@ -761,6 +794,20 @@ multiply with Beta, and add to alpha*A*B*/
|
||||
); \
|
||||
}
|
||||
|
||||
#define BF16_F32_BIAS_LOAD_AVX2_GEMV(scr,n_ind) \
|
||||
scr = (__m256)( _mm256_sllv_epi32 \
|
||||
( \
|
||||
_mm256_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm_loadu_si128( \
|
||||
( __m128i const* )( \
|
||||
( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_i + ( n_ind * 8 ) ) \
|
||||
) \
|
||||
), _mm256_set1_epi32( 16 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
#define BF16_F32_BIAS_BCAST_LT4BF16_AVX2(scr,m_ind) \
|
||||
{ \
|
||||
scr = (__m128)_mm_sllv_epi32 \
|
||||
@@ -776,6 +823,19 @@ multiply with Beta, and add to alpha*A*B*/
|
||||
); \
|
||||
}
|
||||
|
||||
#define BF16_F32_BIAS_BCAST_AVX2_GEMV(scr) \
|
||||
scr = (__m256)( _mm256_sllv_epi32 \
|
||||
( \
|
||||
_mm256_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm_set1_epi16 \
|
||||
( \
|
||||
*( ( bfloat16* )post_ops_list_temp->op_args1 ) \
|
||||
) \
|
||||
), _mm256_set1_epi32( 16 ) \
|
||||
) \
|
||||
);
|
||||
|
||||
#define STORE_F32_BF16_N_ONE_YMM( temp, m_ele ) \
|
||||
{ \
|
||||
dest = ( bfloat16* )post_ops_attr.buf_downscale + \
|
||||
@@ -871,17 +931,7 @@ multiply with Beta, and add to alpha*A*B*/
|
||||
|
||||
/*Downscale Zeropoint BF16->F32 Helpers*/
|
||||
#define BF16_F32_ZP_SCALAR_BCAST_AVX2(scr) \
|
||||
scr = (__m256)( _mm256_sllv_epi32 \
|
||||
( \
|
||||
_mm256_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm_set1_epi16 \
|
||||
( \
|
||||
*( ( bfloat16* )post_ops_list_temp->op_args1 ) \
|
||||
) \
|
||||
), _mm256_set1_epi32( 16 ) \
|
||||
) \
|
||||
);
|
||||
BF16_F32_BIAS_BCAST_AVX2_GEMV( scr )
|
||||
|
||||
#define BF16_F32_ZP_VECTOR_BCAST_AVX2(scr, m_ind) \
|
||||
BF16_F32_BIAS_BCAST_AVX2(scr,m_ind);
|
||||
@@ -892,20 +942,8 @@ multiply with Beta, and add to alpha*A*B*/
|
||||
#define BF16_F32_ZP_VECTOR_LOAD_AVX2_MASK(scr,n_ind,mask) \
|
||||
BF16_F32_BIAS_LOAD_AVX2_MASK(scr,n_ind,mask)
|
||||
|
||||
#define BF16_F32_BIAS_LOAD_AVX2_MASK_GEMV(scr,n_ind,mask) \
|
||||
scr = (__m256)( _mm256_sllv_epi32 \
|
||||
( \
|
||||
_mm256_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm_maskload_epi32 \
|
||||
( \
|
||||
( int const* )( \
|
||||
( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_i + ( n_ind * 8 ) ) \
|
||||
, mask ) \
|
||||
), _mm256_set1_epi32( 16 ) \
|
||||
) \
|
||||
); \
|
||||
#define BF16_F32_ZP_VECTOR_LOAD_AVX2_GEMV(scr,n_ind) \
|
||||
BF16_F32_BIAS_LOAD_AVX2_GEMV(scr,n_ind) \
|
||||
|
||||
#define BF16_F32_ZP_SCALAR_BCAST_SSE(scr) \
|
||||
scr = (__m128)_mm_sllv_epi32 \
|
||||
|
||||
@@ -111,15 +111,22 @@ void lpgemv_m_one_f32f32f32of32_avx2_LT16
|
||||
ZERO_ACC_YMM_4_REG(ymm12, ymm13, ymm14, ymm15);
|
||||
|
||||
dim_t n_left = n0 % 8;
|
||||
// n1, n2 holds the n_elems values.
|
||||
dim_t n1 = 0, n2 = 0;
|
||||
|
||||
if (nr0 < 8)
|
||||
{
|
||||
k1 = masks[n_left];
|
||||
n1 = n_left;
|
||||
k2 =masks[0];
|
||||
n2 = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
k1 = masks[8];
|
||||
n1 = 8;
|
||||
k2 = masks[n_left];
|
||||
n2 = n_left;
|
||||
}
|
||||
|
||||
_mm_prefetch((c_use + 0 * rs_c), _MM_HINT_T0);
|
||||
@@ -239,8 +246,8 @@ void lpgemv_m_one_f32f32f32of32_avx2_LT16
|
||||
ymm3 = _mm256_set1_ps( beta );
|
||||
if( post_ops_attr.buf_downscale != NULL )
|
||||
{
|
||||
BF16_F32_C_BNZ_8(0,0,ymm0, ymm3,ymm8)
|
||||
BF16_F32_C_BNZ_8(0,1,ymm1, ymm3,ymm12)
|
||||
BF16_F32_C_BNZ_GEMV_MASK(0, ymm0, ymm3, ymm8, n1)
|
||||
BF16_F32_C_BNZ_GEMV_MASK(1, ymm1, ymm3, ymm12, n2)
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -264,8 +271,10 @@ void lpgemv_m_one_f32f32f32of32_avx2_LT16
|
||||
{
|
||||
if( post_ops_list_temp->stor_type == BF16 )
|
||||
{
|
||||
BF16_F32_BIAS_LOAD_AVX2( ymm0, 0 );
|
||||
BF16_F32_BIAS_LOAD_AVX2( ymm1, 1 );
|
||||
BF16_F32_BIAS_AVX2_GEMV_MASK(0, ymm0, n1 )
|
||||
BF16_F32_BIAS_AVX2_GEMV_MASK(1, ymm1, n2 )
|
||||
// BF16_F32_BIAS_LOAD_AVX2( ymm0, 0 );
|
||||
// BF16_F32_BIAS_LOAD_AVX2( ymm1, 1 );
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -369,11 +369,8 @@ POST_OPS_BIAS_1x16F:
|
||||
{
|
||||
if( post_ops_list_temp->stor_type == BF16 )
|
||||
{
|
||||
ymm0 = (__m256)( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
|
||||
_mm_set1_epi16(
|
||||
*( ( bfloat16* )post_ops_list_temp->op_args1 )
|
||||
) ), _mm256_set1_epi32( 16 ) )
|
||||
);
|
||||
BF16_F32_BIAS_BCAST_AVX2_GEMV( ymm0)
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -390,13 +387,7 @@ POST_OPS_BIAS_1x16F:
|
||||
// entire column.
|
||||
if( post_ops_list_temp->stor_type == BF16 )
|
||||
{
|
||||
__m128i bias_mask = _mm_loadu_si128((__m128i*)mask[mr0]);
|
||||
ymm0 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
|
||||
_mm_maskload_epi32(
|
||||
( int const* )( ( ( bfloat16* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i )
|
||||
, bias_mask ) ), _mm256_set1_epi32( 16 ) )
|
||||
);
|
||||
BF16_F32_BIAS_LOAD_AVX2_GEMV( ymm0, 0 );
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -500,13 +491,7 @@ POST_OPS_DOWNSCALE_1x16F:
|
||||
{
|
||||
if( is_bf16 == TRUE )
|
||||
{
|
||||
__m128i zp_mask = _mm_loadu_si128((__m128i*)mask[mr0]);
|
||||
zero_point0 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
|
||||
_mm_maskload_epi32(
|
||||
( int const* )( ( ( bfloat16* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i )
|
||||
, zp_mask ) ), _mm256_set1_epi32( 16 ) )
|
||||
);
|
||||
BF16_F32_ZP_VECTOR_LOAD_AVX2_GEMV( ymm0, 0 )
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -422,17 +422,17 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32_avx512_256 )
|
||||
if ( post_ops_attr.buf_downscale != NULL )
|
||||
{
|
||||
ymm0 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
|
||||
_mm_loadu_si128(
|
||||
( __m128i const* )( ( ( bfloat16* )post_ops_attr.buf_downscale ) +
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + 0 ) )
|
||||
+ post_ops_attr.post_op_c_j + (0 * 8) ) ) ), _mm256_set1_epi32( 16 ) )
|
||||
);
|
||||
_mm_loadu_si128( ( __m128i const* )
|
||||
( ( ( bfloat16* )post_ops_attr.buf_downscale ) +
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i ) )
|
||||
+ (0 * 8) ) ) ), _mm256_set1_epi32( 16 ) )
|
||||
);
|
||||
ymm1 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
|
||||
_mm_loadu_si128(
|
||||
( __m128i const* )( ( ( bfloat16* )post_ops_attr.buf_downscale ) +
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + 0 ) )
|
||||
+ post_ops_attr.post_op_c_j + (1 * 8) ) ) ), _mm256_set1_epi32( 16 ) )
|
||||
);
|
||||
_mm_loadu_si128( ( __m128i const* )
|
||||
( ( ( bfloat16* )post_ops_attr.buf_downscale ) +
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i ) )
|
||||
+ (1 * 8) ) ) ), _mm256_set1_epi32( 16 ) )
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -492,16 +492,8 @@ POST_OPS_BIAS_1x32F:
|
||||
{
|
||||
if( post_ops_list_temp->stor_type == BF16 )
|
||||
{
|
||||
ymm0 = (__m256)( _mm256_sllv_epi32( _mm256_cvtepi16_epi32 \
|
||||
( _mm_set1_epi16 \
|
||||
( *( ( bfloat16* )post_ops_list_temp->op_args1 )
|
||||
) ), _mm256_set1_epi32( 16 ) )
|
||||
);
|
||||
ymm1 = (__m256)( _mm256_sllv_epi32( _mm256_cvtepi16_epi32
|
||||
( _mm_set1_epi16
|
||||
( *( ( bfloat16* )post_ops_list_temp->op_args1 )
|
||||
) ), _mm256_set1_epi32( 16 ) )
|
||||
);
|
||||
BF16_F32_BIAS_BCAST_AVX2_GEMV( ymm0)
|
||||
BF16_F32_BIAS_BCAST_AVX2_GEMV( ymm1)
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -519,20 +511,8 @@ POST_OPS_BIAS_1x32F:
|
||||
// entire column.
|
||||
if( post_ops_list_temp->stor_type == BF16 )
|
||||
{
|
||||
__m128i bias_mask1 = _mm256_cvtepi32_epi16(store_mask1);
|
||||
__m128i bias_mask2 = _mm256_cvtepi32_epi16(store_mask2);
|
||||
ymm0 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
|
||||
_mm_maskload_epi32(
|
||||
( int const* )( ( ( bfloat16* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i )
|
||||
, bias_mask1 ) ), _mm256_set1_epi32( 16 ) )
|
||||
);
|
||||
ymm1 = ( __m256 )( _mm256_sllv_epi32( _mm256_cvtepi16_epi32(
|
||||
_mm_maskload_epi32(
|
||||
( int const* )( ( ( bfloat16* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 8)
|
||||
, bias_mask2 ) ), _mm256_set1_epi32( 16 ) )
|
||||
);
|
||||
BF16_F32_BIAS_LOAD_AVX2_GEMV( ymm0, 0 )
|
||||
BF16_F32_BIAS_LOAD_AVX2_GEMV( ymm1, 1 )
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -632,15 +612,6 @@ POST_OPS_DOWNSCALE_1x32F:
|
||||
zero_point1 = _mm256_set1_ps( *(float *)post_ops_list_temp->op_args1 );
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// If original output was columns major, then by the time
|
||||
// kernel sees it, the matrix would be accessed as if it were
|
||||
// transposed. Due to this the scale as well as zp array will
|
||||
// be accessed by the ic index, and each scale/zp element
|
||||
// corresponds to an entire row of the transposed output array,
|
||||
// instead of an entire column.
|
||||
}
|
||||
if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
|
||||
{
|
||||
@@ -665,12 +636,10 @@ POST_OPS_DOWNSCALE_1x32F:
|
||||
}
|
||||
if( *( dim_t*)post_ops_list_temp->op_args3 > 1 )
|
||||
{
|
||||
__m128i zp_mask1 = _mm256_cvtepi32_epi16(store_mask1);
|
||||
__m128i zp_mask2 = _mm256_cvtepi32_epi16(store_mask2);
|
||||
if ( is_bf16 == TRUE )
|
||||
{
|
||||
BF16_F32_BIAS_LOAD_AVX2_MASK_GEMV(zero_point0,0,zp_mask1)
|
||||
BF16_F32_BIAS_LOAD_AVX2_MASK_GEMV(zero_point1,1,zp_mask2)
|
||||
BF16_F32_ZP_VECTOR_LOAD_AVX2_GEMV(zero_point0, 0)
|
||||
BF16_F32_ZP_VECTOR_LOAD_AVX2_GEMV(zero_point1, 1)
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -975,7 +944,6 @@ POST_OPS_1x32F_DISABLE:
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
if( rs_c == 1 )
|
||||
{
|
||||
_mm256_maskstore_ps((float*)temp, store_mask1, ymm30);
|
||||
@@ -986,7 +954,7 @@ POST_OPS_1x32F_DISABLE:
|
||||
else
|
||||
{
|
||||
_mm256_storeu_ps((float*)temp, ymm30);
|
||||
_mm256_storeu_ps((float*)(temp+8), ymm31);
|
||||
_mm256_storeu_ps((float*)(temp + 8), ymm31);
|
||||
|
||||
STORE_F32_BF16_N_ONE_YMM( temp, mr0 )
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user