mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Tiny GEMM path for BF16 LPGEMM API.
-Currently the BF16 API uses the 5 loop algorithm inside the OMP loop to compute the results, irrespective if the input sizes. However it was observed that for very tiny sizes (n <= 128, m <= 36), this OMP loop and NC,MC,KC loops were turning out to be overheads. -In order to address this, a new path without OMP loop and just the NR loop over the micro-kernel is introduced for tiny inputs. This is only applied when the num threads set for GEMM is 1. -Only row major inputs are allowed to proceed with tiny GEMM. AMD-Internal: [SWLCSG-3380, SWLCSG-3258] Change-Id: I9dfa6b130f3c597ca7fcf5f1bc1231faf39de031
This commit is contained in:
committed by
MithunMohan KadavilMadanaMohanan
parent
c47f0f499f
commit
b9f6286731
@@ -43,6 +43,31 @@
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
static inline bool is_tiny_input_bf16obf16
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
lpgemm_cntx_t* lcntx
|
||||
)
|
||||
{
|
||||
bool is_tiny = FALSE;
|
||||
|
||||
const dim_t NC = lcntx->blksz.NC;
|
||||
const dim_t MC = lcntx->blksz.MC;
|
||||
const dim_t KC = lcntx->blksz.KC;
|
||||
|
||||
// Need to explicitly check for MC, NC boundaries for safety.
|
||||
if ( ( k < 256 ) && ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
|
||||
( ( ( m <= 36 ) && ( n <= 64 ) ) ||
|
||||
( ( m <= 12 ) && ( n <= 128 ) ) ) )
|
||||
{
|
||||
is_tiny = TRUE;
|
||||
}
|
||||
|
||||
return is_tiny;
|
||||
}
|
||||
|
||||
AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
@@ -199,6 +224,23 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
|
||||
|
||||
if ( ( is_tiny_input_bf16obf16( m, n, k, lcntx_g ) == TRUE ) &&
|
||||
( is_single_thread( &rntm_g ) == TRUE) &&
|
||||
( is_row_major == TRUE ) )
|
||||
{
|
||||
lpgemm_rowvar_tiny_bf16bf16f32of32
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
( float* )c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
lcntx_g,
|
||||
post_op_list, BF16
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
// Swapping inputs to induce row major computation for column major inputs.
|
||||
if ( is_column_major == TRUE )
|
||||
|
||||
@@ -43,6 +43,31 @@
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
static inline bool is_tiny_input_bf16of32
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
lpgemm_cntx_t* lcntx
|
||||
)
|
||||
{
|
||||
bool is_tiny = FALSE;
|
||||
|
||||
const dim_t NC = lcntx->blksz.NC;
|
||||
const dim_t MC = lcntx->blksz.MC;
|
||||
const dim_t KC = lcntx->blksz.KC;
|
||||
|
||||
// Need to explicitly check for MC, NC boundaries for safety.
|
||||
if ( ( k < 256 ) && ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
|
||||
( ( ( m <= 36 ) && ( n <= 64 ) ) ||
|
||||
( ( m <= 12 ) && ( n <= 128 ) ) ) )
|
||||
{
|
||||
is_tiny = TRUE;
|
||||
}
|
||||
|
||||
return is_tiny;
|
||||
}
|
||||
|
||||
AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
@@ -200,6 +225,23 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
|
||||
|
||||
if ( ( is_tiny_input_bf16of32( m, n, k, lcntx_g ) == TRUE ) &&
|
||||
( is_single_thread( &rntm_g ) == TRUE) &&
|
||||
( is_row_major == TRUE ) )
|
||||
{
|
||||
lpgemm_rowvar_tiny_bf16bf16f32of32
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
// Swapping inputs to induce row major computation for column major inputs.
|
||||
if ( is_column_major == TRUE )
|
||||
|
||||
@@ -67,26 +67,6 @@ static inline bool is_tiny_input_f32
|
||||
return is_tiny;
|
||||
}
|
||||
|
||||
static inline bool is_single_thread( rntm_t* rntm_g )
|
||||
{
|
||||
bool is_st = FALSE;
|
||||
|
||||
dim_t n_threads = bli_rntm_num_threads( rntm_g );
|
||||
dim_t jc_ways = bli_rntm_jc_ways( rntm_g );
|
||||
dim_t ic_ways = bli_rntm_ic_ways( rntm_g );
|
||||
|
||||
ic_ways = ( ic_ways > 0 ) ? ic_ways : 1;
|
||||
jc_ways = ( jc_ways > 0 ) ? jc_ways : 1;
|
||||
|
||||
if ( ( n_threads == 1 ) ||
|
||||
( ( ic_ways * jc_ways ) == 1 ) )
|
||||
{
|
||||
is_st = TRUE;
|
||||
}
|
||||
|
||||
return is_st;
|
||||
}
|
||||
|
||||
AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
@@ -233,7 +213,8 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( F32F32F32OF32 );
|
||||
|
||||
if ( ( is_tiny_input_f32( m, n, k, lcntx_g ) == TRUE ) &&
|
||||
( is_single_thread( &rntm_g ) == TRUE) )
|
||||
( is_single_thread( &rntm_g ) == TRUE) &&
|
||||
( is_row_major == TRUE ) )
|
||||
{
|
||||
lpgemm_rowvar_tiny_f32f32f32of32
|
||||
(
|
||||
|
||||
@@ -83,7 +83,7 @@ LPGEMV(bfloat16, bfloat16, float, bf16bf16f32of32)
|
||||
|
||||
|
||||
float *c_use = NULL;
|
||||
bfloat16* pack_a_buffer_bf16;
|
||||
bfloat16* pack_a_buffer_bf16 = NULL;
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr;
|
||||
post_ops_attr.c_stor_type = c_downscale;
|
||||
@@ -96,7 +96,7 @@ LPGEMV(bfloat16, bfloat16, float, bf16bf16f32of32)
|
||||
mem_t mem_a = BLIS_MEM_INITIALIZER;
|
||||
mem_t mem_b = BLIS_MEM_INITIALIZER;
|
||||
|
||||
bfloat16* pack_b_buffer_bf16;
|
||||
bfloat16* pack_b_buffer_bf16 = NULL;
|
||||
|
||||
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
|
||||
thrinfo_t thread_jc;
|
||||
@@ -379,8 +379,8 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
dim_t rs_c_downscale = rs_c;
|
||||
|
||||
// Pack buffer for B.
|
||||
bfloat16* pack_b_buffer_bf16;
|
||||
bfloat16* pack_a_buffer_bf16;
|
||||
bfloat16* pack_b_buffer_bf16 = NULL;
|
||||
bfloat16* pack_a_buffer_bf16 = NULL;
|
||||
mem_t mem_b = BLIS_MEM_INITIALIZER;
|
||||
mem_t mem_a = BLIS_MEM_INITIALIZER;
|
||||
siz_t mem_b_size_req = 0;
|
||||
|
||||
323
addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16_tiny.c
Normal file
323
addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16_tiny.c
Normal file
@@ -0,0 +1,323 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_pack_bf16.h"
|
||||
#include "lpgemm_kernels.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_thrinfo_utils.h"
|
||||
#include "lpgemm_config.h"
|
||||
|
||||
// Kernel function prototypes
|
||||
typedef void (*lpgemm_rowvar_bf16)
|
||||
(
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const bfloat16*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const bfloat16*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const float,
|
||||
const float,
|
||||
lpgemm_post_op*,
|
||||
lpgemm_post_op_attr
|
||||
);
|
||||
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
LPGEMV_TINY(bfloat16, bfloat16, float, bf16bf16f32of32)
|
||||
{
|
||||
// Strides are updated based on matrix packing/reordering.
|
||||
bfloat16* a_use = ( bfloat16* )a;
|
||||
inc_t rs_a_use = rs_a;
|
||||
inc_t cs_a_use = cs_a;
|
||||
|
||||
bfloat16* b_use = ( bfloat16* )b;
|
||||
inc_t rs_b_use = rs_b;
|
||||
inc_t cs_b_use = cs_b;
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr;
|
||||
post_ops_attr.c_stor_type = c_downscale;
|
||||
if (c_downscale < F32)
|
||||
{
|
||||
post_ops_attr.buf_downscale = c;
|
||||
}
|
||||
else
|
||||
{
|
||||
post_ops_attr.buf_downscale = NULL;
|
||||
}
|
||||
|
||||
if( n == 1 )
|
||||
{
|
||||
bfloat16* pack_a_buffer_bf16 = NULL;
|
||||
bfloat16* pack_b_buffer_bf16 = NULL;
|
||||
err_t err = BLIS_SUCCESS;
|
||||
|
||||
// Increased MR from 6 to 16 to make use of 32 ZMM registers
|
||||
dim_t MR = 16;
|
||||
|
||||
// pack B matrix if rs_b > 1
|
||||
if( ( mtag_b == PACK ) && ( rs_b != 1 ) )
|
||||
{
|
||||
siz_t mem_b_size_req = sizeof( bfloat16 ) * k;
|
||||
pack_b_buffer_bf16 =
|
||||
( bfloat16* )bli_malloc_user( mem_b_size_req, &err );
|
||||
|
||||
for( dim_t k0 = 0; k0 < k; k0++ )
|
||||
{
|
||||
pack_b_buffer_bf16[k0] = b[ k0*rs_b ];
|
||||
}
|
||||
|
||||
b_use = pack_b_buffer_bf16;
|
||||
rs_b_use = 1;
|
||||
cs_b_use = 1;
|
||||
}
|
||||
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_j = 0;
|
||||
post_ops_attr.rs_c_downscale = rs_c;
|
||||
|
||||
if( mtag_a == PACK )
|
||||
{
|
||||
siz_t mem_a_size_req = sizeof( bfloat16 ) * m * k;
|
||||
pack_a_buffer_bf16 =
|
||||
( bfloat16* )bli_malloc_user( mem_a_size_req, &err );
|
||||
|
||||
( ( pack_bf16 ) lcntx->packa_fun_ptr )
|
||||
(
|
||||
pack_a_buffer_bf16,
|
||||
a,
|
||||
rs_a, cs_a,
|
||||
m, k,
|
||||
&rs_a_use, &cs_a_use
|
||||
);
|
||||
a_use = pack_a_buffer_bf16;
|
||||
}
|
||||
// Call lpgemv_n_one kernel
|
||||
lpgemv_n_one_bf16bf16f32of32
|
||||
(
|
||||
m, k,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a,
|
||||
b_use, rs_b_use, cs_b_use, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
MR, k,
|
||||
post_op_list,
|
||||
&post_ops_attr
|
||||
);
|
||||
|
||||
// Release pack buffers.
|
||||
if ( pack_a_buffer_bf16 != NULL )
|
||||
{
|
||||
bli_free_user( pack_a_buffer_bf16 );
|
||||
}
|
||||
if ( pack_b_buffer_bf16 != NULL )
|
||||
{
|
||||
bli_free_user( pack_b_buffer_bf16 );
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
// B should always be packed.
|
||||
LPGEMM_TINY(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
{
|
||||
|
||||
#if (defined(BLIS_KERNELS_ZEN4) && (!defined(LPGEMM_BF16_JIT)))
|
||||
// Handle using LPGEMV when m or/and n equal to 1
|
||||
// The avx512 check will be removed when avx2 kernels added in future
|
||||
if ( n == 1 )
|
||||
{
|
||||
lpgemv_rowvar_tiny_bf16bf16f32of32( m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha,
|
||||
beta,
|
||||
lcntx,
|
||||
post_op_list,
|
||||
c_downscale);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
dim_t NR = lcntx->blksz.NR;
|
||||
|
||||
const int16_t* a_use = NULL;
|
||||
dim_t cs_a_use = cs_a;
|
||||
dim_t rs_a_use = rs_a;
|
||||
dim_t a_block_stride = 0;
|
||||
|
||||
const int16_t* b_use = NULL;
|
||||
dim_t rs_b_use = rs_b;
|
||||
dim_t cs_b_use = cs_b;
|
||||
|
||||
dim_t rs_c_use = rs_c;
|
||||
dim_t rs_c_downscale = rs_c;
|
||||
|
||||
bfloat16* pack_a_buffer_bf16 = NULL;
|
||||
bfloat16* pack_b_buffer_bf16 = NULL;
|
||||
err_t err = BLIS_SUCCESS;
|
||||
siz_t mem_b_size_req = 0;
|
||||
siz_t mem_a_size_req = 0;
|
||||
dim_t packb_min_NR = 16;
|
||||
|
||||
// kc needs to be a multiple of 2 so that it can be used with dpbf16_ps
|
||||
// instruction. Padding is added in cases this condition is not
|
||||
// satisfied, and therefore the k offset used for packed/reordered
|
||||
// buffer needs to be updated.
|
||||
dim_t k_updated = k;
|
||||
k_updated += (k_updated & 0x1);
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr;
|
||||
post_ops_attr.c_stor_type = c_downscale;
|
||||
if ( c_downscale < F32 )
|
||||
{
|
||||
post_ops_attr.buf_downscale = c;
|
||||
}
|
||||
else
|
||||
{
|
||||
post_ops_attr.buf_downscale = NULL;
|
||||
}
|
||||
|
||||
bool is_first_k = TRUE;
|
||||
post_ops_attr.is_first_k = is_first_k;
|
||||
bool is_last_k = TRUE;
|
||||
post_ops_attr.is_last_k = is_last_k;
|
||||
|
||||
// k needs to be a multiple of 2 so that it can be used with dpbf16_ps
|
||||
// instruction. Padding is added in cases this condition is not
|
||||
// satisfied, and therefore the k offsets used for packed/reordered
|
||||
// buffers needs to be updated.
|
||||
dim_t k0_updated = k;
|
||||
k0_updated += (k0_updated & 0x1);
|
||||
|
||||
if ( mtag_b == PACK )
|
||||
{
|
||||
// nc0 needs to be a multiple of 16 since this gives maximum
|
||||
// vectorization. Packing B always results in buffers with width
|
||||
// which is a multiple of 16. Subsequently the nc0 offsets used
|
||||
// for packed/reordered buffers needs to be updated.
|
||||
dim_t nc0_updated = make_multiple_of_n( n, packb_min_NR );
|
||||
mem_b_size_req = sizeof( bfloat16 ) * nc0_updated * k0_updated;
|
||||
pack_b_buffer_bf16 =
|
||||
( bfloat16* )bli_malloc_user( mem_b_size_req, &err );
|
||||
|
||||
( ( pack_bf16 )lcntx->packb_fun_ptr )
|
||||
(
|
||||
pack_b_buffer_bf16,
|
||||
b,
|
||||
rs_b, cs_b,
|
||||
n, k,
|
||||
&rs_b_use, &cs_b_use
|
||||
);
|
||||
|
||||
b_use = pack_b_buffer_bf16;
|
||||
}
|
||||
else if ( mtag_b == REORDERED )
|
||||
{
|
||||
b_use = b;
|
||||
|
||||
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
|
||||
}
|
||||
|
||||
if ( mtag_a == UNPACKED )
|
||||
{
|
||||
a_use = a;
|
||||
|
||||
// bf16 kernel reads 2 elements, totalling 4 bytes in a
|
||||
// single broadcast for use in bf16 instruction.
|
||||
// Non bf16 based kernel requires update to this code.
|
||||
cs_a_use = 2;
|
||||
a_block_stride = rs_a;
|
||||
rs_a_use = rs_a;
|
||||
}
|
||||
else if ( mtag_a == PACK )
|
||||
{
|
||||
mem_a_size_req = sizeof( bfloat16 ) * m * k;
|
||||
pack_a_buffer_bf16 =
|
||||
( bfloat16* )bli_malloc_user( mem_a_size_req, &err );
|
||||
|
||||
( ( pack_bf16 )lcntx->packa_fun_ptr )
|
||||
(
|
||||
pack_a_buffer_bf16,
|
||||
a,
|
||||
rs_a, cs_a,
|
||||
m, k,
|
||||
&rs_a_use, &cs_a_use
|
||||
);
|
||||
|
||||
a_use = pack_a_buffer_bf16;
|
||||
a_block_stride = rs_a_use;
|
||||
}
|
||||
|
||||
for ( dim_t jr = 0; jr < n; jr += NR )
|
||||
{
|
||||
dim_t nr0 = bli_min( ( n - jr ), NR );
|
||||
|
||||
// Post ops meta attributes.
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_j = jr;
|
||||
post_ops_attr.rs_c_downscale = rs_c_downscale;
|
||||
|
||||
// Reorder/Packed B, Reorder/Packed/Unpacked A call.
|
||||
( ( lpgemm_rowvar_bf16 )lcntx->kern_fun_ptr )
|
||||
(
|
||||
m, nr0, k,
|
||||
a_use, rs_a_use, cs_a_use, a_block_stride,
|
||||
( b_use + ( jr * k0_updated ) ), rs_b_use, cs_b_use,
|
||||
( c + jr ), rs_c_use, 1,
|
||||
alpha, beta,
|
||||
post_op_list, post_ops_attr
|
||||
);
|
||||
}
|
||||
|
||||
// Release pack buffers.
|
||||
if ( pack_a_buffer_bf16 != NULL )
|
||||
{
|
||||
bli_free_user( pack_a_buffer_bf16 );
|
||||
}
|
||||
if ( pack_b_buffer_bf16 != NULL )
|
||||
{
|
||||
bli_free_user( pack_b_buffer_bf16 );
|
||||
}
|
||||
}
|
||||
@@ -73,9 +73,6 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
|
||||
inc_t rs_b_use = rs_b;
|
||||
inc_t cs_b_use = cs_b;
|
||||
|
||||
// Strides are updated based on matrix packing/reordering.
|
||||
float *c_use = ( float* )c;
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr;
|
||||
post_ops_attr.c_stor_type = c_downscale;
|
||||
if (c_downscale < F32) post_ops_attr.buf_downscale = c;
|
||||
@@ -131,7 +128,7 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
|
||||
m, k,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a,
|
||||
b_use, rs_b_use, cs_b_use, mtag_b,
|
||||
c_use, rs_c, cs_c,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
MR, k,
|
||||
post_op_list,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -223,4 +223,24 @@ BLIS_INLINE void adjust_B_panel_reordered_jc( dim_t* jc, dim_t panel_start )
|
||||
( *jc ) = panel_start;
|
||||
}
|
||||
|
||||
static inline bool is_single_thread( rntm_t* rntm_g )
|
||||
{
|
||||
bool is_st = FALSE;
|
||||
|
||||
dim_t n_threads = bli_rntm_num_threads( rntm_g );
|
||||
dim_t jc_ways = bli_rntm_jc_ways( rntm_g );
|
||||
dim_t ic_ways = bli_rntm_ic_ways( rntm_g );
|
||||
|
||||
ic_ways = ( ic_ways > 0 ) ? ic_ways : 1;
|
||||
jc_ways = ( jc_ways > 0 ) ? jc_ways : 1;
|
||||
|
||||
if ( ( n_threads == 1 ) ||
|
||||
( ( ic_ways * jc_ways ) == 1 ) )
|
||||
{
|
||||
is_st = TRUE;
|
||||
}
|
||||
|
||||
return is_st;
|
||||
}
|
||||
|
||||
#endif //LPGEMM_UTILS_H
|
||||
@@ -64,6 +64,7 @@ void lpgemm_rowvar_tiny_ ## LP_SFX \
|
||||
) \
|
||||
|
||||
LPGEMM_TINY(float,float,float,f32f32f32of32);
|
||||
LPGEMM_TINY(bfloat16,bfloat16,float,bf16bf16f32of32);
|
||||
|
||||
#define LPGEMM_5LOOP(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
@@ -152,6 +153,7 @@ void lpgemv_rowvar_tiny_ ## LP_SFX \
|
||||
) \
|
||||
|
||||
LPGEMV_TINY(float, float, float, f32f32f32of32);
|
||||
LPGEMV_TINY(bfloat16,bfloat16,float,bf16bf16f32of32);
|
||||
|
||||
#define LPGEMV(A_type, B_type, C_type, LP_SFX) \
|
||||
void lpgemv_rowvar_ ## LP_SFX \
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
r n n n r 121 1 1601 1601 1 1 f32f32f32of32:bias=na,matrix_mul=na
|
||||
r n n n r 13 1 16 16 1 1 f32f32f32of32:bias=na,matrix_mul=na
|
||||
r n n n r 36 64 16 16 64 64 f32f32f32of32:none
|
||||
r n n n r 1 48 16 16 48 48 f32f32f32of32:bias=na,matrix_mul=na
|
||||
r n n n r 121 1 1601 1601 1 1 bf16bf16f32obf16:bias=bf16,matrix_mul=bf16
|
||||
r n n n r 12 108 16 16 108 108 f32f32f32of32:bias=na,matrix_mul=na
|
||||
r n n n r 12 108 16 16 108 108 bf16bf16f32of32:bias=na,matrix_mul=na
|
||||
r n n n r 12 108 16 16 108 108 bf16bf16f32obf16:bias=bf16,matrix_mul=bf16
|
||||
r n n n r 36 54 16 16 64 64 f32f32f32of32:none
|
||||
r t n n r 1 128 64 1 128 128 *:none
|
||||
c n t n n 32 128 2 32 128 32 bf16bf16f32of32:bias=na,swish
|
||||
r n n n r 6 1 4 4 16 16 bf16s4f32of32:pre_op_scale=scalar,pre_op_scale_type=bf16,group_size=2
|
||||
|
||||
Reference in New Issue
Block a user