Fix for F32 to BF16 Conversion and AVX512 ISA Support Checks

- Fixed register assignment bug in lpgemv_m_kernel_f32_avx512 where zmm3
  was incorrectly used instead of zmm4 in BF16_F32_BETA_OP_NLT16F_MASK macro.

- Replaced hardware-specific BF16 conversion intrinsics with manual
  rounding, bit manipulation and F32 instruction set for compatibility on
  hardware without native BF16 support.

- Added AVX512_BF16 ISA support checks for s8s8s32obf16 and u8s8s32obf16
  GEMM operations to ensure processor compatibility before execution.

AMD-Internal: [CPUPL-7410]
This commit is contained in:
Sharma, Arnav
2025-09-19 18:49:33 +05:30
committed by GitHub
parent ec5cf7d174
commit ee3d250b7a
4 changed files with 74 additions and 46 deletions

View File

@@ -65,7 +65,15 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform s8s8s32 gemm.", __FILE__, __LINE__ );
"cannot perform s8s8s32obf16 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
// Check for avx512_bf16 ISA support necessary for BF16.
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform s8s8s32obf16 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}

View File

@@ -67,6 +67,15 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,bfloat16,int32_t,u8s8s32obf16)
"cannot perform u8s8s32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
// Check for avx512_bf16 ISA support necessary for BF16.
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform u8s8s32obf16 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
#ifdef LPGEMM_BF16_JIT
bli_print_msg("cannot perform u8s8s32obf16 gemm with gcc < 11.2",
__FILE__, __LINE__ );

View File

@@ -258,10 +258,10 @@ LPGEMV_M_EQ1_KERN( float, float, float, f32f32f32of32 )
if ( ( post_ops_attr.buf_downscale != NULL ) )
{
BF16_F32_BETA_OP_NLT16F_MASK(k1, zmm8, 0, 0, zmm0,zmm3);
BF16_F32_BETA_OP_NLT16F_MASK(k2, zmm12, 0, 1, zmm1,zmm3);
BF16_F32_BETA_OP_NLT16F_MASK(k3, zmm16, 0, 2, zmm2,zmm3);
BF16_F32_BETA_OP_NLT16F_MASK(k4, zmm20, 0, 3, zmm3,zmm3);
BF16_F32_BETA_OP_NLT16F_MASK(k1, zmm8, 0, 0, zmm0, zmm3);
BF16_F32_BETA_OP_NLT16F_MASK(k2, zmm12, 0, 1, zmm1, zmm3);
BF16_F32_BETA_OP_NLT16F_MASK(k3, zmm16, 0, 2, zmm2, zmm3);
BF16_F32_BETA_OP_NLT16F_MASK(k4, zmm20, 0, 3, zmm4, zmm3);
}
else
{
@@ -835,33 +835,33 @@ LPGEMV_M_EQ1_KERN( float, float, float, f32f32f32of32 )
{
if ( post_ops_attr.buf_downscale != NULL )
{
_mm256_mask_storeu_epi16
(
( bfloat16* )post_ops_attr.buf_downscale +
post_ops_attr.post_op_c_j + ( 0 * 16 ),
k1, (__m256i) _mm512_cvtneps_pbh( zmm8 )
);
uint32_t tlsb, rounded, temp[16] = {0};
int i, chunk;
bfloat16* dest;
_mm256_mask_storeu_epi16
(
( bfloat16* )post_ops_attr.buf_downscale +
post_ops_attr.post_op_c_j + ( 1 * 16 ),
k2, (__m256i) _mm512_cvtneps_pbh( zmm12 )
);
dim_t full_iters = nr0 / 16;
dim_t partial_iters = nr0 % 16;
_mm256_mask_storeu_epi16
(
( bfloat16* )post_ops_attr.buf_downscale +
post_ops_attr.post_op_c_j + ( 2 * 16 ),
k3, (__m256i) _mm512_cvtneps_pbh( zmm16 )
);
// masks and zmm_regs respective to each chunk.
__mmask16 masks[4] = {k1, k2, k3, k4};
__m512 zmm_regs[4] = {zmm8, zmm12, zmm16, zmm20};
for (chunk = 0; chunk < 4; ++chunk) {
dim_t chunk_size = (chunk < full_iters) ? 16 :
(chunk == full_iters) ? partial_iters : 0;
if (chunk_size == 0) break;
_mm256_mask_storeu_epi16
(
( bfloat16* )post_ops_attr.buf_downscale +
post_ops_attr.post_op_c_j + ( 3 * 16 ),
k4, (__m256i) _mm512_cvtneps_pbh( zmm20 )
);
_mm512_mask_storeu_ps((float*)temp, masks[chunk], zmm_regs[chunk]);
dest = (bfloat16*)post_ops_attr.buf_downscale +
post_ops_attr.post_op_c_j + (chunk * 16);
for (i = 0; i < chunk_size; ++i) {
tlsb = (temp[i] & (uint32_t)0x00010000) > 16;
rounded = temp[i] + (uint32_t)0x00007FFF + tlsb;
memcpy((dest+i), ((char*)(&rounded))+2, sizeof(bfloat16));
}
}
}
else
{

View File

@@ -786,12 +786,19 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32 )
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_first_k == TRUE ) )
{
_mm256_mask_storeu_epi16
(
( bfloat16* )post_ops_attr.buf_downscale +
post_ops_attr.post_op_c_i,
k2, (__m256i) _mm512_cvtneps_pbh( zmm8 )
);
// Convert F32 to BF16 and store directly into memory using memcpy.
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
_mm512_mask_store_ps((float*)temp, k2, zmm8);
dest = ( bfloat16* )post_ops_attr.buf_downscale +
post_ops_attr.post_op_c_i;
for (i = 0; i < mr0; ++i) {
tlsb = ( temp[i] & ( uint32_t )0x00010000 ) > 16;
rounded = temp[i] + ( uint32_t )0x00007FFF + tlsb;
memcpy( (dest+i), ((char *)(&rounded))+2, sizeof(bfloat16));
}
}
else
{
@@ -800,23 +807,27 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32 )
}
else
{
// Store ZMM8 into ctemp buffer and store back
// element by element into output buffer at strides
if ( post_ops_attr.buf_downscale != NULL )
{
bfloat16 ctemp[16];
_mm256_mask_storeu_epi16( ctemp, k2, ( __m256i )
_mm512_cvtneps_pbh( zmm8 ) );
for (dim_t i = 0; i < mr0; i++)
{
*( ( bfloat16* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
// Convert F32 to BF16 and store directly into memory using memcpy.
uint32_t tlsb, rounded, temp[16] = {0};
int i;
_mm512_mask_storeu_ps((float*)temp, k2, zmm8);
for (i = 0; i < mr0; ++i) {
tlsb = ( temp[i] & ( uint32_t )0x00010000 ) > 16;
rounded = temp[i] + ( uint32_t )0x00007FFF + tlsb;
memcpy( ( ( bfloat16* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) ),
((char *)(&rounded)) + 2,
sizeof(bfloat16));
}
}
else
{
// Store ZMM8 into ctemp buffer and store back
// element by element into output buffer at strides
float ctemp[16];
_mm512_mask_storeu_ps(ctemp, k2, zmm8);
for (dim_t i = 0; i < mr0; i++)