mirror of
https://github.com/amd/blis.git
synced 2026-05-03 05:51:13 +00:00
Introduced support for GEMV operations with group-level symmetric quantization for the S8S8S32032 API. Framework Changes: - Added macro definitions and function prototypes for GEMV with symmetric quantization in lpgemm_5loop_interface_apis.h and lpgemm_kernels.h. - LPGEMV_M_EQ1_KERN2 for the lpgemv_m_one_s8s8s32os32_sym_quant kernel, and - LPGEMV_N_EQ1_KERN2 for the lpgemv_n_one_s8s8s32os32_sym_quant kernel. - Implemented the main GEMV framework for symmetric quantization in lpgemm_s8s8s32_sym_quant.c. Kernel Changes: - lpgemv_m_one_s8s8s32os32_sym_quant for handling the case where M = 1 and implemented in lpgemv_m_kernel_s8_grp_amd512vnni.c. - lpgemv_n_one_s8s8s32os32_sym_quant for handling the case where N = 1 and implemented in lpgemv_n_kernel_s8_grp_amd512vnni.c. - Updated the buffer reordering logic for group quantization for N=1 cases in aocl_gemm_s8s8s32os32_utils.c. Notes - Ensure that group_size is a factor of both K (and KC when K > KC). - The B matrix must be provided in reordered format (mtag_b == REORDERED). AMD-Internal: [SWLCSG-3604]
982 lines
41 KiB
C
982 lines
41 KiB
C
/*
|
|
|
|
BLIS
|
|
An object-based framework for developing high-performance BLAS-like
|
|
libraries.
|
|
|
|
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are
|
|
met:
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
- Neither the name(s) of the copyright holder(s) nor the names of its
|
|
contributors may be used to endorse or promote products derived
|
|
from this software without specific prior written permission.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*/
|
|
|
|
#ifndef BLIS_LPGEMM_KERN_H
|
|
#define BLIS_LPGEMM_KERN_H
|
|
|
|
#include "lpgemm_post_ops.h"
|
|
#include "aocl_bf16_type.h"
|
|
|
|
// Disable BF16 kernel in cases where compilers support other avx 512
|
|
// features except BF16 ISA.
|
|
#if ( defined( BLIS_GCC ) && ( ( __GNUC__ < 11 ) || \
|
|
( ( __GNUC__ == 11 ) && ( __GNUC_MINOR__ < 2 ) ) ) && defined(BLIS_KERNELS_ZEN4) )
|
|
#define LPGEMM_BF16_JIT
|
|
#define BPREFETCH_JIT
|
|
//#define DUMP_JIT_CODE
|
|
#endif
|
|
|
|
typedef void (*lpgemm_m_fringe_f32_ker_ft)
|
|
(
|
|
const dim_t k0,
|
|
const float* a,
|
|
const dim_t rs_a,
|
|
const dim_t cs_a,
|
|
const float* b,
|
|
const dim_t rs_b,
|
|
const dim_t cs_b,
|
|
float* c,
|
|
const dim_t rs_c,
|
|
const float alpha,
|
|
const float beta,
|
|
lpgemm_post_op* post_ops_list,
|
|
lpgemm_post_op_attr post_ops_attr
|
|
);
|
|
|
|
typedef void (*lpgemm_n_fringe_f32_ker_ft)
|
|
(
|
|
const dim_t m0,
|
|
const dim_t k0,
|
|
const float* a,
|
|
const dim_t rs_a,
|
|
const dim_t cs_a,
|
|
const dim_t ps_a,
|
|
const float* b,
|
|
const dim_t rs_b,
|
|
const dim_t cs_b,
|
|
float* c,
|
|
const dim_t rs_c,
|
|
const float alpha,
|
|
const float beta,
|
|
lpgemm_post_op* post_ops_list,
|
|
lpgemm_post_op_attr post_ops_attr
|
|
);
|
|
|
|
typedef void (*lpgemm_mn_fringe_f32_mask_ker_ft)
|
|
(
|
|
const dim_t k0,
|
|
const float* a,
|
|
const dim_t rs_a,
|
|
const dim_t cs_a,
|
|
const float* b,
|
|
const dim_t rs_b,
|
|
const dim_t cs_b,
|
|
float* c,
|
|
const dim_t rs_c,
|
|
const float alpha,
|
|
const float beta,
|
|
const dim_t n0_rem,
|
|
lpgemm_post_op* post_ops_list,
|
|
lpgemm_post_op_attr post_ops_attr
|
|
);
|
|
|
|
#define LPGEMM_MAIN_KERN(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t m0, \
|
|
const dim_t n0, \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const dim_t ps_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const dim_t cs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64);
|
|
LPGEMM_MAIN_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x64);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m_np);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m_rd);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x8m_rd);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x4m_rd);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x2m_rd);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x1m_rd);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_256_6x64m);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m_np);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m_rd);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x48m_rd);
|
|
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x32m_rd);
|
|
LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64);
|
|
|
|
|
|
#define LPGEMM_MAIN_KERN1(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t m0, \
|
|
const dim_t n0, \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const dim_t ps_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const dim_t cs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr, \
|
|
lpgemm_pre_op_attr pre_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_MAIN_KERN1(bfloat16,int8_t,float,bf16s4f32of32_6x64m);
|
|
|
|
#define LPGEMM_MAIN_KERN2(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t m0, \
|
|
const dim_t n0, \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const dim_t ps_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
float* c, \
|
|
const dim_t rs_c, \
|
|
const dim_t cs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_grp_post_op_attr grp_post_ops_attr, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_MAIN_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_6x64m_sym_quant);
|
|
|
|
#define LPGEMM_M_RD_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const dim_t cs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x64_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x64_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x64_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x64_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x64_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x48_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x48_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x48_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x48_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x48_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x32_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x32_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x32_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x32_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x32_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x16_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x16_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x8_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x8_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x4_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x4_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x2_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x1_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x2_rd);
|
|
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x1_rd);
|
|
|
|
#define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64);
|
|
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64);
|
|
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64);
|
|
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64);
|
|
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64);
|
|
|
|
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x64);
|
|
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x64);
|
|
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x64);
|
|
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x64);
|
|
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x64);
|
|
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x64);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x64);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x64);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x64);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x64);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x48);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x48);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x48);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x48);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x48);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x32);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x32);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x32);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x32);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x32);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x16);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x16);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x16);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x16);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x16);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x16);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x16);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x16);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x16);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x16);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x8);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x8);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x8);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x8);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x8);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x4);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x4);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x4);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x4);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x4);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x2);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x2);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x2);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x2);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x2);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x1);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x1);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x1);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x1);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x1);
|
|
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x64_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x64_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x64_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x64_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x64_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x48_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x48_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x48_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x48_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x48_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x32_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x32_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x32_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x32_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x32_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x16_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x16_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x16_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x16_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x16_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x16_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x16_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x16_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x16_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x16_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x8_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x8_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x8_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x8_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x8_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x4_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x4_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x4_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x4_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x4_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x2_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x2_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x2_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x2_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x2_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x1_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x1_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x1_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x1_np);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x1_np);
|
|
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_256_5x32);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_256_4x32);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_256_3x32);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_256_2x32);
|
|
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_256_1x32);
|
|
|
|
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x64);
|
|
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x64);
|
|
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x64);
|
|
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x64);
|
|
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x64);
|
|
|
|
|
|
#define LPGEMM_M_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr, \
|
|
lpgemm_pre_op_attr pre_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_M_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_5x64);
|
|
LPGEMM_M_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_4x64);
|
|
LPGEMM_M_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_3x64);
|
|
LPGEMM_M_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_2x64);
|
|
LPGEMM_M_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_1x64);
|
|
|
|
#define LPGEMM_M_FRINGE_KERN2(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
float* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_grp_post_op_attr grp_post_ops_attr, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_M_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_5x64_sym_quant);
|
|
LPGEMM_M_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_4x64_sym_quant);
|
|
LPGEMM_M_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_3x64_sym_quant);
|
|
LPGEMM_M_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_2x64_sym_quant);
|
|
LPGEMM_M_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_1x64_sym_quant);
|
|
|
|
#define LPGEMM_N_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t m0, \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const dim_t ps_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16);
|
|
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12x16);
|
|
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32);
|
|
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_9x32);
|
|
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48);
|
|
|
|
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x16);
|
|
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x32);
|
|
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x48);
|
|
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x48m);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x32m);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x16m);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x8m);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x4m);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x2m);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x1m);
|
|
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x48m_np);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x32m_np);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x16m_np);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x8m_np);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x4m_np);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x2m_np);
|
|
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x1m_np);
|
|
|
|
|
|
LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x16);
|
|
LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x32);
|
|
LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x48);
|
|
|
|
|
|
#define LPGEMM_N_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t m0, \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const dim_t ps_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr, \
|
|
lpgemm_pre_op_attr pre_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_N_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_6x16m);
|
|
LPGEMM_N_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_6x32m);
|
|
LPGEMM_N_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_6x48m);
|
|
|
|
#define LPGEMM_N_FRINGE_KERN2(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t m0, \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const dim_t ps_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
float* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_grp_post_op_attr grp_post_ops_attr, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_N_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_6x48_sym_quant);
|
|
LPGEMM_N_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_6x32_sym_quant);
|
|
LPGEMM_N_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_6x16_sym_quant);
|
|
|
|
#define LPGEMM_N_LT_NR0_FRINGE_KERN2(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t m0, \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const dim_t ps_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
float* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
const dim_t n0_rem, \
|
|
lpgemm_grp_post_op_attr grp_post_ops_attr, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_N_LT_NR0_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16_sym_quant);
|
|
|
|
#define LPGEMM_N_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t m0, \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const dim_t ps_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
const dim_t n0_rem, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16);
|
|
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16);
|
|
|
|
LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6xlt16);
|
|
LPGEMM_N_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6xlt16m);
|
|
LPGEMM_N_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_6xlt8m);
|
|
LPGEMM_N_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6xlt16m_np);
|
|
LPGEMM_N_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_6xlt8m_np);
|
|
|
|
LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16);
|
|
|
|
#define LPGEMM_N_LT_NR0_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t m0, \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const dim_t ps_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
const dim_t n0_rem, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr, \
|
|
lpgemm_pre_op_attr pre_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_N_LT_NR0_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_6xlt16m);
|
|
|
|
|
|
#define LPGEMM_MN_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48);
|
|
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48);
|
|
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x16);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x16);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x16);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x16);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x16);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x32);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x32);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x32);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x32);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x32);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x48);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x48);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x48);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x48);
|
|
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x48);
|
|
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x16);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x16);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x16);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x16);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x16);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x32);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x32);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x32);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x32);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x32);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x48);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x48);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x48);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x48);
|
|
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x48);
|
|
|
|
#define LPGEMM_MN_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr, \
|
|
lpgemm_pre_op_attr pre_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_5x16);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_4x16);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_3x16);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_2x16);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_1x16);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_5x32);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_4x32);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_3x32);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_2x32);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_1x32);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_5x48);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_4x48);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_3x48);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_2x48);
|
|
LPGEMM_MN_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_1x48);
|
|
|
|
#define LPGEMM_MN_FRINGE_KERN2(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
float* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
lpgemm_grp_post_op_attr grp_post_ops_attr, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_5x48_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_4x48_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_3x48_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_2x48_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_1x48_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_5x32_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_4x32_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_3x32_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_2x32_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_1x32_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_5x16_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_4x16_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_3x16_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_2x16_sym_quant);
|
|
LPGEMM_MN_FRINGE_KERN2(int8_t, int8_t, int32_t, s8s8s32os32_1x16_sym_quant);
|
|
|
|
|
|
#define LPGEMM_MN_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
const dim_t n0_rem, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16);
|
|
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1xlt16);
|
|
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1xlt16);
|
|
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1xlt16);
|
|
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5xlt16_np);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4xlt16_np);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3xlt16_np);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2xlt16_np);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1xlt16_np);
|
|
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_5xlt8);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_4xlt8);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_3xlt8);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_2xlt8);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_1xlt8);
|
|
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_5xlt8_np);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_4xlt8_np);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_3xlt8_np);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_2xlt8_np);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_1xlt8_np);
|
|
|
|
|
|
|
|
#define LPGEMM_MN_LT_NR0_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
C_type* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
const dim_t n0_rem, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr, \
|
|
lpgemm_pre_op_attr pre_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_5xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_4xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_3xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_2xlt16);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN1(bfloat16,int8_t,float,bf16s4f32of32_1xlt16);
|
|
|
|
#define LPGEMM_MN_LT_NR0_FRINGE_KERN2(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemm_rowvar_ ## LP_SFX \
|
|
( \
|
|
const dim_t k0, \
|
|
const A_type* a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const B_type* b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
float* c, \
|
|
const dim_t rs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
const dim_t n0_rem, \
|
|
lpgemm_grp_post_op_attr grp_post_ops_attr, \
|
|
lpgemm_post_op* post_ops_list, \
|
|
lpgemm_post_op_attr post_ops_attr \
|
|
) \
|
|
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_5xlt16_sym_quant);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_4xlt16_sym_quant);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16_sym_quant);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16_sym_quant);
|
|
LPGEMM_MN_LT_NR0_FRINGE_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_1xlt16_sym_quant);
|
|
|
|
#define LPGEMV_M_EQ1_KERN(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemv_m_one_ ## LP_SFX \
|
|
( \
|
|
const dim_t n0, \
|
|
const dim_t k, \
|
|
const A_type *a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const AOCL_MEMORY_TAG mtag_a, \
|
|
const B_type *b, \
|
|
dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
const AOCL_MEMORY_TAG mtag_b, \
|
|
C_type *c, \
|
|
const dim_t rs_c, \
|
|
const dim_t cs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
dim_t NR, \
|
|
const dim_t KC, \
|
|
const dim_t n_sub_updated, \
|
|
const dim_t jc_cur_loop_rem, \
|
|
lpgemm_post_op *post_op, \
|
|
lpgemm_post_op_attr *post_op_attr \
|
|
) \
|
|
|
|
LPGEMV_M_EQ1_KERN(float, float, float,f32f32f32of32);
|
|
LPGEMV_M_EQ1_KERN(float, float, float,f32f32f32of32_avx2);
|
|
LPGEMV_M_EQ1_KERN(float, float, float,f32f32f32of32_avx512_256);
|
|
LPGEMV_M_EQ1_KERN(bfloat16,bfloat16,float,bf16bf16f32of32);
|
|
LPGEMV_M_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
|
|
LPGEMV_M_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);
|
|
|
|
|
|
#define LPGEMV_M_EQ1_KERN2(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemv_m_one_ ## LP_SFX \
|
|
( \
|
|
const dim_t n0, \
|
|
const dim_t k, \
|
|
const A_type *a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const AOCL_MEMORY_TAG mtag_a, \
|
|
const B_type *b, \
|
|
dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
const AOCL_MEMORY_TAG mtag_b, \
|
|
float *c, \
|
|
const dim_t rs_c, \
|
|
const dim_t cs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
dim_t NR, \
|
|
const dim_t KC, \
|
|
const dim_t n_sub_updated, \
|
|
const dim_t jc_cur_loop_rem, \
|
|
lpgemm_grp_post_op_attr grp_post_ops_attr, \
|
|
lpgemm_post_op *post_op, \
|
|
lpgemm_post_op_attr *post_op_attr \
|
|
) \
|
|
|
|
LPGEMV_M_EQ1_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_sym_quant);
|
|
|
|
#define LPGEMV_N_EQ1_KERN(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemv_n_one_ ## LP_SFX \
|
|
( \
|
|
const dim_t m0, \
|
|
const dim_t k, \
|
|
const A_type *a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const AOCL_MEMORY_TAG mtag_a, \
|
|
const B_type *b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
const AOCL_MEMORY_TAG mtag_b, \
|
|
C_type *c, \
|
|
const dim_t rs_c, \
|
|
const dim_t cs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
const dim_t MR, \
|
|
const dim_t KC, \
|
|
lpgemm_post_op *post_op, \
|
|
lpgemm_post_op_attr *post_op_attr \
|
|
) \
|
|
|
|
LPGEMV_N_EQ1_KERN(float, float, float,f32f32f32of32);
|
|
LPGEMV_N_EQ1_KERN(float, float, float,f32f32f32of32_avx2);
|
|
LPGEMV_N_EQ1_KERN(float, float, float,f32f32f32of32_avx512_256);
|
|
LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float,bf16bf16f32of32);
|
|
LPGEMV_N_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
|
|
LPGEMV_N_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);
|
|
|
|
|
|
#define LPGEMV_N_EQ1_KERN2(A_type,B_type,C_type,LP_SFX) \
|
|
void lpgemv_n_one_ ## LP_SFX \
|
|
( \
|
|
const dim_t m0, \
|
|
const dim_t k, \
|
|
const A_type *a, \
|
|
const dim_t rs_a, \
|
|
const dim_t cs_a, \
|
|
const AOCL_MEMORY_TAG mtag_a, \
|
|
const B_type *b, \
|
|
const dim_t rs_b, \
|
|
const dim_t cs_b, \
|
|
const AOCL_MEMORY_TAG mtag_b, \
|
|
float *c, \
|
|
const dim_t rs_c, \
|
|
const dim_t cs_c, \
|
|
const C_type alpha, \
|
|
const C_type beta, \
|
|
const dim_t MR, \
|
|
const dim_t KC, \
|
|
lpgemm_grp_post_op_attr grp_post_ops_attr, \
|
|
lpgemm_post_op *post_op, \
|
|
lpgemm_post_op_attr *post_op_attr \
|
|
) \
|
|
|
|
LPGEMV_N_EQ1_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_sym_quant);
|
|
|
|
#endif //BLIS_LPGEMM_KERN_H
|