Added column-major support for BF16 tiny path

- Added column major path for BF16 tiny path
 - Tuned tiny-path thresholds to support few more inputs to the
   tiny path.

AMD-Internal: [SWLCSG-3380]
Change-Id: I9a5578c9f0d689881fc5a67ab778e6a917c4fce1
This commit is contained in:
varshav
2025-03-12 11:56:46 +00:00
committed by Nallani Bhaskar
parent fb4617d7c3
commit acee9c7d4e

View File

@@ -44,25 +44,23 @@
#include "lpgemm_logger.h"
static inline bool is_tiny_input_bf16obf16
(
(
dim_t m,
dim_t n,
dim_t k,
lpgemm_cntx_t* lcntx
)
)
{
bool is_tiny = FALSE;
const dim_t NC = lcntx->blksz.NC;
const dim_t MC = lcntx->blksz.MC;
const dim_t KC = lcntx->blksz.KC;
const dim_t MR = lcntx->blksz.MR;
const dim_t NR = lcntx->blksz.NR;
const dim_t NC = lcntx->blksz.NC;
const dim_t MC = lcntx->blksz.MC;
const dim_t KC = lcntx->blksz.KC;
const dim_t MR = lcntx->blksz.MR;
const dim_t NR = lcntx->blksz.NR;
dim_t mnk = m * n * k;
const dim_t mnk_magic_num = 36 * 128 * 128;
const dim_t mnk_magic_num = 36 * 128 * 256;
const dim_t m_thresh = 6 * MR;
const dim_t n_thresh = 2 * NR;
const dim_t n_thresh = 6 * NR;
const dim_t k_thresh = 1024;
// Need to explicitly check for MC, NC boundaries for safety.
@@ -70,10 +68,10 @@ static inline bool is_tiny_input_bf16obf16
( ( m <= m_thresh ) && ( n <= n_thresh ) && ( k <= k_thresh ) &&
( mnk < mnk_magic_num ) ) )
{
is_tiny = TRUE;
return TRUE;
}
return is_tiny;
return FALSE;
}
AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
@@ -234,21 +232,38 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
#if (defined(BLIS_KERNELS_ZEN4) && (!defined(LPGEMM_BF16_JIT)))
arch_t arch_id = bli_arch_query_id();
if( ( ( arch_id == BLIS_ARCH_ZEN4 ) || ( arch_id == BLIS_ARCH_ZEN5 ) ) &&
( is_tiny_input_bf16obf16( m, n, k, lcntx_g ) == TRUE ) &&
( is_single_thread( &rntm_g ) == TRUE) &&
( is_row_major == TRUE ) )
( is_single_thread( &rntm_g ) == TRUE) )
{
lpgemm_rowvar_tiny_bf16bf16f32of32
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
( float* )c, rs_c, cs_c,
alpha, beta,
lcntx_g,
post_op_list, BF16
);
return;
if( ( is_row_major == TRUE ) &&
( is_tiny_input_bf16obf16( m, n, k, lcntx_g ) == TRUE ) )
{
lpgemm_rowvar_tiny_bf16bf16f32of32
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
( float* )c, rs_c, cs_c,
alpha, beta,
lcntx_g,
post_op_list, BF16
);
return;
}
else if( ( is_column_major == TRUE ) &&
( is_tiny_input_bf16obf16( n, m, k, lcntx_g ) == TRUE ) )
{
lpgemm_rowvar_tiny_bf16bf16f32of32
(
n, m, k,
b, rs_b, cs_b, mtag_b,
a, rs_a, cs_a, mtag_a,
( float* )c, rs_c, cs_c,
alpha, beta,
lcntx_g,
post_op_list, BF16
);
return;
}
}
#endif
#ifdef BLIS_ENABLE_OPENMP