Tiny GEMM path for BF16 LPGEMM API.

-Currently the BF16 API uses the 5 loop algorithm inside the OMP loop
to compute the results, irrespective if the input sizes. However it
was observed that for very tiny sizes (n <= 128, m <= 36), this OMP
loop and NC,MC,KC loops were turning out to be overheads.
-In order to address this, a new path without OMP loop and just the
NR loop over the micro-kernel is introduced for tiny inputs. This is
only applied when the num threads set for GEMM is 1.
-Only row major inputs are allowed to proceed with tiny GEMM.

AMD-Internal: [SWLCSG-3380, SWLCSG-3258]

Change-Id: I9dfa6b130f3c597ca7fcf5f1bc1231faf39de031
This commit is contained in:
Mithun Mohan
2025-02-07 09:02:15 +00:00
committed by MithunMohan KadavilMadanaMohanan
parent c47f0f499f
commit b9f6286731
10 changed files with 442 additions and 33 deletions

View File

@@ -43,6 +43,31 @@
#include "lpgemm_utils.h"
#include "lpgemm_logger.h"
static inline bool is_tiny_input_bf16obf16
(
dim_t m,
dim_t n,
dim_t k,
lpgemm_cntx_t* lcntx
)
{
bool is_tiny = FALSE;
const dim_t NC = lcntx->blksz.NC;
const dim_t MC = lcntx->blksz.MC;
const dim_t KC = lcntx->blksz.KC;
// Need to explicitly check for MC, NC boundaries for safety.
if ( ( k < 256 ) && ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
( ( ( m <= 36 ) && ( n <= 64 ) ) ||
( ( m <= 12 ) && ( n <= 128 ) ) ) )
{
is_tiny = TRUE;
}
return is_tiny;
}
AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
{
LPGEMM_START_LOGGER();
@@ -199,6 +224,23 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
if ( ( is_tiny_input_bf16obf16( m, n, k, lcntx_g ) == TRUE ) &&
( is_single_thread( &rntm_g ) == TRUE) &&
( is_row_major == TRUE ) )
{
lpgemm_rowvar_tiny_bf16bf16f32of32
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
( float* )c, rs_c, cs_c,
alpha, beta,
lcntx_g,
post_op_list, BF16
);
return;
}
#ifdef BLIS_ENABLE_OPENMP
// Swapping inputs to induce row major computation for column major inputs.
if ( is_column_major == TRUE )

View File

@@ -43,6 +43,31 @@
#include "lpgemm_utils.h"
#include "lpgemm_logger.h"
static inline bool is_tiny_input_bf16of32
(
dim_t m,
dim_t n,
dim_t k,
lpgemm_cntx_t* lcntx
)
{
bool is_tiny = FALSE;
const dim_t NC = lcntx->blksz.NC;
const dim_t MC = lcntx->blksz.MC;
const dim_t KC = lcntx->blksz.KC;
// Need to explicitly check for MC, NC boundaries for safety.
if ( ( k < 256 ) && ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
( ( ( m <= 36 ) && ( n <= 64 ) ) ||
( ( m <= 12 ) && ( n <= 128 ) ) ) )
{
is_tiny = TRUE;
}
return is_tiny;
}
AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
{
LPGEMM_START_LOGGER();
@@ -200,6 +225,23 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
if ( ( is_tiny_input_bf16of32( m, n, k, lcntx_g ) == TRUE ) &&
( is_single_thread( &rntm_g ) == TRUE) &&
( is_row_major == TRUE ) )
{
lpgemm_rowvar_tiny_bf16bf16f32of32
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
lcntx_g,
post_op_list, F32
);
return;
}
#ifdef BLIS_ENABLE_OPENMP
// Swapping inputs to induce row major computation for column major inputs.
if ( is_column_major == TRUE )

View File

