Files
blis/addon/aocl_gemm/JIT/lpgemm_jit_bf16.cpp
Meghana Vankadari c1e063e65c Fix for offset issue while reading constants from JIT code
Details:
- For a variable x, Using address of x in an instruction throws
  exception if the difference between &x and access position is
  larger than 2 GiB. To solve this issue all variables are stored
  within the JIT code section and are accessed using relative addressing.

- Fixed a bug in B matrix pack function for s8s8s32os32 API.
- Fixed a bug in JIT code to apply bias on col-major matrices.

AMD-Internal: [SWLCSG-2820]
Change-Id: I82f117a0422c794cb9b1a4d65a89d60de4adfd96
2024-06-24 07:14:15 -04:00

1510 lines
47 KiB
C++

/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "lpgemm_jit_bf16.h"
// push callee-save registers to stack
void bli_lpgemm_jit:: preamble()
{
push(rbp);
push(rbx);
push(r12);
push(r13);
push(r14);
push(r15);
}
// pop the callee-save registers before returning from function.
void bli_lpgemm_jit:: postamble()
{
pop(r15);
pop(r14);
pop(r13);
pop(r12);
pop(rbx);
pop(rbp);
vzeroupper();
}
void bli_lpgemm_jit:: store_zmms_in_stack( dim_t reg_start_idx,
dim_t num_regs,
dim_t stack_off
)
{
for( dim_t idx = 0; idx < num_regs; idx++ )
{
vmovups( ptr[ rsp + zmm_stack_top + stack_off + idx * 64],
Zmm( reg_start_idx + idx ) );
}
}
void bli_lpgemm_jit:: get_zmms_from_stack( dim_t reg_start_idx,
dim_t num_regs,
dim_t stack_off
)
{
for( dim_t idx = 0; idx < num_regs; idx++ )
{
vmovups( Zmm( reg_start_idx + idx ),
ptr[ rsp + zmm_stack_top + stack_off + idx * 64] );
}
}
//Zero out the registers that will be used for storing accumulated values.
// For a given micro-kernel dimension MRxNR,
// considering a row-major kernel, we need (MR * (NR / num_elems per reg))
// registers to store accumulated values.
void bli_lpgemm_jit:: reg_init( dim_t m_dim, dim_t n_dim )
{
vxorps( Zmm( fma_start_idx ), Zmm( fma_start_idx ));
for( dim_t m = fma_start_idx + 1; m < 32; m++ )
{
vmovaps( Zmm( m ), Zmm( fma_start_idx ) );
}
}
// This code replicates the existing bf16 kernel.
// Hence unroll factor is hardcoded to be 2.
// To-DO: Make unroll factor as an configurable parameter.
void bli_lpgemm_jit:: kernel_unroll( dim_t m_dim, dim_t n_dim )
{
dim_t reg_num;
// Broadcast elements of A matrix
vpbroadcastd( Zmm( bcst_start_idx ), ptr[ rax ] );
// load elements of B matrix into registers
for( dim_t n = 0; n < num_full_loads; n++ )
vmovdqu16( Zmm( load_start_idx + n ), ptr[ rbx + n * 64 ] );
// In case of last load with fringe part, use mask
if( n_rem )
vmovdqu16( Zmm( load_start_idx + num_full_loads )
| k3 | T_z, ptr[ rbx + num_full_loads * 64 ] );
add( rbx, r10 );
for( dim_t m = 0; m < m_dim; m++ )
{
// broadcast elements of A matrix.
// Using 2 ZMM registers for broadcast.
if( m < ( m_dim - 1 ) )
{
switch ( m + 1 )
{
case 1:
case 4:
case 2: vpbroadcastd( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
ptr[ rax + r8 * ( m + 1 ) ] );
break;
case 3: vpbroadcastd( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
ptr[ rax + r13 ] );
break;
case 5: vpbroadcastd( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
ptr[ rax + r15 ] );
break;
default:
break;
}
}
// move to next column
if( m == ( m_dim - 1 ) ) add( rax, r9 );
// Generate FMA instructions.
for( dim_t n = 0; n < num_loads; n++ )
{
reg_num = fma_start_idx + ( m * num_loads ) + n;
vdpbf16ps( Zmm( reg_num ), Zmm( bcst_start_idx + m % 2 ),
Zmm( load_start_idx + n ) );
}
}
}
void bli_lpgemm_jit:: k_fringe_loop( dim_t m_dim, dim_t n_dim )
{
dim_t reg_num;
// Broadcast elements of A matrix
vpbroadcastw( Zmm( bcst_start_idx ), ptr[ rax ] );
// load elements of B matrix into registers
for( dim_t n = 0; n < num_full_loads; n++ )
vmovdqu16( Zmm( load_start_idx + n ), ptr[ rbx + n * 64 ] );
// In case of last load with fringe part, use mask
if( n_rem )
vmovdqu16( Zmm( load_start_idx + num_full_loads )
| k3 | T_z, ptr[ rbx + num_full_loads * 64 ] );
for( dim_t m = 0; m < m_dim; m++ )
{
if( m < ( m_dim - 1 ) )
{
// broadcast elements of A matrix.
// Using 2 ZMM registers for broadcast.
switch ( m + 1 )
{
case 1:
case 4:
case 2: vpbroadcastw( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
ptr[ rax + r8 * ( m + 1 ) ] );
break;
case 3: vpbroadcastw( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
ptr[ rax + r13 ] );
break;
case 5: vpbroadcastw( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
ptr[ rax + r15 ] );
break;
default:
break;
}
}
// Generate FMA instructions.
for( dim_t n = 0; n < num_loads; n++ )
{
reg_num = fma_start_idx + ( m * num_loads ) + n;
vdpbf16ps( Zmm( reg_num ), Zmm( bcst_start_idx + m % 2 ),
Zmm( load_start_idx + n ) );
}
}
}
// Generate required number of mul instructions for scaling with alpha.
void bli_lpgemm_jit:: scale_alpha( dim_t m_dim, dim_t n_dim )
{
for( dim_t reg_num = fma_start_idx; reg_num < 32; reg_num++ )
vmulps( Zmm( reg_num ), Zmm( alpha_reg ), Zmm( reg_num ) );
}
// Scale C by beta and store when beta is a generic value.
void bli_lpgemm_jit:: f32_f32_beta_op( dim_t m_dim, dim_t n_dim)
{
dim_t reg_num;
for( dim_t m = 0; m < m_dim; m++ )
{
if( m > 0 ) add( rcx, rdi );
for( dim_t n = 0; n < num_full_loads; n++ )
{
reg_num = fma_start_idx + ( m * num_loads ) + n;
vmovups( Zmm( load_start_idx + n ) , ptr[ rcx + n * 64 ] );
vfmadd231ps( Zmm( reg_num ), Zmm( load_start_idx + n ),
Zmm( beta_reg ) );
}
// Use mask in case of n_fringe.
if( n_rem )
{
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
vmovups( Zmm( load_start_idx + num_full_loads ) | k4 | T_z,
ptr[ rcx + num_full_loads * 64 ] );
vfmadd231ps( Zmm( reg_num ),
Zmm( load_start_idx + num_full_loads ),
Zmm( beta_reg ) );
}
}
}
void bli_lpgemm_jit:: bf16_f32_beta_op( dim_t m_dim, dim_t n_dim )
{
dim_t reg_num;
mov( rcx, ptr[ rsp + stack_off_buf_downscale ] );
mov( rax, ptr[ rsp + stack_off_postop + offsetof( lpgemm_post_op_attr,
rs_c_downscale ) ] );
// rs_c_downscale *= sizeof(bfloat16)
lea( rax, ptr[ rax * 2 ] );
mov( rsi, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
mov( rbx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) )
imul( rsi, rax );
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) )
// + post_op_c_j * sizeof(bfloat16)
lea( rsi, ptr[ rsi + rbx * 2 ] );
add( rcx, rsi );
for( dim_t m = 0; m < m_dim; m++ )
{
for( dim_t n = 0; n < num_full_loads; n++ )
{
reg_num = fma_start_idx + ( m * num_loads ) + n;
// convert from 16 bit elements to 32 bit elements
vpmovsxwd( Zmm( load_start_idx + n ), ptr[ rcx + n * 32 ] );
// Shift left by 16 bits
vpslld( Zmm( load_start_idx + n ), Zmm( load_start_idx + n ),
0x10 );
// fma with beta
vfmadd231ps( Zmm( reg_num ), Zmm( beta_reg ),
Zmm( load_start_idx + n ) );
}
if( n_rem )
{
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
// load the bf16 elements from the downscale buffer using mask.
vmovdqu16( Ymm( load_start_idx + num_full_loads ) | k4 | T_z,
ptr[rcx + num_full_loads * 32 ] );
// convert from 16 bit elements to 32 bit elements
vpmovsxwd( Zmm( load_start_idx + num_full_loads ),
Ymm( load_start_idx + num_full_loads ) );
// Shift left by 16 bits
vpslld( Zmm( load_start_idx + num_full_loads ),
Zmm( load_start_idx + num_full_loads ), 0x10 );
// fma with beta
vfmadd231ps( Zmm( reg_num ), Zmm( beta_reg ),
Zmm( load_start_idx + num_full_loads ) );
}
// move to next row
add( rcx, rax );
}
}
void bli_lpgemm_jit:: clip_f32( dim_t m_dim, dim_t n_dim )
{
dim_t min_reg = load_start_idx;
dim_t max_reg = bcst_start_idx;
// min reg
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] );
vbroadcastss( Zmm( min_reg ), ptr[ rax ] );
// max reg
mov( rbx, ptr[ rdx + offsetof( lpgemm_post_op, op_args3 ) ] );
vbroadcastss( Zmm( max_reg ), ptr[ rbx ] );
for( dim_t m = fma_start_idx; m < 32; m++ )
{
vmaxps( Zmm( m ), Zmm( m ), Zmm( min_reg ) );
vminps( Zmm( m ), Zmm( m ), Zmm( max_reg ) );
}
}
void bli_lpgemm_jit:: bf16_f32_matrix_add( dim_t m_dim, dim_t n_dim )
{
dim_t reg_num;
// rcx = matrix ptr
mov( rcx, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] );
// rax = ldm
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args3 ) ] );
mov( rax, ptr[ rax ] );
// ldm *= sizeof(bfloat16)
lea( rax, ptr[ rax * 2 ] );
mov( rsi, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
mov( rbx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) )
imul( rsi, rax );
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) )
// + post_op_c_j * sizeof(bfloat16)
lea( rsi, ptr[ rsi + rbx * 2 ] );
add( rcx, rsi );
for( dim_t m = 0; m < m_dim; m++ )
{
for( dim_t n = 0; n < num_full_loads; n++ )
{
reg_num = fma_start_idx + ( m * num_loads ) + n;
// convert from 16 bit elements to 32 bit elements
vpmovsxwd( Zmm( load_start_idx + n ), ptr[ rcx + n*32 ] );
// Shift left by 16 bits
vpslld( Zmm( load_start_idx + n ), Zmm( load_start_idx + n ),
0x10 );
vaddps( Zmm( reg_num ), Zmm( reg_num ),
Zmm( load_start_idx + n ) );
}
if( n_rem )
{
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
// load the bf16 elements from the downscale buffer using mask.
vmovdqu16( Ymm( load_start_idx + num_full_loads ) | k4 | T_z,
ptr[rcx + num_full_loads * 32 ] );
// convert from 16 bit elements to 32 bit elements
vpmovsxwd( Zmm( load_start_idx + num_full_loads ),
Ymm( load_start_idx + num_full_loads ) );
// Shift left by 16 bits
vpslld( Zmm(load_start_idx + num_full_loads ),
Zmm( load_start_idx + num_full_loads ), 0x10 );
vaddps( Zmm( reg_num ), Zmm( reg_num ),
Zmm( load_start_idx + num_full_loads ) );
}
// move to next row
add( rcx, rax );
}
}
void bli_lpgemm_jit:: f32_f32_matrix_add( dim_t m_dim, dim_t n_dim )
{
dim_t reg_num;
// rcx = matrix ptr
mov( rcx, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] );
// rax = ldm
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args3 ) ] );
mov( rax, ptr[ rax ] );
// ldm *= sizeof(float)
lea( rax, ptr[ rax * 4 ] );
mov( rsi, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
mov( rbx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(float) )
imul( rsi, rax );
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(float) )
// + post_op_c_j * sizeof(float)
lea( rsi, ptr[ rsi + rbx * 4] );
add( rcx, rsi );
for( dim_t m = 0; m < m_dim; m++ )
{
for( dim_t n = 0; n < num_full_loads; n++)
{
reg_num = fma_start_idx + ( m * num_loads ) + n;
vmovups(Zmm( load_start_idx + n ), ptr[ rcx + n * 64 ] );
vaddps( Zmm( reg_num ), Zmm( reg_num ),
Zmm( load_start_idx + n ) );
}
if( n_rem )
{
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
vmovups( Zmm( load_start_idx + num_full_loads ) | k4 | T_z,
ptr[ rcx + num_full_loads * 64 ] );
vaddps( Zmm( reg_num ), Zmm( reg_num ),
Zmm( load_start_idx + num_full_loads ) );
}
// move to next row
add( rcx, rax );
}
}
void bli_lpgemm_jit:: bias_row_major( dim_t m_dim, dim_t n_dim )
{
dim_t reg_num;
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] );
mov( rbx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
mov( rcx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, c_stor_type ) ] );
cmp( rcx, 4 );
je( "BIAS_BF16_ROW_MAJOR", T_NEAR );
// postops_c_j *= sizeof(float)
lea( rbx, ptr[ rbx * 4 ] );
add( rax, rbx );
for( dim_t n = 0; n < num_full_loads; n++ )
{
vmovups( Zmm( load_start_idx + n ), ptr[ rax + n * 64 ] );
}
if( n_rem )
{
vmovups( Zmm( load_start_idx + num_full_loads ) | k4,
ptr[ rax + num_full_loads * 64 ] );
}
jmp( "POST_BIAS_BF16_ROW_MAJOR", T_NEAR );
L( "BIAS_BF16_ROW_MAJOR" );
// postops_c_j *= sizeof(bfloat16)
lea( rbx, ptr[ rbx * 2 ] );
add( rax, rbx );
for( dim_t n = 0; n < num_full_loads; n++ )
{
// convert from 16 bit elements to 32 bit elements
vpmovsxwd( Zmm( load_start_idx + n ), ptr[ rax + n * 32 ] );
// Shift left by 16 bits
vpslld( Zmm( load_start_idx + n ), Zmm( load_start_idx + n ), 0x10 );
}
if( n_rem )
{
// load the bf16 elements from the downscale buffer using mask.
vmovdqu16( Ymm( load_start_idx + num_full_loads ) | k4 | T_z,
ptr[rax + num_full_loads * 32 ] );
// convert from 16 bit elements to 32 bit elements
vpmovsxwd( Zmm( load_start_idx + num_full_loads ),
Ymm( load_start_idx + num_full_loads ) );
// Shift left by 16 bits
vpslld( Zmm( load_start_idx + num_full_loads ),
Zmm( load_start_idx + num_full_loads ), 0x10 );
}
L( "POST_BIAS_BF16_ROW_MAJOR" );
for( dim_t m = 0; m < m_dim; m++ )
{
for( dim_t n = 0; n < num_loads; n++ )
{
reg_num = fma_start_idx + ( m * num_loads ) + n;
vaddps( Zmm( reg_num ), Zmm( reg_num ),
Zmm( load_start_idx + n ) );
}
}
}
void bli_lpgemm_jit:: bias_col_major( dim_t m_dim, dim_t n_dim )
{
dim_t reg_num;
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] );
mov( rbx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
mov( rcx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, c_stor_type ) ] );
cmp( rcx, 4 );
je( "BIAS_BF16_COL_MAJOR", T_NEAR );
// postops_c_i *= sizeof(float)
lea( rbx, ptr[ rbx * 4 ] );
add( rax, rbx );
for( dim_t m = 0; m < m_dim; m++ )
{
vbroadcastss( Zmm( alpha_reg ), ptr[ rax + m * 4 ] );
for( dim_t n = 0; n < num_loads; n++ )
{
reg_num = fma_start_idx + ( m * num_loads ) + n;
vaddps( Zmm( reg_num ), Zmm( reg_num ), Zmm( alpha_reg ) );
}
}
jmp( "POST_BIAS_BF16_COL_MAJOR", T_NEAR );
L( "BIAS_BF16_COL_MAJOR" );
// postops_c_i *= sizeof(bfloat16)
lea( rbx, ptr[ rbx * 2 ] );
add( rax, rbx );
for( dim_t m = 0; m < m_dim; m++ )
{
vpbroadcastw( Zmm( alpha_reg ), ptr[ rax + m * 4 ] );
// convert from 16 bit elements to 32 bit elements
vpmovsxwd( Zmm( alpha_reg ), Ymm( alpha_reg ) );
// Shift left by 16 bits
vpslld( Zmm( alpha_reg ), Zmm( alpha_reg ), 0x10 );
for( dim_t n = 0; n < num_loads; n++ )
{
reg_num = fma_start_idx + ( m * num_loads ) + n;
vaddps( Zmm( reg_num ), Zmm( reg_num ), Zmm( alpha_reg ) );
}
}
L( "POST_BIAS_BF16_COL_MAJOR" );
}
void bli_lpgemm_jit:: relu( dim_t m_dim, dim_t n_dim )
{
dim_t scratch_reg = bcst_start_idx;
vpxorq(Zmm( scratch_reg ), Zmm( scratch_reg ), Zmm( scratch_reg ) );
for( dim_t m = fma_start_idx; m < 32; m++ )
{
vmaxps( Zmm( m ), Zmm( m ), Zmm( scratch_reg ) );
}
}
void bli_lpgemm_jit:: relu_scale( dim_t m_dim, dim_t n_dim )
{
dim_t zero_reg = load_start_idx;
dim_t scale_factor = bcst_start_idx;
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] );
vbroadcastss( Zmm( scale_factor ), ptr[ rax ] );
vpxorq( Zmm( zero_reg ), Zmm( zero_reg ), Zmm( zero_reg ) );
for( dim_t m = fma_start_idx; m < 32; m++ )
{
vcmpps( k5, Zmm( m ), Zmm( zero_reg ), 0x02 );
vmulps( Zmm( m ) | k5, Zmm( m ), Zmm( scale_factor ) );
}
}
void bli_lpgemm_jit::apply_post_ops_in_high_reg_pressure
(
const dim_t num_post_op_regs,
std::function< void( dim_t ) > op_fn
)
{
dim_t num_push_regs = num_post_op_regs - fma_start_idx ;
// If number of registers required to compute pots op is more than
// registers available, then push some accum registers to stack
// and use them to compute gelu.
store_zmms_in_stack( fma_start_idx, num_push_regs, 0 );
dim_t post_op_start = num_push_regs > 0 ? fma_start_idx + num_push_regs
: fma_start_idx;
// operate on non-pushed regs
for( dim_t reg = post_op_start; reg < 32; reg++ )
{
op_fn( reg );
}
// Push num_push_regs number of registers from last to stack and
// replace them with the items that were pushed earlier
// and compute on them.
store_zmms_in_stack( 32 - num_push_regs, num_push_regs,
num_push_regs * 64 );
get_zmms_from_stack( 32 - num_push_regs, num_push_regs, 0);
for( dim_t reg = 0; reg < num_push_regs; reg++ )
{
op_fn( 32 - num_push_regs + reg );
}
for( dim_t reg = 0; reg < num_push_regs; reg++ )
vmovups( Zmm( fma_start_idx + reg ),
Zmm( 32 - num_push_regs + reg ) );
get_zmms_from_stack( 32 - num_push_regs, num_push_regs,
num_push_regs * 64 );
}
//r2 and z, q are scratch regs
//r will be passed in and out of parent function.
void bli_lpgemm_jit:: POLY_EVAL_6_AVX512( )
{
vmulps( Zmm( r2 ), Zmm( r ), Zmm( r ) );
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_exp_off, 3) );
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_exp_off, 2) );
vmovups( Zmm( q ), Zmm( const2 ) );
vfmadd231ps( Zmm( q ), Zmm( const1 ), Zmm( r ) );
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_exp_off, 1) );
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_exp_off, 0) );
vmovups( Zmm( z ), Zmm( const2 ) );
vfmadd231ps( Zmm( z ), Zmm( const1 ), Zmm( r ) );
vfmadd231ps( Zmm( z ), Zmm( r2 ), Zmm( q ) );
vmulps(Zmm( r2 ), Zmm( r2 ), Zmm( r2 ) );
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_exp_off, 5) );
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_exp_off, 4) );
vfmadd231ps( Zmm( const2 ), Zmm( const1 ), Zmm( r ) );
vfmadd231ps( Zmm( z ), Zmm( const2 ), Zmm( r2 ) );
vmovups(Zmm( r ), Zmm( z ) );
}
// z, r, dn is a scratch register
// takes 'x' as input and returns 'q' to the parent
void bli_lpgemm_jit:: EXPF_AVX512()
{
vbroadcastss( Zmm( const1 ), get_constant(gelu_macros_off, 0) );
vmulps( Zmm( z ), Zmm( x ), Zmm(const1 ) );
vbroadcastss( Zmm( const2 ), get_constant(gelu_macros_off, 1) );
vaddps( Zmm( dn ), Zmm( z ), Zmm( const2 ) );
vsubps( Zmm( r ), Zmm( dn ), Zmm( const2 ) );
vsubps( Zmm( r ), Zmm( z ), Zmm( r ) );
POLY_EVAL_6_AVX512();
vpslld( Zmm( dn ), Zmm( dn ), 0x17 );
vpaddd( Zmm( q ), Zmm( r ), Zmm( dn ) );
vpxorq( Zmm( const2 ), Zmm( const2 ), Zmm( const2 ) );
vpbroadcastd( Zmm( const1 ), get_constant(gelu_macros_off, 2) );
vcmpps( k5, Zmm( const1 ), Zmm( x ), 0x06 );
vpandd( Zmm( q ) | k5, Zmm( q ), Zmm( const2 ) );
vbroadcastss( Zmm( const1 ), get_constant(gelu_macros_off, 3) );
vcmpps( k5, Zmm( const1 ), Zmm( x ), 0x06 );
vbroadcastss( Zmm( x ), get_constant(gelu_macros_off, 4) );
vpxord( Zmm( x ) | k5, Zmm( q ), Zmm( const2 ) );
vmovups(Zmm( q ), Zmm( x ) );
}
// uses z, dn, r as scratch regs
// passes r to child macro and gets q
// takes x_tanh as input and gives back x_tanh
void bli_lpgemm_jit:: TANHF_AVX512()
{
vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 2) );
mov( ebx, 0x7FFFFFFF );
vpbroadcastd( Zmm( const2 ), ebx );
vpandd( Zmm( x ), Zmm( x_tanh ), Zmm( const2 ) );
vmulps( Zmm( x ), Zmm( x ), Zmm( const1 ) );
EXPF_AVX512();
mov( eax, -1 );
vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 4) );
vaddps( Zmm( z ), Zmm( q ), Zmm( const1 ) );
vbroadcastss( Zmm( const2 ), get_constant(gelu_consts_off, 5) );
vaddps( Zmm( r ), Zmm( z ), Zmm( const2 ) );
vdivps( Zmm( z ), Zmm( z ), Zmm( r ) );
vmulps( Zmm( z ), Zmm( z ), Zmm( const1 ) );
mov( eax, -2147483648 );
vpbroadcastd( Zmm( const1 ), eax );
vpandd(Zmm( q ), Zmm( x_tanh ), Zmm( const1 ) );
vpxord( Zmm( x_tanh ), Zmm( q ), Zmm( z ) );
}
void bli_lpgemm_jit:: GELU_TANH_F32_AVX512_DEF(dim_t reg )
{
vmulps( Zmm( r2 ), Zmm( reg ), Zmm( reg ) );
vmulps( Zmm( r2 ), Zmm( r2 ), Zmm( reg ) );
vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 0) );
vmovups( Zmm( r ), Zmm( reg ) );
vfmadd231ps( Zmm( r ), Zmm( r2 ), Zmm( const1 ) );
vbroadcastss( Zmm( const2 ), get_constant(gelu_consts_off, 1) );
vmulps( Zmm( x_tanh ), Zmm( r ), Zmm( const2 ) );
TANHF_AVX512();
vbroadcastss( Zmm( const2 ), get_constant(gelu_consts_off, 6) );
vaddps( Zmm( x_tanh ), Zmm( x_tanh ), Zmm( const2 ) );
vmulps( Zmm( x_tanh ), Zmm( x_tanh ), Zmm( reg ) );
vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 3) );
vmulps( Zmm( reg ), Zmm( x_tanh ), Zmm( const1 ) );
}
void bli_lpgemm_jit:: gelu_tanh( dim_t m_dim, dim_t n_dim )
{
apply_post_ops_in_high_reg_pressure
(
num_gelu_regs,
std::bind
(
&bli_lpgemm_jit::GELU_TANH_F32_AVX512_DEF,
this,
std::placeholders::_1
)
);
}
void bli_lpgemm_jit:: POLY_EVAL_HORNER_16_0_AVX512()
{
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 15) );
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 14) );
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 13) );
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 12) );
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 11) );
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 10) );
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 9) );
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 8) );
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 7 ) );
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 6) );
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 5) );
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 4) );
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 3) );
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 2) );
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 1) );
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 0) );
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
vmulps( Zmm( x ), Zmm( const2 ), Zmm( r ) );
}
void bli_lpgemm_jit:: ERF_AVX512()
{
mov( eax, 0x7FFFFFFF );
vpbroadcastd( Zmm( const2 ), eax );
vpandd( Zmm( r ), Zmm( x_erf ), Zmm( const2 ) );
POLY_EVAL_HORNER_16_0_AVX512();
vbroadcastss( Zmm( const1 ), get_constant(erf_consts_off, 1) );
vbroadcastss( Zmm( const2 ), get_constant(erf_consts_off, 3) );
vcmpps( k5, Zmm( const2 ), Zmm( r ), 0x06 );
vpxorq( Zmm( const2 ), Zmm( const2 ), Zmm( const2 ) );
vpxord( Zmm( const1 ) | k5, Zmm( x ), Zmm( const2 ) );
vmovups( Zmm( x ), Zmm( const1 ) );
vbroadcastss( Zmm( const1 ), get_constant(erf_consts_off, 1) );
vcmpps( k5, Zmm( const1 ), Zmm( x ), 0x06 );
vpxord( Zmm( const1 ) | k5, Zmm( x ), Zmm( const2 ) );
mov( eax, ~(0x7FFFFFFF));
vpbroadcastd( Zmm( const2 ), eax );
vpandd( Zmm( x_erf ), Zmm( x_erf ), Zmm( const2 ) );
vpord( Zmm( x_erf ), Zmm( x_erf ), Zmm( const1 ) );
}
void bli_lpgemm_jit:: GELU_ERF_F32_AVX512_DEF( dim_t reg )
{
vbroadcastss( Zmm( const1 ), get_constant(erf_consts_off, 0) );
vmulps( Zmm( x_erf ), Zmm( reg ), Zmm( const1 ) );
ERF_AVX512();
vbroadcastss( Zmm( const2 ), get_constant(erf_consts_off, 1) );
vaddps( Zmm( x_erf ), Zmm( x_erf ), Zmm( const2 ) );
vmulps( Zmm( x_erf ), Zmm( x_erf ), Zmm( reg ) );
vbroadcastss( Zmm( const2 ), get_constant(erf_consts_off, 2) );
vmulps( Zmm( reg ), Zmm( x_erf ), Zmm( const2 ) );
}
void bli_lpgemm_jit:: gelu_erf( dim_t m_dim, dim_t n_dim )
{
apply_post_ops_in_high_reg_pressure
(
num_gelu_regs,
std::bind
(
&bli_lpgemm_jit::GELU_ERF_F32_AVX512_DEF,
this,
std::placeholders::_1
)
);
}
void bli_lpgemm_jit::SWISH_F32_AVX512_DEF( dim_t reg )
{
vpxorq( Zmm( x ), Zmm( x ), Zmm( x ) );
vfnmadd231ps( Zmm( x ), Zmm( reg ), Zmm( x_tanh ) );
// Input reg x and output reg q.
EXPF_AVX512();
vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 6) );
vaddps( Zmm( q ), Zmm( q ), Zmm( const1 ) );
vdivps( Zmm( reg ), Zmm( reg ), Zmm( q ) );
}
void bli_lpgemm_jit::swish( dim_t m_dim, dim_t n_dim )
{
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] );
vbroadcastss( Zmm( x_tanh ), ptr[ rax ] );
apply_post_ops_in_high_reg_pressure
(
num_gelu_regs,
std::bind
(
&bli_lpgemm_jit::SWISH_F32_AVX512_DEF,
this,
std::placeholders::_1
)
);
}
void bli_lpgemm_jit:: store_f32( dim_t m_dim, dim_t n_dim )
{
dim_t reg_num;
for( dim_t m = 0; m < m_dim; m++ )
{
if( m > 0 ) add( rcx, rdi );
for( dim_t n = 0; n < num_full_loads; n++ )
{
reg_num = fma_start_idx + ( m * num_loads ) + n;
vmovups( ptr[ rcx + n * 64 ], Zmm( reg_num ) );
}
// Use mask in case of n_fringe.
if( n_rem )
{
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
vmovups( ptr[ rcx + num_full_loads * 64 ] | k4, Zmm( reg_num ) );
}
}
}
void bli_lpgemm_jit:: cvt_store_f32_bf16_mask( dim_t m_dim, dim_t n_dim )
{
dim_t reg_num;
mov( rcx, ptr[ rsp + stack_off_buf_downscale ] );
mov( rax, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, rs_c_downscale ) ] );
// rs_c_downscale *= sizeof(bfloat16)
lea( rax, ptr[rax * 2 ] );
mov( rsi, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
mov( rbx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
imul( rsi, rax );
lea( rsi, ptr[ rsi + rbx * 2 ] );
add( rcx, rsi );
for( dim_t m = 0; m < m_dim; m++ )
{
for( dim_t n = 0; n < num_full_loads; n++ )
{
reg_num = fma_start_idx + ( m * num_loads ) + n;
// convert from 32 bit elements to 16 bit elements
vcvtneps2bf16( Ymm( reg_num ), Zmm( reg_num ) );
vmovdqu16( ptr[ rcx + n * 32 ], Ymm( reg_num ) );
}
if( n_rem )
{
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
// convert from 32 bit elements to 16 bit elements
vcvtneps2bf16( Ymm( reg_num ), Zmm( reg_num ) );
vmovdqu16( ptr[ rcx + num_full_loads * 32 ] | k4, Ymm( reg_num ) );
}
// move to next row
add( rcx, rax );
}
}
void bli_lpgemm_jit::initialize_params( lpgemm_jit_inputs_t* params )
{
// params needed in kernel
// a(r14, rax), b(rbx), c(r12, rcx) podim_ters. To be stored in regs
// rs_a(r8), cs_a(r9), rs_b(r10), rs_c(rdi).
// alpha(rax), beta(rbx) values. To be pushed to stack
// m_iter(r11), ps_a(rax) values. ps_a to be pushed to stack.
// k_iter(rsi), k_left(rsi) value. To be pushed to stack.
// load values from params struct to registers and stack
if( params->m_loop )
{
// move address of a
mov( r14, ptr[ rdi + offsetof( lpgemm_jit_params_t, a ) ] );
mov( r11, ptr[ rdi + offsetof( lpgemm_jit_params_t, m_iter ) ] );
}
else
{
mov( rax, ptr[ rdi + offsetof(lpgemm_jit_params_t, a ) ] );
}
if( params->generate_mask )
{
// This mask will be used to load/store bf16 elements
kmovd( k3, ptr[ rdi + offsetof( lpgemm_jit_params_t, mask16 ) ] );
// This mask will be used to load/store f32 elements
kmovw( k4, ptr[ rdi + offsetof(lpgemm_jit_params_t, mask32 ) ] );
}
mov( r12, ptr[ rdi + offsetof( lpgemm_jit_params_t, c ) ] );
mov( r8, ptr[ rdi + offsetof( lpgemm_jit_params_t, rs_a ) ] );
mov( r9, ptr[ rdi + offsetof( lpgemm_jit_params_t, cs_a ) ] );
mov( r10, ptr [rdi + offsetof( lpgemm_jit_params_t, rs_b ) ] );
// Push all the params that will be required in later stages
// of kernel to stack.
// Pusing in order ps_a2, k_iter, k_left, alpha, beta, b
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, ps_a2 ) ] );
mov( ptr[ rsp + stack_off_ps_a ], rbx);
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t,
k_iter_before_prefetch ) ] );
mov( ptr[ rsp + stack_off_k_iter_before_prefetch ], rbx );
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t,
k_iter_after_prefetch ) ] );
mov( ptr[ rsp + stack_off_k_iter_after_prefetch ], rbx );
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, k_left ) ] );
mov( ptr[ rsp + stack_off_k_left ], rbx );
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, alpha ) ] );
mov( ptr[ rsp + stack_off_alpha ], rbx );
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, beta ) ] );
mov( ptr[ rsp + stack_off_beta ], rbx );
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, b ) ] );
mov( ptr[ rsp + stack_off_b_ptr ], rbx );
// once all the params that will be required in
// later stages of kernel are pushed to stack,
// move rs_c dim_to rdi.
mov( rdi, ptr[ rdi + offsetof( lpgemm_jit_params_t, rs_c ) ] );
// push all members of lpgemm_post_op_attr struct to stack.
// Since this will be passed as 2nd arg to the function, it will be in rsi
mov( rbx, ptr[ rsi + offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
mov( ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_i ) ], rbx );
mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
mov( ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_j ) ], rcx );
mov( rbx, ptr[ rsi + offsetof( lpgemm_post_op_attr, rs_c_downscale ) ] );
mov( ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, rs_c_downscale)], rbx );
mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, cs_c_downscale ) ] );
mov( ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, cs_c_downscale)], rcx );
mov( rbx, ptr[ rsi + offsetof(lpgemm_post_op_attr, buf_downscale ) ] );
mov( ptr[ rsp + stack_off_buf_downscale ], rbx );
mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, is_first_k ) ] );
mov( ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, is_first_k ) ], rcx );
mov( rbx, ptr[ rsi + offsetof(lpgemm_post_op_attr, is_last_k ) ] );
mov( ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, is_last_k ) ], rbx );
mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, c_stor_type ) ] );
mov( ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, c_stor_type ) ], rcx );
mov( rbx, ptr[ rsi + offsetof(lpgemm_post_op_attr, b_sum_offset)]);
mov( ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, b_sum_offset )] , rbx );
mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, b_col_sum_vec ) ] );
mov( ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, b_col_sum_vec ) ], rcx );
mov( rbx, ptr[ rsi +
offsetof( lpgemm_post_op_attr, b_col_sum_vec_s16 ) ] );
mov( ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, b_col_sum_vec_s16 ) ], rbx );
// Storing the address to the head node of post-op list in stack
// It needs to be restored after every loop of m_iter
mov( ptr[ rsp + stack_off_temp_list ], rdx );
// initialize top of zmm stack
zmm_stack_top = stack_off_zmm_stack;
}
void bli_lpgemm_jit:: prefetchC( dim_t m_dim, dim_t n_dim )
{
for( dim_t m = 0; m < m_dim; m++ )
{
if( m > 0 ) add( rcx, rdi );
for( dim_t n = 0; n < num_loads; n++ )
{
prefetcht1( ptr[ rcx + n * 64 ] );
}
}
}
void bli_lpgemm_jit:: post_op_label_lastk_safe_jump_with_next_ptr()
{
mov( rdx, ptr[rdx+offsetof( lpgemm_post_op, next ) ] );
post_op_label_lastk_safe_jump();
}
void bli_lpgemm_jit:: post_op_label_lastk_safe_jump()
{
// check if post_ops_list_temp != NULL
cmp( rdx, 0 );
je( "POST_OPS_6x64_DISABLE", T_NEAR );
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_code ) ] );
cmp( rax, POST_OPS_DISABLE );
je( "POST_OPS_6x64_DISABLE", T_NEAR );
cmp( rax, POST_OPS_BIAS ) ;
je( "POST_OPS_BIAS_6x64", T_NEAR );
cmp( rax, POST_OPS_RELU );
je( "POST_OPS_RELU_6x64", T_NEAR );
cmp( rax, POST_OPS_RELU_SCALE );
je( "POST_OPS_RELU_SCALE_6x64", T_NEAR );
cmp( rax, POST_OPS_GELU_TANH );
je( "POST_OPS_GELU_TANH_6x64", T_NEAR );
cmp( rax, POST_OPS_GELU_ERF );
je( "POST_OPS_GELU_ERF_6x64", T_NEAR );
cmp( rax, POST_OPS_CLIP );
je( "POST_OPS_CLIP_6x64", T_NEAR );
cmp( rax, POST_OPS_DOWNSCALE );
je( "POST_OPS_DOWNSCALE_6x64", T_NEAR );
cmp( rax, POST_OPS_MATRIX_ADD );
je( "POST_OPS_MATRIX_ADD_6x64", T_NEAR );
cmp( rax, POST_OPS_SWISH );
je( "POST_OPS_SWISH_6x64", T_NEAR );
}
// Constructor
bli_lpgemm_jit:: bli_lpgemm_jit( void* buffer, size_t bufferSize )
: CodeGenerator( bufferSize, buffer )
{
protect( buffer, bufferSize, PROTECT_RWE );
}
// Main kernel function body
void bli_lpgemm_jit::generate_kernel( lpgemm_jit_inputs_t* params )
{
dim_t m_dim = params->MR;
dim_t n_dim = params->NR;
// In kernel-function pointer array, kernels to handle n < 16
// are stored at col-index 0. Hacking n_dim to some value 0 < value < 16
// so masked instructions are generated.
// This will be removed when we support on-the-fly generation of kernels.
if( n_dim == 0 )
{
n_dim = 2;
params->generate_mask = TRUE;
}
n_rem = n_dim % NUM_F32_ELEMS_PER_ZMM;
// Number of loads that doesn't require mask
num_full_loads = ( n_dim / num_elems_per_reg );
// Number of loads in total = full loads + mask load (if required)
num_loads = ( num_full_loads ) + ( n_rem > 0 ? 1 : 0 );
// Total number of registers to store accumulated values.
num_fma_regs = m_dim * num_loads;
// calculating start index for accumulation registers.
// If the kernel requires 'x' number of accumulation regs, we use the
// last 'x' ZMMs available on certain architecture.
// 31 is hardcoded here since we only support AVX-512 as of now,
// This needs to be made as a configurable parameter later.
fma_start_idx = 31 - num_fma_regs + 1;
// If a kernel requires x registers for loads, we always use the
// first 'x' ZMM registers available for loads.
// And the immediate registers next to load regs are used for broadcast.
bcst_start_idx = load_start_idx + num_loads;
// While scaling the accumulated registers with beta,
// load regs will be used to load C matrix,
// Hence using broadcast register to store beta value.
beta_reg = bcst_start_idx;
preamble();
// add some spack in stack to store params
sub( rsp, 512 );
// Initialize all the paramters required for execution of kernel.
// load some values to registers and push the rest of them to stack.
initialize_params( params );
/* register usage:
r14, rax - podim_ter for A matrix
r8 - rs_a
r9 - cs_a
r13 - 3 * rs_a
r15 - 5 * rs_a
rbx - podim_ter to B matrix, beta
r10 - rs_b
r12, rcx - podim_ter for C matrix
rdi - rs_c
r11 - m_iter
rsi - k_iter, k_left
rax - ps_a2, alpha
*/
lea( rdi, ptr[ rdi * 4 ] ); // rs_c *= sizeof(float) => rs_c *= 4
lea( r8, ptr[ r8 * 2 ] ); // rs_a *= sizeof(dt) => rs_a *= 2
lea( r9, ptr[ r9 * 2 ] ); // cs_a *= sizeof(dt) => cs_a *= 2
if ( m_dim >= 4)
lea( r13, ptr[r8 + r8 * 2 ] ); // r13 = 3 * rs_a
if( m_dim >= 6 )
lea( r15, ptr[r8 + r8 * 4 ] ); // r15 = 5 * rs_a
lea( r10, ptr[ r10 * 2 ] ); // rs_b *= sizeof(dt) => rs_b *= 2
mov( rcx, r12 );
if( params->m_loop )
{
L( "BLOOP6X64I" );
mov( rax, r14 ); // reset rax to current upanel of a.
}
mov( rbx, ptr[ rsp + stack_off_b_ptr ] ); // move address of b
// Zero all the registers that will be used for accumulation.
reg_init( m_dim, n_dim );
// load k_iter
mov( rsi, ptr[ rsp + stack_off_k_iter_before_prefetch ] );
test( rsi, rsi );
je( "BPREFETCH", T_NEAR );
L( "BLOOPKITER" );
// Main k-unroll loop
kernel_unroll( m_dim, n_dim );
dec( rsi ); // i -= 1
jne("BLOOPKITER", T_NEAR );
L( "BPREFETCH" );
prefetchC( m_dim, n_dim );
mov( rsi, ptr[ rsp + stack_off_k_iter_after_prefetch ] );
test( rsi, rsi );
je( "BCONSIDKLEFT", T_NEAR );
L( "AFTERPREFETCH" );
kernel_unroll( m_dim, n_dim );
dec( rsi );
jne( "AFTERPREFETCH", T_NEAR );
L( "BCONSIDKLEFT" );
// load k_left
mov( rsi, ptr[ rsp + stack_off_k_left ] );
test( rsi, rsi );
je( "BPOSTACCUM", T_NEAR );
// k_fringe
k_fringe_loop( m_dim, n_dim );
L( "BPOSTACCUM" );
// Generate alpha scaling code only when required.
if( params->alpha_scale )
{
mov( rax, ptr[ rsp + stack_off_alpha ] ); // load address of alpha
vbroadcastss( Zmm( alpha_reg ), ptr[ rax ] );
scale_alpha( m_dim, n_dim );
}
mov( rbx, ptr[ rsp + stack_off_beta ] );
vbroadcastss( Xmm( beta_reg ), ptr[ rbx ] ); // load address of beta
// Zero out a register
vxorps( Xmm( alpha_reg ), Xmm( alpha_reg ) );
// cmp beta value with zero
vucomiss( Xmm( beta_reg ), Xmm( alpha_reg ) );
// if beta=0, skip beta scaling
je( "BPOSTBETAOP", T_NEAR );
// check if buf_downscale is NULL
mov( rax, ptr[ rsp + stack_off_buf_downscale ] );
cmp( rax, 0 );
je( "BETAOP", T_NEAR );
// Check if is_first_k is 0
mov( rcx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, is_first_k ) ] );
test( rcx, rcx );
je( "BETAOP", T_NEAR );
L( "DOWNSCALEBETAOP" );
vbroadcastss( Zmm( beta_reg ), ptr[ rbx ] );
bf16_f32_beta_op( m_dim, n_dim );
jmp( "BPOSTBETAOP", T_NEAR );
L( "BETAOP" );
mov( rcx, r12 );
vbroadcastss( Zmm( beta_reg ), ptr[ rbx ] );
f32_f32_beta_op( m_dim, n_dim );
L( "BPOSTBETAOP" );
// Check if is_last_k is 0
mov( rcx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, is_last_k ) ] );
test(rcx, rcx);
je( "POST_OPS_6x64_DISABLE", T_NEAR );
post_op_label_lastk_safe_jump();
L( "POST_OPS_BIAS_6x64" );
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] );
mov( bl, ptr[ rax ] );
//check if op_args2 == 'R'
cmp( bl, 0x52 );
je("BIAS_ROW_MAJOR", T_NEAR );
// check if op_args2 == 'r
cmp( bl, 0x72 );
je( "BIAS_ROW_MAJOR", T_NEAR );
bias_col_major( m_dim, n_dim );
jmp( "POST_BIAS", T_NEAR );
L( "BIAS_ROW_MAJOR" );
bias_row_major( m_dim, n_dim );
L( "POST_BIAS" );
post_op_label_lastk_safe_jump_with_next_ptr();
L( "POST_OPS_RELU_6x64" );
relu( m_dim, n_dim );
post_op_label_lastk_safe_jump_with_next_ptr();
L( "POST_OPS_RELU_SCALE_6x64" );
relu_scale( m_dim, n_dim );
post_op_label_lastk_safe_jump_with_next_ptr();
L( "POST_OPS_GELU_TANH_6x64" );
gelu_tanh( m_dim, n_dim );
post_op_label_lastk_safe_jump_with_next_ptr();
L( "POST_OPS_GELU_ERF_6x64" );
gelu_erf( m_dim, n_dim );
post_op_label_lastk_safe_jump_with_next_ptr();
L( "POST_OPS_CLIP_6x64" );
clip_f32( m_dim, n_dim );
post_op_label_lastk_safe_jump_with_next_ptr();
L( "POST_OPS_DOWNSCALE_6x64" );
post_op_label_lastk_safe_jump_with_next_ptr();
L( "POST_OPS_MATRIX_ADD_6x64" );
mov( rcx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, c_stor_type ) ] );
cmp( rcx, 4 );
je( "BF16_MATADD", T_NEAR );
f32_f32_matrix_add( m_dim, n_dim );
jmp( "POST_MATADD", T_NEAR );
L( "BF16_MATADD" );
bf16_f32_matrix_add( m_dim, n_dim );
L( "POST_MATADD" );
post_op_label_lastk_safe_jump_with_next_ptr();
L( "POST_OPS_SWISH_6x64" );
swish( m_dim, n_dim );
post_op_label_lastk_safe_jump_with_next_ptr();
L( "POST_OPS_6x64_DISABLE" );
// check if buf_downscale is NULL
mov( rax, ptr[ rsp + stack_off_buf_downscale ] );
cmp( rax, 0 );
je( "F32_STORE", T_NEAR );
// Check if is_last_k is 0
mov( rcx, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, is_last_k ) ] );
test( rcx, rcx );
je( "F32_STORE", T_NEAR );
L( "BF16_STORE" );
//mov( rcx, ptr[rsp + stack_off_buf_downscale]);
cvt_store_f32_bf16_mask( m_dim, n_dim );
jmp( "END", T_NEAR );
L( "F32_STORE" );
mov( rcx, r12 );
store_f32( m_dim, n_dim );
L( "END" );
if( params->m_loop )
{
mov(rax, ptr[ rsp + stack_off_ps_a ] );
lea( r12, ptr[ r12 + rdi * 4 ] );
lea( r12, ptr[ r12 + rdi * 2 ] ); // c_ii = r12 += 6*rs_c;
lea(r14, ptr[ r14 + rax ] ); // a_ii = r14 += ps_a2
//add(, m_dim );
mov( rax, ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
add( rax, m_dim);
mov( ptr[ rsp + stack_off_postop +
offsetof( lpgemm_post_op_attr, post_op_c_i ) ], rax );
mov( rdx, ptr[ rsp + stack_off_temp_list ] );
dec(r11);
jne("BLOOP6X64I", T_NEAR);
}
// release the space that is requested from stack
add( rsp, 512 );
// restore the callee-save registers.
postamble();
ret();
align(64);
L(tables);
db(reinterpret_cast<uint8_t*>( &gelu_consts ), sizeof( gelu_consts ) );
db(reinterpret_cast<uint8_t*>( &gelu_macros ), sizeof( gelu_macros ) );
db(reinterpret_cast<uint8_t*>( &lpgemm_exp ), sizeof( lpgemm_exp ) );
db(reinterpret_cast<uint8_t*>( &erf_consts ), sizeof( erf_consts ) );
db(reinterpret_cast<uint8_t*>( &lpgemm_erf ), sizeof( lpgemm_erf ) );
}
const void (* bli_lpgemm_jit:: get_function ()const)( lpgemm_jit_params_t*,
lpgemm_post_op_attr*,
lpgemm_post_op* )
{
return getCode<const void (*)( lpgemm_jit_params_t*,
lpgemm_post_op_attr*,
lpgemm_post_op*)>();
}
const void* bli_lpgemm_jit:: get_code ()const
{
return getCode<const void (*)>();
}
dim_t bli_lpgemm_jit:: get_size ()
{
return getSize();
}