diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c index fe7b1d7cc..483f09e2c 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c @@ -117,7 +117,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) { goto err_hndl; } - + #ifdef LPGEMM_BF16_JIT if( get_jit_kernels_generated() == FALSE ) { @@ -176,6 +176,14 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) goto err_hndl; } + //For AVX-2 system support of BF16, handle the unsupported cases and return early + arch_t arch_id = bli_arch_query_id(); + if( (arch_id != BLIS_ARCH_ZEN4 ) && ( ( is_column_major == TRUE ) || + ( bli_is_trans(blis_transa ) ) || ( bli_is_trans(blis_transb ) ) ) ) + { + bli_print_msg(" Transpose of A/B matrix or column major is not supported in AVX2.", __FILE__, __LINE__ ); + goto err_hndl; + } // From 5-loop function point of view // B matrix needs to be packed in a certain format in order to be loaded // and used in bf16 instrution. As such the mtag_b always needs to be either @@ -226,7 +234,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 ); - if (( bli_arch_query_id() == BLIS_ARCH_ZEN4 ) && ( is_tiny_input_bf16of32( m, n, k, lcntx_g ) == TRUE ) && + if ( ( arch_id == BLIS_ARCH_ZEN4 ) && ( is_tiny_input_bf16of32( m, n, k, lcntx_g ) == TRUE ) && ( is_single_thread( &rntm_g ) == TRUE) && ( is_row_major == TRUE ) ) { diff --git a/addon/aocl_gemm/kernels/lpgemm_kernels.h b/addon/aocl_gemm/kernels/lpgemm_kernels.h index 20d0b5894..24f145ad3 100644 --- a/addon/aocl_gemm/kernels/lpgemm_kernels.h +++ b/addon/aocl_gemm/kernels/lpgemm_kernels.h @@ -41,11 +41,10 @@ // Disable BF16 kernel in cases where compilers support other avx 512 // features except BF16 ISA. #if ( defined( BLIS_GCC ) && ( ( __GNUC__ < 11 ) || \ - ( ( __GNUC__ == 11 ) && ( __GNUC_MINOR__ < 2 ) ) ) ) -//Commenting the JIT definition, to enable the BF16 -> F32 path -//#define LPGEMM_BF16_JIT -//#define BPREFETCH_JIT -//#define DUMP_JIT_CODE + ( ( __GNUC__ == 11 ) && ( __GNUC_MINOR__ < 2 ) ) ) && defined(BLIS_KERNELS_ZEN4) ) +#define LPGEMM_BF16_JIT +#define BPREFETCH_JIT +#define DUMP_JIT_CODE #endif typedef void (*lpgemm_m_fringe_f32_ker_ft)