mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Fixed compiler errors and warning for gcc < 11.2
Description: 1. When compiler gcc version less than 11.2 few BF16 instructions are not supported by the compiler even though the processors arch's zen4 and zen5 supports. 2. These instructions are guarded now with a macro. Change-Id: Ib07d41ff73d8fe14937af411843286c0e80c4131
This commit is contained in:
committed by
Nallani Bhaskar
parent
d61c54dc26
commit
17634d7ae8
@@ -169,6 +169,13 @@ AOCL_UTIL_ELTWISE_OPS(bfloat16,bfloat16,bf16obf16)
|
||||
b, ldb
|
||||
);
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
bli_print_msg("cannot perform the operation with gcc < 11.2",
|
||||
__FILE__, __LINE__ );
|
||||
return;
|
||||
#endif
|
||||
|
||||
|
||||
// Even though b matrix is typecasted to float*, actual load/store
|
||||
// and matrix traversal will happen as bfloat16* type. This typecast
|
||||
// is only to ensure code is reused.
|
||||
@@ -310,6 +317,13 @@ AOCL_UTIL_ELTWISE_OPS(float,bfloat16,f32obf16)
|
||||
b, ldb
|
||||
);
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
bli_print_msg("cannot perform the operation with gcc < 11.2",
|
||||
__FILE__, __LINE__ );
|
||||
return;
|
||||
#endif
|
||||
|
||||
|
||||
aocl_eltwise_ops_f32of32_base
|
||||
(
|
||||
order, transa, transb,
|
||||
|
||||
@@ -69,6 +69,13 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
bli_print_msg("cannot perform s8s8s32obf16 gemm with gcc < 11.2",
|
||||
__FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
#endif
|
||||
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
@@ -107,7 +114,7 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
|
||||
// The strides are set assuming a row major kernel.
|
||||
inc_t rs_a = lda;
|
||||
inc_t cs_a = 1;
|
||||
@@ -138,16 +145,16 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
|
||||
// Reorder is not supported for A matrix
|
||||
if ((is_row_major == TRUE) && (mtag_a == REORDERED))
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in "
|
||||
bli_print_msg(" Reordering of A matrix is not supported in "
|
||||
" row major case.", __FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
else if ((is_column_major == TRUE) &&
|
||||
else if ((is_column_major == TRUE) &&
|
||||
((mtag_b == REORDERED) || (mtag_a == REORDERED)))
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is "
|
||||
bli_print_msg(" Reordering of column major matrices is "
|
||||
" not supported.", __FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
@@ -67,6 +67,11 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,bfloat16,int32_t,u8s8s32obf16)
|
||||
"cannot perform u8s8s32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
bli_print_msg("cannot perform u8s8s32obf16 gemm with gcc < 11.2",
|
||||
__FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
#endif
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
@@ -44,7 +44,7 @@
|
||||
( ( __GNUC__ == 11 ) && ( __GNUC_MINOR__ < 2 ) ) ) && defined(BLIS_KERNELS_ZEN4) )
|
||||
#define LPGEMM_BF16_JIT
|
||||
#define BPREFETCH_JIT
|
||||
#define DUMP_JIT_CODE
|
||||
//#define DUMP_JIT_CODE
|
||||
#endif
|
||||
|
||||
typedef void (*lpgemm_m_fringe_f32_ker_ft)
|
||||
|
||||
@@ -231,7 +231,7 @@ static inline void fill_array_ ## ctype ( void* arr, dim_t size ) \
|
||||
{ \
|
||||
if( size < 0 ) return; \
|
||||
ctype* temp_arr = ( ctype* ) arr; \
|
||||
_Pragma( "omp parallel " ) \
|
||||
_Pragma( "omp parallel for" ) \
|
||||
for ( dim_t i = 0; i < size; ++i ) \
|
||||
{ \
|
||||
temp_arr[i] = ( ctype )( ( i % 11 ) - 5 ); \
|
||||
|
||||
@@ -1697,57 +1697,52 @@ POST_OPS_5x64_OPS_DISABLE:
|
||||
// final write for a given block within C.
|
||||
if ( post_ops_attr.c_stor_type == BF16 )
|
||||
{
|
||||
// Actually the b matrix is of type bfloat16. However
|
||||
// in order to reuse this kernel for f32, the output
|
||||
// matrix type in kernel function signature is set to
|
||||
// f32 irrespective of original output matrix type.
|
||||
bfloat16* b_q = ( bfloat16* )b;
|
||||
dim_t ir = 0;
|
||||
|
||||
// Store the results in downscaled type (bf16 instead of float).
|
||||
// c[0, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm8,k0,0,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm8,k0,0,0);
|
||||
// c[0, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm9,k1,0,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm9,k1,0,16);
|
||||
// c[0, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm10,k2,0,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm10,k2,0,32);
|
||||
// c[0, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm11,k3,0,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm11,k3,0,48);
|
||||
|
||||
// c[1, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm12,k0,1,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm12,k0,1,0);
|
||||
// c[1, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm13,k1,1,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm13,k1,1,16);
|
||||
// c[1, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm14,k2,1,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm14,k2,1,32);
|
||||
// c[1, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm15,k3,1,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm15,k3,1,48);
|
||||
|
||||
// c[2, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm16,k0,2,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm16,k0,2,0);
|
||||
// c[2, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm17,k1,2,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm17,k1,2,16);
|
||||
// c[2, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm18,k2,2,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm18,k2,2,32);
|
||||
// c[2, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm19,k3,2,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm19,k3,2,48);
|
||||
|
||||
// c[3, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm20,k0,3,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm20,k0,3,0);
|
||||
// c[3, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm21,k1,3,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm21,k1,3,16);
|
||||
// c[3, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm22,k2,3,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm22,k2,3,32);
|
||||
// c[3, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm23,k3,3,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm23,k3,3,48);
|
||||
|
||||
// c[4, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm24,k0,4,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm24,k0,4,0);
|
||||
// c[4, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm25,k1,4,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm25,k1,4,16);
|
||||
// c[4, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm26,k2,4,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm26,k2,4,32);
|
||||
// c[4, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm27,k3,4,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm27,k3,4,48);
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == S32 )
|
||||
{
|
||||
@@ -3411,48 +3406,43 @@ POST_OPS_4x64_OPS_DISABLE:
|
||||
// final write for a given block within C.
|
||||
if ( post_ops_attr.c_stor_type == BF16 )
|
||||
{
|
||||
// Actually the b matrix is of type bfloat16. However
|
||||
// in order to reuse this kernel for f32, the output
|
||||
// matrix type in kernel function signature is set to
|
||||
// f32 irrespective of original output matrix type.
|
||||
bfloat16* b_q = ( bfloat16* )b;
|
||||
dim_t ir = 0;
|
||||
|
||||
// Store the results in downscaled type (bf16 instead of float).
|
||||
// c[0, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm8,k0,0,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm8,k0,0,0);
|
||||
// c[0, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm9,k1,0,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm9,k1,0,16);
|
||||
// c[0, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm10,k2,0,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm10,k2,0,32);
|
||||
// c[0, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm11,k3,0,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm11,k3,0,48);
|
||||
|
||||
// c[1, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm12,k0,1,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm12,k0,1,0);
|
||||
// c[1, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm13,k1,1,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm13,k1,1,16);
|
||||
// c[1, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm14,k2,1,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm14,k2,1,32);
|
||||
// c[1, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm15,k3,1,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm15,k3,1,48);
|
||||
|
||||
// c[2, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm16,k0,2,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm16,k0,2,0);
|
||||
// c[2, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm17,k1,2,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm17,k1,2,16);
|
||||
// c[2, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm18,k2,2,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm18,k2,2,32);
|
||||
// c[2, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm19,k3,2,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm19,k3,2,48);
|
||||
|
||||
// c[3, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm20,k0,3,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm20,k0,3,0);
|
||||
// c[3, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm21,k1,3,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm21,k1,3,16);
|
||||
// c[3, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm22,k2,3,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm22,k2,3,32);
|
||||
// c[3, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm23,k3,3,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm23,k3,3,48);
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == S32 )
|
||||
{
|
||||
@@ -4836,39 +4826,34 @@ POST_OPS_3x64_OPS_DISABLE:
|
||||
// final write for a given block within C.
|
||||
if ( post_ops_attr.c_stor_type == BF16 )
|
||||
{
|
||||
// Actually the b matrix is of type bfloat16. However
|
||||
// in order to reuse this kernel for f32, the output
|
||||
// matrix type in kernel function signature is set to
|
||||
// f32 irrespective of original output matrix type.
|
||||
bfloat16* b_q = ( bfloat16* )b;
|
||||
dim_t ir = 0;
|
||||
|
||||
// Store the results in downscaled type (bf16 instead of float).
|
||||
// c[0, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm8,k0,0,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm8,k0,0,0);
|
||||
// c[0, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm9,k1,0,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm9,k1,0,16);
|
||||
// c[0, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm10,k2,0,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm10,k2,0,32);
|
||||
// c[0, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm11,k3,0,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm11,k3,0,48);
|
||||
|
||||
// c[1, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm12,k0,1,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm12,k0,1,0);
|
||||
// c[1, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm13,k1,1,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm13,k1,1,16);
|
||||
// c[1, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm14,k2,1,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm14,k2,1,32);
|
||||
// c[1, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm15,k3,1,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm15,k3,1,48);
|
||||
|
||||
// c[2, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm16,k0,2,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm16,k0,2,0);
|
||||
// c[2, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm17,k1,2,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm17,k1,2,16);
|
||||
// c[2, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm18,k2,2,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm18,k2,2,32);
|
||||
// c[2, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm19,k3,2,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm19,k3,2,48);
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == S32 )
|
||||
{
|
||||
@@ -5972,30 +5957,25 @@ POST_OPS_2x64_OPS_DISABLE:
|
||||
// final write for a given block within C.
|
||||
if ( post_ops_attr.c_stor_type == BF16 )
|
||||
{
|
||||
// Actually the b matrix is of type bfloat16. However
|
||||
// in order to reuse this kernel for f32, the output
|
||||
// matrix type in kernel function signature is set to
|
||||
// f32 irrespective of original output matrix type.
|
||||
bfloat16* b_q = ( bfloat16* )b;
|
||||
dim_t ir = 0;
|
||||
|
||||
// Store the results in downscaled type (bf16 instead of float).
|
||||
// c[0, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm8,k0,0,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm8,k0,0,0);
|
||||
// c[0, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm9,k1,0,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm9,k1,0,16);
|
||||
// c[0, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm10,k2,0,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm10,k2,0,32);
|
||||
// c[0, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm11,k3,0,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm11,k3,0,48);
|
||||
|
||||
// c[1, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm12,k0,1,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm12,k0,1,0);
|
||||
// c[1, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm13,k1,1,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm13,k1,1,16);
|
||||
// c[1, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm14,k2,1,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm14,k2,1,32);
|
||||
// c[1, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm15,k3,1,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm15,k3,1,48);
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == S32 )
|
||||
{
|
||||
@@ -6819,21 +6799,16 @@ POST_OPS_1x64_OPS_DISABLE:
|
||||
// final write for a given block within C.
|
||||
if ( post_ops_attr.c_stor_type == BF16 )
|
||||
{
|
||||
// Actually the b matrix is of type bfloat16. However
|
||||
// in order to reuse this kernel for f32, the output
|
||||
// matrix type in kernel function signature is set to
|
||||
// f32 irrespective of original output matrix type.
|
||||
bfloat16* b_q = ( bfloat16* )b;
|
||||
dim_t ir = 0;
|
||||
|
||||
// Store the results in downscaled type (bf16 instead of float).
|
||||
// c[0, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm8,k0,0,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm8,k0,0,0);
|
||||
// c[0, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm9,k1,0,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm9,k1,0,16);
|
||||
// c[0, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm10,k2,0,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm10,k2,0,32);
|
||||
// c[0, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm11,k3,0,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(0,jr,zmm11,k3,0,48);
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == S32 )
|
||||
{
|
||||
|
||||
@@ -99,7 +99,7 @@ LPGEMM_ELTWISE_OPS_KERNEL(float,float,f32of32_6x64)
|
||||
__m512 zmm2 = _mm512_setzero_ps();
|
||||
__m512 zmm3 = _mm512_setzero_ps();
|
||||
__m512 zmm4 = _mm512_setzero_ps();
|
||||
|
||||
|
||||
uint64_t orig_post_op_c_j = post_ops_attr.post_op_c_j;
|
||||
for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR )
|
||||
{
|
||||
@@ -1973,66 +1973,60 @@ POST_OPS_6x64_OPS_DISABLE:
|
||||
// final write for a given block within C.
|
||||
if ( post_ops_attr.c_stor_type == BF16 )
|
||||
{
|
||||
// Actually the b matrix is of type bfloat16. However
|
||||
// in order to reuse this kernel for f32, the output
|
||||
// matrix type in kernel function signature is set to
|
||||
// f32 irrespective of original output matrix type.
|
||||
bfloat16* b_q = ( bfloat16* )b;
|
||||
|
||||
// Store the results in downscaled type (bf16 instead of float).
|
||||
// c[0, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm8,k0,0,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm8,k0,0,0);
|
||||
// c[0, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm9,k1,0,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm9,k1,0,16);
|
||||
// c[0, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm10,k2,0,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm10,k2,0,32);
|
||||
// c[0, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm11,k3,0,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm11,k3,0,48);
|
||||
|
||||
// c[1, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm12,k0,1,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm12,k0,1,0);
|
||||
// c[1, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm13,k1,1,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm13,k1,1,16);
|
||||
// c[1, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm14,k2,1,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm14,k2,1,32);
|
||||
// c[1, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm15,k3,1,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm15,k3,1,48);
|
||||
|
||||
// c[2, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm16,k0,2,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm16,k0,2,0);
|
||||
// c[2, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm17,k1,2,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm17,k1,2,16);
|
||||
// c[2, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm18,k2,2,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm18,k2,2,32);
|
||||
// c[2, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm19,k3,2,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm19,k3,2,48);
|
||||
|
||||
// c[3, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm20,k0,3,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm20,k0,3,0);
|
||||
// c[3, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm21,k1,3,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm21,k1,3,16);
|
||||
// c[3, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm22,k2,3,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm22,k2,3,32);
|
||||
// c[3, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm23,k3,3,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm23,k3,3,48);
|
||||
|
||||
// c[4, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm24,k0,4,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm24,k0,4,0);
|
||||
// c[4, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm25,k1,4,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm25,k1,4,16);
|
||||
// c[4, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm26,k2,4,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm26,k2,4,32);
|
||||
// c[4, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm27,k3,4,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm27,k3,4,48);
|
||||
|
||||
// c[5, 0-15]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm28,k0,5,0);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm28,k0,5,0);
|
||||
// c[5, 16-31]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm29,k1,5,16);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm29,k1,5,16);
|
||||
// c[5, 32-47]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm30,k2,5,32);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm30,k2,5,32);
|
||||
// c[5, 48-63]
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(zmm31,k3,5,48);
|
||||
CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,zmm31,k3,5,48);
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == S32 )
|
||||
{
|
||||
|
||||
@@ -515,14 +515,17 @@
|
||||
reg = _mm512_mul_ps( reg, selector ); \
|
||||
reg = _mm512_add_ps( reg, zero_point ); \
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
#define CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,reg,mask,m_ind,n_ind)
|
||||
#else
|
||||
// Downscale store bf16 macro
|
||||
#define CVT_STORE_F32_BF16_POST_OPS_MASK(reg,mask,m_ind,n_ind) \
|
||||
#define CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,reg,mask,m_ind,n_ind) \
|
||||
_mm256_mask_storeu_epi16 \
|
||||
( \
|
||||
b_q + ( rs_b * ( ir + m_ind ) ) + ( cs_b * ( jr + n_ind ) ), \
|
||||
((bfloat16*)b) + ( rs_b * ( ir + m_ind ) ) + ( cs_b * ( jr + n_ind ) ), \
|
||||
mask, (__m256i) _mm512_cvtneps_pbh( reg ) \
|
||||
) \
|
||||
|
||||
)
|
||||
#endif
|
||||
|
||||
// Downscale store s8 macro
|
||||
#define CVT_STORE_F32_S8_POST_OPS_MASK(reg,mask,m_ind,n_ind) \
|
||||
|
||||
@@ -1167,8 +1167,7 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
{
|
||||
bfloat16 ctemp[16];
|
||||
|
||||
_mm256_mask_storeu_epi16( ctemp, k2,
|
||||
(__m256i)_mm512_cvtneps_pbh( acc_8 ) );
|
||||
CVT_STORE_F32_BF16_MASK_AVX2(acc_8, k2, ctemp);
|
||||
|
||||
for (dim_t i = 0; i < mr0; i++)
|
||||
{
|
||||
|
||||
@@ -433,16 +433,10 @@
|
||||
#define CVT_STORE_F32_U8(reg,m_ind,n_ind) \
|
||||
CVT_STORE_F32_U8_MASK(mask_all1,reg,m_ind,n_ind) \
|
||||
|
||||
// Downscale store bf16 macro
|
||||
#define CVT_STORE_S32_BF16(reg,m_ind,n_ind) \
|
||||
_mm256_mask_storeu_epi16 \
|
||||
( \
|
||||
( bfloat16* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ), \
|
||||
mask_all1, (__m256i) _mm512_cvtneps_pbh( _mm512_cvtepi32_ps ( reg ) ) \
|
||||
) \
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
#define CVT_STORE_F32_BF16_MASK(mask,reg,m_ind,n_ind)
|
||||
#define CVT_STORE_F32_BF16_MASK_AVX2(reg,mask, ptr)
|
||||
#else
|
||||
// Downscale store bf16 macro
|
||||
#define CVT_STORE_F32_BF16_MASK(mask,reg,m_ind,n_ind) \
|
||||
_mm256_mask_storeu_epi16 \
|
||||
@@ -453,6 +447,12 @@
|
||||
mask, (__m256i) _mm512_cvtneps_pbh( ( reg ) ) \
|
||||
); \
|
||||
|
||||
#define CVT_STORE_F32_BF16_MASK_AVX2(reg,mask, ptr) \
|
||||
_mm256_mask_storeu_epi16( ptr, mask, \
|
||||
(__m256i)_mm512_cvtneps_pbh( reg ) );
|
||||
#endif
|
||||
|
||||
|
||||
#define CVT_STORE_F32_BF16(reg,m_ind,n_ind) \
|
||||
CVT_STORE_F32_BF16_MASK(mask_all1,reg,m_ind,n_ind); \
|
||||
|
||||
@@ -1078,17 +1078,7 @@
|
||||
mask, reg \
|
||||
); \
|
||||
|
||||
// Downscale store bf16 macro
|
||||
#define CVT_STORE_S32_BF16_MASK(reg,mask,m_ind,n_ind) \
|
||||
_mm256_mask_storeu_epi16 \
|
||||
( \
|
||||
( bfloat16* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ), \
|
||||
mask, (__m256i) _mm512_cvtneps_pbh( _mm512_cvtepi32_ps ( reg ) ) \
|
||||
) \
|
||||
|
||||
// Downscale store f32 macro
|
||||
// Downscale store f32 macro
|
||||
#define CVT_STORE_S32_F32_MASK(reg,mask,m_ind,n_ind) \
|
||||
_mm512_mask_storeu_ps \
|
||||
( \
|
||||
|
||||
@@ -1133,8 +1133,7 @@ LPGEMV_N_EQ1_KERN(uint8_t, int8_t, int32_t, u8s8s32os32)
|
||||
{
|
||||
bfloat16 ctemp[16];
|
||||
|
||||
_mm256_mask_storeu_epi16( ctemp, k2,
|
||||
(__m256i)_mm512_cvtneps_pbh( acc_8 ) );
|
||||
CVT_STORE_F32_BF16_MASK_AVX2( acc_8, k2, ctemp );
|
||||
|
||||
for (dim_t i = 0; i < mr0; i++)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user