@@ -67,26 +67,6 @@ static inline bool is_tiny_input_f32
return is_tiny;
}
static inline bool is_single_thread( rntm_t* rntm_g )
{
bool is_st = FALSE;
dim_t n_threads = bli_rntm_num_threads( rntm_g );
dim_t jc_ways = bli_rntm_jc_ways( rntm_g );
dim_t ic_ways = bli_rntm_ic_ways( rntm_g );
ic_ways = ( ic_ways > 0 ) ? ic_ways : 1;
jc_ways = ( jc_ways > 0 ) ? jc_ways : 1;
if ( ( n_threads == 1 ) ||
( ( ic_ways * jc_ways ) == 1 ) )
{
is_st = TRUE;
}
return is_st;
}
AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
{
LPGEMM_START_LOGGER();
@@ -233,7 +213,8 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( F32F32F32OF32 );
if ( ( is_tiny_input_f32( m, n, k, lcntx_g ) == TRUE ) &&
( is_single_thread( &rntm_g ) == TRUE) )
( is_single_thread( &rntm_g ) == TRUE) &&
( is_row_major == TRUE ) )
{
lpgemm_rowvar_tiny_f32f32f32of32
(

View File

@@ -83,7 +83,7 @@ LPGEMV(bfloat16, bfloat16, float, bf16bf16f32of32)
float *c_use = NULL;
bfloat16* pack_a_buffer_bf16;
bfloat16* pack_a_buffer_bf16 = NULL;
lpgemm_post_op_attr post_ops_attr;
post_ops_attr.c_stor_type = c_downscale;
@@ -96,7 +96,7 @@ LPGEMV(bfloat16, bfloat16, float, bf16bf16f32of32)
mem_t mem_a = BLIS_MEM_INITIALIZER;
mem_t mem_b = BLIS_MEM_INITIALIZER;
bfloat16* pack_b_buffer_bf16;
bfloat16* pack_b_buffer_bf16 = NULL;
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
thrinfo_t thread_jc;
@@ -379,8 +379,8 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32)
dim_t rs_c_downscale = rs_c;
// Pack buffer for B.
bfloat16* pack_b_buffer_bf16;
bfloat16* pack_a_buffer_bf16;
bfloat16* pack_b_buffer_bf16 = NULL;
bfloat16* pack_a_buffer_bf16 = NULL;
mem_t mem_b = BLIS_MEM_INITIALIZER;
mem_t mem_a = BLIS_MEM_INITIALIZER;
siz_t mem_b_size_req = 0;

View File

@@ -0,0 +1,323 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_pack_bf16.h"
#include "lpgemm_kernels.h"
#include "lpgemm_utils.h"
#include "lpgemm_thrinfo_utils.h"
#include "lpgemm_config.h"
// Kernel function prototypes
typedef void (*lpgemm_rowvar_bf16)
(
const dim_t,
const dim_t,
const dim_t,
const bfloat16*,
const dim_t,
const dim_t,
const dim_t,
const bfloat16*,
const dim_t,
const dim_t,
float*,
const dim_t,
const dim_t,
const float,
const float,
lpgemm_post_op*,
lpgemm_post_op_attr
);
#ifdef BLIS_KERNELS_ZEN4
LPGEMV_TINY(bfloat16, bfloat16, float, bf16bf16f32of32)
{
// Strides are updated based on matrix packing/reordering.
bfloat16* a_use = ( bfloat16* )a;
inc_t rs_a_use = rs_a;
inc_t cs_a_use = cs_a;
bfloat16* b_use = ( bfloat16* )b;
inc_t rs_b_use = rs_b;
inc_t cs_b_use = cs_b;
lpgemm_post_op_attr post_ops_attr;
post_ops_attr.c_stor_type = c_downscale;
if (c_downscale < F32)
{
post_ops_attr.buf_downscale = c;
}
else
{
post_ops_attr.buf_downscale = NULL;
}
if( n == 1 )
{
bfloat16* pack_a_buffer_bf16 = NULL;
bfloat16* pack_b_buffer_bf16 = NULL;
err_t err = BLIS_SUCCESS;
// Increased MR from 6 to 16 to make use of 32 ZMM registers
dim_t MR = 16;
// pack B matrix if rs_b > 1
if( ( mtag_b == PACK ) && ( rs_b != 1 ) )
{
siz_t mem_b_size_req = sizeof( bfloat16 ) * k;
pack_b_buffer_bf16 =
( bfloat16* )bli_malloc_user( mem_b_size_req, &err );
for( dim_t k0 = 0; k0 < k; k0++ )
{
pack_b_buffer_bf16[k0] = b[ k0*rs_b ];
}
b_use = pack_b_buffer_bf16;
rs_b_use = 1;
cs_b_use = 1;
}
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_j = 0;
post_ops_attr.rs_c_downscale = rs_c;
if( mtag_a == PACK )
{
siz_t mem_a_size_req = sizeof( bfloat16 ) * m * k;
pack_a_buffer_bf16 =
( bfloat16* )bli_malloc_user( mem_a_size_req, &err );
( ( pack_bf16 ) lcntx->packa_fun_ptr )
(
pack_a_buffer_bf16,
a,
rs_a, cs_a,
m, k,
&rs_a_use, &cs_a_use
);
a_use = pack_a_buffer_bf16;
}
// Call lpgemv_n_one kernel
lpgemv_n_one_bf16bf16f32of32
(
m, k,
a_use, rs_a_use, cs_a_use, mtag_a,
b_use, rs_b_use, cs_b_use, mtag_b,
c, rs_c, cs_c,
alpha, beta,
MR, k,
post_op_list,
&post_ops_attr
);
// Release pack buffers.
if ( pack_a_buffer_bf16 != NULL )
{
bli_free_user( pack_a_buffer_bf16 );
}
if ( pack_b_buffer_bf16 != NULL )
{
bli_free_user( pack_b_buffer_bf16 );
}
}
}
#endif
// B should always be packed.
LPGEMM_TINY(bfloat16,bfloat16,float,bf16bf16f32of32)
{
#if (defined(BLIS_KERNELS_ZEN4) && (!defined(LPGEMM_BF16_JIT)))
// Handle using LPGEMV when m or/and n equal to 1
// The avx512 check will be removed when avx2 kernels added in future
if ( n == 1 )
{
lpgemv_rowvar_tiny_bf16bf16f32of32( m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha,
beta,
lcntx,
post_op_list,
c_downscale);
return;
}
#endif
dim_t NR = lcntx->blksz.NR;
const int16_t* a_use = NULL;
dim_t cs_a_use = cs_a;
dim_t rs_a_use = rs_a;
dim_t a_block_stride = 0;
const int16_t* b_use = NULL;
dim_t rs_b_use = rs_b;
dim_t cs_b_use = cs_b;
dim_t rs_c_use = rs_c;
dim_t rs_c_downscale = rs_c;
bfloat16* pack_a_buffer_bf16 = NULL;
bfloat16* pack_b_buffer_bf16 = NULL;
err_t err = BLIS_SUCCESS;
siz_t mem_b_size_req = 0;
siz_t mem_a_size_req = 0;
dim_t packb_min_NR = 16;
// kc needs to be a multiple of 2 so that it can be used with dpbf16_ps
// instruction. Padding is added in cases this condition is not
// satisfied, and therefore the k offset used for packed/reordered
// buffer needs to be updated.
dim_t k_updated = k;
k_updated += (k_updated & 0x1);
lpgemm_post_op_attr post_ops_attr;
post_ops_attr.c_stor_type = c_downscale;
if ( c_downscale < F32 )
{
post_ops_attr.buf_downscale = c;
}
else
{
post_ops_attr.buf_downscale = NULL;
}
bool is_first_k = TRUE;
post_ops_attr.is_first_k = is_first_k;
bool is_last_k = TRUE;
post_ops_attr.is_last_k = is_last_k;
// k needs to be a multiple of 2 so that it can be used with dpbf16_ps
// instruction. Padding is added in cases this condition is not
// satisfied, and therefore the k offsets used for packed/reordered
// buffers needs to be updated.
dim_t k0_updated = k;
k0_updated += (k0_updated & 0x1);
if ( mtag_b == PACK )
{
// nc0 needs to be a multiple of 16 since this gives maximum
// vectorization. Packing B always results in buffers with width
// which is a multiple of 16. Subsequently the nc0 offsets used
// for packed/reordered buffers needs to be updated.
dim_t nc0_updated = make_multiple_of_n( n, packb_min_NR );
mem_b_size_req = sizeof( bfloat16 ) * nc0_updated * k0_updated;
pack_b_buffer_bf16 =
( bfloat16* )bli_malloc_user( mem_b_size_req, &err );
( ( pack_bf16 )lcntx->packb_fun_ptr )
(
pack_b_buffer_bf16,
b,
rs_b, cs_b,
n, k,
&rs_b_use, &cs_b_use
);
b_use = pack_b_buffer_bf16;
}
else if ( mtag_b == REORDERED )
{
b_use = b;
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
}
if ( mtag_a == UNPACKED )
{
a_use = a;
// bf16 kernel reads 2 elements, totalling 4 bytes in a
// single broadcast for use in bf16 instruction.
// Non bf16 based kernel requires update to this code.
cs_a_use = 2;
a_block_stride = rs_a;
rs_a_use = rs_a;
}
else if ( mtag_a == PACK )
{
mem_a_size_req = sizeof( bfloat16 ) * m * k;
pack_a_buffer_bf16 =
( bfloat16* )bli_malloc_user( mem_a_size_req, &err );
( ( pack_bf16 )lcntx->packa_fun_ptr )
(
pack_a_buffer_bf16,
a,
rs_a, cs_a,
m, k,
&rs_a_use, &cs_a_use
);
a_use = pack_a_buffer_bf16;
a_block_stride = rs_a_use;
}
for ( dim_t jr = 0; jr < n; jr += NR )
{
dim_t nr0 = bli_min( ( n - jr ), NR );
// Post ops meta attributes.
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_j = jr;
post_ops_attr.rs_c_downscale = rs_c_downscale;
// Reorder/Packed B, Reorder/Packed/Unpacked A call.
( ( lpgemm_rowvar_bf16 )lcntx->kern_fun_ptr )
(
m, nr0, k,
a_use, rs_a_use, cs_a_use, a_block_stride,
( b_use + ( jr * k0_updated ) ), rs_b_use, cs_b_use,
( c + jr ), rs_c_use, 1,
alpha, beta,
post_op_list, post_ops_attr
);
}
// Release pack buffers.
if ( pack_a_buffer_bf16 != NULL )
{
bli_free_user( pack_a_buffer_bf16 );
}
if ( pack_b_buffer_bf16 != NULL )
{
bli_free_user( pack_b_buffer_bf16 );
}
}

View File

@@ -73,9 +73,6 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
inc_t rs_b_use = rs_b;
inc_t cs_b_use = cs_b;
// Strides are updated based on matrix packing/reordering.
float *c_use = ( float* )c;
lpgemm_post_op_attr post_ops_attr;
post_ops_attr.c_stor_type = c_downscale;
if (c_downscale < F32) post_ops_attr.buf_downscale = c;
@@ -131,7 +128,7 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
m, k,
a_use, rs_a_use, cs_a_use, mtag_a,
b_use, rs_b_use, cs_b_use, mtag_b,
c_use, rs_c, cs_c,
c, rs_c, cs_c,
alpha, beta,
MR, k,
post_op_list,

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -223,4 +223,24 @@ BLIS_INLINE void adjust_B_panel_reordered_jc( dim_t* jc, dim_t panel_start )
( *jc ) = panel_start;
}
static inline bool is_single_thread( rntm_t* rntm_g )
{
bool is_st = FALSE;
dim_t n_threads = bli_rntm_num_threads( rntm_g );
dim_t jc_ways = bli_rntm_jc_ways( rntm_g );
dim_t ic_ways = bli_rntm_ic_ways( rntm_g );
ic_ways = ( ic_ways > 0 ) ? ic_ways : 1;
jc_ways = ( jc_ways > 0 ) ? jc_ways : 1;
if ( ( n_threads == 1 ) ||
( ( ic_ways * jc_ways ) == 1 ) )
{
is_st = TRUE;
}
return is_st;
}
#endif //LPGEMM_UTILS_H

