mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Fix for F32 to BF16 Conversion and AVX512 ISA Support Checks
- Fixed register assignment bug in lpgemv_m_kernel_f32_avx512 where zmm3 was incorrectly used instead of zmm4 in BF16_F32_BETA_OP_NLT16F_MASK macro. - Replaced hardware-specific BF16 conversion intrinsics with manual rounding, bit manipulation and F32 instruction set for compatibility on hardware without native BF16 support. - Added AVX512_BF16 ISA support checks for s8s8s32obf16 and u8s8s32obf16 GEMM operations to ensure processor compatibility before execution. AMD-Internal: [CPUPL-7410]
This commit is contained in:
@@ -65,7 +65,15 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform s8s8s32 gemm.", __FILE__, __LINE__ );
|
||||
"cannot perform s8s8s32obf16 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Check for avx512_bf16 ISA support necessary for BF16.
|
||||
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform s8s8s32obf16 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
|
||||
@@ -67,6 +67,15 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,bfloat16,int32_t,u8s8s32obf16)
|
||||
"cannot perform u8s8s32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Check for avx512_bf16 ISA support necessary for BF16.
|
||||
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform u8s8s32obf16 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
bli_print_msg("cannot perform u8s8s32obf16 gemm with gcc < 11.2",
|
||||
__FILE__, __LINE__ );
|
||||
|
||||
@@ -258,10 +258,10 @@ LPGEMV_M_EQ1_KERN( float, float, float, f32f32f32of32 )
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) )
|
||||
{
|
||||
BF16_F32_BETA_OP_NLT16F_MASK(k1, zmm8, 0, 0, zmm0,zmm3);
|
||||
BF16_F32_BETA_OP_NLT16F_MASK(k2, zmm12, 0, 1, zmm1,zmm3);
|
||||
BF16_F32_BETA_OP_NLT16F_MASK(k3, zmm16, 0, 2, zmm2,zmm3);
|
||||
BF16_F32_BETA_OP_NLT16F_MASK(k4, zmm20, 0, 3, zmm3,zmm3);
|
||||
BF16_F32_BETA_OP_NLT16F_MASK(k1, zmm8, 0, 0, zmm0, zmm3);
|
||||
BF16_F32_BETA_OP_NLT16F_MASK(k2, zmm12, 0, 1, zmm1, zmm3);
|
||||
BF16_F32_BETA_OP_NLT16F_MASK(k3, zmm16, 0, 2, zmm2, zmm3);
|
||||
BF16_F32_BETA_OP_NLT16F_MASK(k4, zmm20, 0, 3, zmm4, zmm3);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -835,33 +835,33 @@ LPGEMV_M_EQ1_KERN( float, float, float, f32f32f32of32 )
|
||||
{
|
||||
if ( post_ops_attr.buf_downscale != NULL )
|
||||
{
|
||||
_mm256_mask_storeu_epi16
|
||||
(
|
||||
( bfloat16* )post_ops_attr.buf_downscale +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ),
|
||||
k1, (__m256i) _mm512_cvtneps_pbh( zmm8 )
|
||||
);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i, chunk;
|
||||
bfloat16* dest;
|
||||
|
||||
_mm256_mask_storeu_epi16
|
||||
(
|
||||
( bfloat16* )post_ops_attr.buf_downscale +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ),
|
||||
k2, (__m256i) _mm512_cvtneps_pbh( zmm12 )
|
||||
);
|
||||
dim_t full_iters = nr0 / 16;
|
||||
dim_t partial_iters = nr0 % 16;
|
||||
|
||||
_mm256_mask_storeu_epi16
|
||||
(
|
||||
( bfloat16* )post_ops_attr.buf_downscale +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ),
|
||||
k3, (__m256i) _mm512_cvtneps_pbh( zmm16 )
|
||||
);
|
||||
// masks and zmm_regs respective to each chunk.
|
||||
__mmask16 masks[4] = {k1, k2, k3, k4};
|
||||
__m512 zmm_regs[4] = {zmm8, zmm12, zmm16, zmm20};
|
||||
|
||||
for (chunk = 0; chunk < 4; ++chunk) {
|
||||
dim_t chunk_size = (chunk < full_iters) ? 16 :
|
||||
(chunk == full_iters) ? partial_iters : 0;
|
||||
|
||||
if (chunk_size == 0) break;
|
||||
|
||||
_mm256_mask_storeu_epi16
|
||||
(
|
||||
( bfloat16* )post_ops_attr.buf_downscale +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ),
|
||||
k4, (__m256i) _mm512_cvtneps_pbh( zmm20 )
|
||||
);
|
||||
_mm512_mask_storeu_ps((float*)temp, masks[chunk], zmm_regs[chunk]);
|
||||
dest = (bfloat16*)post_ops_attr.buf_downscale +
|
||||
post_ops_attr.post_op_c_j + (chunk * 16);
|
||||
|
||||
for (i = 0; i < chunk_size; ++i) {
|
||||
tlsb = (temp[i] & (uint32_t)0x00010000) > 16;
|
||||
rounded = temp[i] + (uint32_t)0x00007FFF + tlsb;
|
||||
memcpy((dest+i), ((char*)(&rounded))+2, sizeof(bfloat16));
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -786,12 +786,19 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32 )
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_first_k == TRUE ) )
|
||||
{
|
||||
_mm256_mask_storeu_epi16
|
||||
(
|
||||
( bfloat16* )post_ops_attr.buf_downscale +
|
||||
post_ops_attr.post_op_c_i,
|
||||
k2, (__m256i) _mm512_cvtneps_pbh( zmm8 )
|
||||
);
|
||||
// Convert F32 to BF16 and store directly into memory using memcpy.
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
_mm512_mask_store_ps((float*)temp, k2, zmm8);
|
||||
dest = ( bfloat16* )post_ops_attr.buf_downscale +
|
||||
post_ops_attr.post_op_c_i;
|
||||
for (i = 0; i < mr0; ++i) {
|
||||
tlsb = ( temp[i] & ( uint32_t )0x00010000 ) > 16;
|
||||
rounded = temp[i] + ( uint32_t )0x00007FFF + tlsb;
|
||||
memcpy( (dest+i), ((char *)(&rounded))+2, sizeof(bfloat16));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -800,23 +807,27 @@ LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32 )
|
||||
}
|
||||
else
|
||||
{
|
||||
// Store ZMM8 into ctemp buffer and store back
|
||||
// element by element into output buffer at strides
|
||||
|
||||
if ( post_ops_attr.buf_downscale != NULL )
|
||||
{
|
||||
bfloat16 ctemp[16];
|
||||
_mm256_mask_storeu_epi16( ctemp, k2, ( __m256i )
|
||||
_mm512_cvtneps_pbh( zmm8 ) );
|
||||
for (dim_t i = 0; i < mr0; i++)
|
||||
{
|
||||
*( ( bfloat16* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
|
||||
// Convert F32 to BF16 and store directly into memory using memcpy.
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
|
||||
_mm512_mask_storeu_ps((float*)temp, k2, zmm8);
|
||||
for (i = 0; i < mr0; ++i) {
|
||||
tlsb = ( temp[i] & ( uint32_t )0x00010000 ) > 16;
|
||||
rounded = temp[i] + ( uint32_t )0x00007FFF + tlsb;
|
||||
memcpy( ( ( bfloat16* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) ),
|
||||
((char *)(&rounded)) + 2,
|
||||
sizeof(bfloat16));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Store ZMM8 into ctemp buffer and store back
|
||||
// element by element into output buffer at strides
|
||||
float ctemp[16];
|
||||
_mm512_mask_storeu_ps(ctemp, k2, zmm8);
|
||||
for (dim_t i = 0; i < mr0; i++)
|
||||
|
||||
Reference in New Issue
Block a user