View File

@@ -64,6 +64,7 @@ void lpgemm_rowvar_tiny_ ## LP_SFX \
) \
LPGEMM_TINY(float,float,float,f32f32f32of32);
LPGEMM_TINY(bfloat16,bfloat16,float,bf16bf16f32of32);
#define LPGEMM_5LOOP(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
@@ -152,6 +153,7 @@ void lpgemv_rowvar_tiny_ ## LP_SFX \
) \
LPGEMV_TINY(float, float, float, f32f32f32of32);
LPGEMV_TINY(bfloat16,bfloat16,float,bf16bf16f32of32);
#define LPGEMV(A_type, B_type, C_type, LP_SFX) \
void lpgemv_rowvar_ ## LP_SFX \

View File

@@ -1,7 +1,9 @@
r n n n r 121 1 1601 1601 1 1 f32f32f32of32:bias=na,matrix_mul=na
r n n n r 13 1 16 16 1 1 f32f32f32of32:bias=na,matrix_mul=na
r n n n r 36 64 16 16 64 64 f32f32f32of32:none
r n n n r 1 48 16 16 48 48 f32f32f32of32:bias=na,matrix_mul=na
r n n n r 121 1 1601 1601 1 1 bf16bf16f32obf16:bias=bf16,matrix_mul=bf16
r n n n r 12 108 16 16 108 108 f32f32f32of32:bias=na,matrix_mul=na
r n n n r 12 108 16 16 108 108 bf16bf16f32of32:bias=na,matrix_mul=na
r n n n r 12 108 16 16 108 108 bf16bf16f32obf16:bias=bf16,matrix_mul=bf16
r n n n r 36 54 16 16 64 64 f32f32f32of32:none
r t n n r 1 128 64 1 128 128 *:none
c n t n n 32 128 2 32 128 32 bf16bf16f32of32:bias=na,swish
r n n n r 6 1 4 4 16 16 bf16s4f32of32:pre_op_scale=scalar,pre_op_scale_type=bf16,group_size=2