mirror of
https://github.com/amd/blis.git
synced 2026-05-04 14:31:12 +00:00
Details: - For a variable x, Using address of x in an instruction throws exception if the difference between &x and access position is larger than 2 GiB. To solve this issue all variables are stored within the JIT code section and are accessed using relative addressing. - Fixed a bug in B matrix pack function for s8s8s32os32 API. - Fixed a bug in JIT code to apply bias on col-major matrices. AMD-Internal: [SWLCSG-2820] Change-Id: I82f117a0422c794cb9b1a4d65a89d60de4adfd96
1510 lines
47 KiB
C++
1510 lines
47 KiB
C++
/*
|
|
|
|
BLIS
|
|
An object-based framework for developing high-performance BLAS-like
|
|
libraries.
|
|
|
|
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are
|
|
met:
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
- Neither the name(s) of the copyright holder(s) nor the names of its
|
|
contributors may be used to endorse or promote products derived
|
|
from this software without specific prior written permission.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*/
|
|
|
|
#include "lpgemm_jit_bf16.h"
|
|
|
|
// push callee-save registers to stack
|
|
void bli_lpgemm_jit:: preamble()
|
|
{
|
|
push(rbp);
|
|
push(rbx);
|
|
push(r12);
|
|
push(r13);
|
|
push(r14);
|
|
push(r15);
|
|
}
|
|
|
|
// pop the callee-save registers before returning from function.
|
|
void bli_lpgemm_jit:: postamble()
|
|
{
|
|
pop(r15);
|
|
pop(r14);
|
|
pop(r13);
|
|
pop(r12);
|
|
pop(rbx);
|
|
pop(rbp);
|
|
vzeroupper();
|
|
}
|
|
|
|
void bli_lpgemm_jit:: store_zmms_in_stack( dim_t reg_start_idx,
|
|
dim_t num_regs,
|
|
dim_t stack_off
|
|
)
|
|
{
|
|
for( dim_t idx = 0; idx < num_regs; idx++ )
|
|
{
|
|
vmovups( ptr[ rsp + zmm_stack_top + stack_off + idx * 64],
|
|
Zmm( reg_start_idx + idx ) );
|
|
}
|
|
}
|
|
|
|
void bli_lpgemm_jit:: get_zmms_from_stack( dim_t reg_start_idx,
|
|
dim_t num_regs,
|
|
dim_t stack_off
|
|
)
|
|
{
|
|
for( dim_t idx = 0; idx < num_regs; idx++ )
|
|
{
|
|
vmovups( Zmm( reg_start_idx + idx ),
|
|
ptr[ rsp + zmm_stack_top + stack_off + idx * 64] );
|
|
}
|
|
}
|
|
|
|
//Zero out the registers that will be used for storing accumulated values.
|
|
// For a given micro-kernel dimension MRxNR,
|
|
// considering a row-major kernel, we need (MR * (NR / num_elems per reg))
|
|
// registers to store accumulated values.
|
|
void bli_lpgemm_jit:: reg_init( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
vxorps( Zmm( fma_start_idx ), Zmm( fma_start_idx ));
|
|
for( dim_t m = fma_start_idx + 1; m < 32; m++ )
|
|
{
|
|
vmovaps( Zmm( m ), Zmm( fma_start_idx ) );
|
|
}
|
|
}
|
|
|
|
|
|
// This code replicates the existing bf16 kernel.
|
|
// Hence unroll factor is hardcoded to be 2.
|
|
// To-DO: Make unroll factor as an configurable parameter.
|
|
void bli_lpgemm_jit:: kernel_unroll( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
dim_t reg_num;
|
|
|
|
// Broadcast elements of A matrix
|
|
vpbroadcastd( Zmm( bcst_start_idx ), ptr[ rax ] );
|
|
|
|
// load elements of B matrix into registers
|
|
for( dim_t n = 0; n < num_full_loads; n++ )
|
|
vmovdqu16( Zmm( load_start_idx + n ), ptr[ rbx + n * 64 ] );
|
|
|
|
// In case of last load with fringe part, use mask
|
|
if( n_rem )
|
|
vmovdqu16( Zmm( load_start_idx + num_full_loads )
|
|
| k3 | T_z, ptr[ rbx + num_full_loads * 64 ] );
|
|
|
|
add( rbx, r10 );
|
|
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
// broadcast elements of A matrix.
|
|
// Using 2 ZMM registers for broadcast.
|
|
if( m < ( m_dim - 1 ) )
|
|
{
|
|
switch ( m + 1 )
|
|
{
|
|
case 1:
|
|
case 4:
|
|
case 2: vpbroadcastd( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
|
|
ptr[ rax + r8 * ( m + 1 ) ] );
|
|
break;
|
|
case 3: vpbroadcastd( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
|
|
ptr[ rax + r13 ] );
|
|
break;
|
|
case 5: vpbroadcastd( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
|
|
ptr[ rax + r15 ] );
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
// move to next column
|
|
if( m == ( m_dim - 1 ) ) add( rax, r9 );
|
|
|
|
// Generate FMA instructions.
|
|
for( dim_t n = 0; n < num_loads; n++ )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + n;
|
|
|
|
vdpbf16ps( Zmm( reg_num ), Zmm( bcst_start_idx + m % 2 ),
|
|
Zmm( load_start_idx + n ) );
|
|
}
|
|
}
|
|
}
|
|
|
|
void bli_lpgemm_jit:: k_fringe_loop( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
|
|
dim_t reg_num;
|
|
|
|
// Broadcast elements of A matrix
|
|
vpbroadcastw( Zmm( bcst_start_idx ), ptr[ rax ] );
|
|
|
|
// load elements of B matrix into registers
|
|
for( dim_t n = 0; n < num_full_loads; n++ )
|
|
vmovdqu16( Zmm( load_start_idx + n ), ptr[ rbx + n * 64 ] );
|
|
|
|
// In case of last load with fringe part, use mask
|
|
if( n_rem )
|
|
vmovdqu16( Zmm( load_start_idx + num_full_loads )
|
|
| k3 | T_z, ptr[ rbx + num_full_loads * 64 ] );
|
|
|
|
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
if( m < ( m_dim - 1 ) )
|
|
{
|
|
// broadcast elements of A matrix.
|
|
// Using 2 ZMM registers for broadcast.
|
|
switch ( m + 1 )
|
|
{
|
|
case 1:
|
|
case 4:
|
|
case 2: vpbroadcastw( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
|
|
ptr[ rax + r8 * ( m + 1 ) ] );
|
|
break;
|
|
case 3: vpbroadcastw( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
|
|
ptr[ rax + r13 ] );
|
|
break;
|
|
case 5: vpbroadcastw( Zmm( bcst_start_idx + ( m + 1 ) % 2 ),
|
|
ptr[ rax + r15 ] );
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Generate FMA instructions.
|
|
for( dim_t n = 0; n < num_loads; n++ )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + n;
|
|
|
|
vdpbf16ps( Zmm( reg_num ), Zmm( bcst_start_idx + m % 2 ),
|
|
Zmm( load_start_idx + n ) );
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
// Generate required number of mul instructions for scaling with alpha.
|
|
void bli_lpgemm_jit:: scale_alpha( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
for( dim_t reg_num = fma_start_idx; reg_num < 32; reg_num++ )
|
|
vmulps( Zmm( reg_num ), Zmm( alpha_reg ), Zmm( reg_num ) );
|
|
}
|
|
|
|
|
|
// Scale C by beta and store when beta is a generic value.
|
|
void bli_lpgemm_jit:: f32_f32_beta_op( dim_t m_dim, dim_t n_dim)
|
|
{
|
|
dim_t reg_num;
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
if( m > 0 ) add( rcx, rdi );
|
|
|
|
for( dim_t n = 0; n < num_full_loads; n++ )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + n;
|
|
|
|
vmovups( Zmm( load_start_idx + n ) , ptr[ rcx + n * 64 ] );
|
|
|
|
vfmadd231ps( Zmm( reg_num ), Zmm( load_start_idx + n ),
|
|
Zmm( beta_reg ) );
|
|
}
|
|
|
|
// Use mask in case of n_fringe.
|
|
if( n_rem )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
|
|
|
|
vmovups( Zmm( load_start_idx + num_full_loads ) | k4 | T_z,
|
|
ptr[ rcx + num_full_loads * 64 ] );
|
|
|
|
vfmadd231ps( Zmm( reg_num ),
|
|
Zmm( load_start_idx + num_full_loads ),
|
|
Zmm( beta_reg ) );
|
|
}
|
|
}
|
|
}
|
|
|
|
void bli_lpgemm_jit:: bf16_f32_beta_op( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
|
|
dim_t reg_num;
|
|
mov( rcx, ptr[ rsp + stack_off_buf_downscale ] );
|
|
mov( rax, ptr[ rsp + stack_off_postop + offsetof( lpgemm_post_op_attr,
|
|
rs_c_downscale ) ] );
|
|
|
|
|
|
|
|
// rs_c_downscale *= sizeof(bfloat16)
|
|
lea( rax, ptr[ rax * 2 ] );
|
|
mov( rsi, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
|
|
mov( rbx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
|
|
|
|
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) )
|
|
imul( rsi, rax );
|
|
|
|
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) )
|
|
// + post_op_c_j * sizeof(bfloat16)
|
|
lea( rsi, ptr[ rsi + rbx * 2 ] );
|
|
|
|
add( rcx, rsi );
|
|
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
for( dim_t n = 0; n < num_full_loads; n++ )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + n;
|
|
|
|
// convert from 16 bit elements to 32 bit elements
|
|
vpmovsxwd( Zmm( load_start_idx + n ), ptr[ rcx + n * 32 ] );
|
|
|
|
// Shift left by 16 bits
|
|
vpslld( Zmm( load_start_idx + n ), Zmm( load_start_idx + n ),
|
|
0x10 );
|
|
|
|
// fma with beta
|
|
vfmadd231ps( Zmm( reg_num ), Zmm( beta_reg ),
|
|
Zmm( load_start_idx + n ) );
|
|
}
|
|
if( n_rem )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
|
|
|
|
// load the bf16 elements from the downscale buffer using mask.
|
|
vmovdqu16( Ymm( load_start_idx + num_full_loads ) | k4 | T_z,
|
|
ptr[rcx + num_full_loads * 32 ] );
|
|
|
|
// convert from 16 bit elements to 32 bit elements
|
|
vpmovsxwd( Zmm( load_start_idx + num_full_loads ),
|
|
Ymm( load_start_idx + num_full_loads ) );
|
|
|
|
// Shift left by 16 bits
|
|
vpslld( Zmm( load_start_idx + num_full_loads ),
|
|
Zmm( load_start_idx + num_full_loads ), 0x10 );
|
|
|
|
// fma with beta
|
|
vfmadd231ps( Zmm( reg_num ), Zmm( beta_reg ),
|
|
Zmm( load_start_idx + num_full_loads ) );
|
|
}
|
|
|
|
// move to next row
|
|
add( rcx, rax );
|
|
}
|
|
|
|
}
|
|
|
|
void bli_lpgemm_jit:: clip_f32( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
dim_t min_reg = load_start_idx;
|
|
dim_t max_reg = bcst_start_idx;
|
|
|
|
// min reg
|
|
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] );
|
|
vbroadcastss( Zmm( min_reg ), ptr[ rax ] );
|
|
|
|
// max reg
|
|
mov( rbx, ptr[ rdx + offsetof( lpgemm_post_op, op_args3 ) ] );
|
|
vbroadcastss( Zmm( max_reg ), ptr[ rbx ] );
|
|
|
|
for( dim_t m = fma_start_idx; m < 32; m++ )
|
|
{
|
|
vmaxps( Zmm( m ), Zmm( m ), Zmm( min_reg ) );
|
|
vminps( Zmm( m ), Zmm( m ), Zmm( max_reg ) );
|
|
}
|
|
}
|
|
|
|
void bli_lpgemm_jit:: bf16_f32_matrix_add( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
dim_t reg_num;
|
|
|
|
// rcx = matrix ptr
|
|
mov( rcx, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] );
|
|
|
|
// rax = ldm
|
|
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args3 ) ] );
|
|
mov( rax, ptr[ rax ] );
|
|
|
|
// ldm *= sizeof(bfloat16)
|
|
lea( rax, ptr[ rax * 2 ] );
|
|
|
|
mov( rsi, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
|
|
mov( rbx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
|
|
|
|
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) )
|
|
imul( rsi, rax );
|
|
|
|
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(bfloat16) )
|
|
// + post_op_c_j * sizeof(bfloat16)
|
|
lea( rsi, ptr[ rsi + rbx * 2 ] );
|
|
|
|
add( rcx, rsi );
|
|
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
for( dim_t n = 0; n < num_full_loads; n++ )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + n;
|
|
|
|
// convert from 16 bit elements to 32 bit elements
|
|
vpmovsxwd( Zmm( load_start_idx + n ), ptr[ rcx + n*32 ] );
|
|
|
|
// Shift left by 16 bits
|
|
vpslld( Zmm( load_start_idx + n ), Zmm( load_start_idx + n ),
|
|
0x10 );
|
|
|
|
vaddps( Zmm( reg_num ), Zmm( reg_num ),
|
|
Zmm( load_start_idx + n ) );
|
|
|
|
}
|
|
if( n_rem )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
|
|
|
|
// load the bf16 elements from the downscale buffer using mask.
|
|
vmovdqu16( Ymm( load_start_idx + num_full_loads ) | k4 | T_z,
|
|
ptr[rcx + num_full_loads * 32 ] );
|
|
|
|
// convert from 16 bit elements to 32 bit elements
|
|
vpmovsxwd( Zmm( load_start_idx + num_full_loads ),
|
|
Ymm( load_start_idx + num_full_loads ) );
|
|
|
|
// Shift left by 16 bits
|
|
vpslld( Zmm(load_start_idx + num_full_loads ),
|
|
Zmm( load_start_idx + num_full_loads ), 0x10 );
|
|
|
|
vaddps( Zmm( reg_num ), Zmm( reg_num ),
|
|
Zmm( load_start_idx + num_full_loads ) );
|
|
}
|
|
|
|
// move to next row
|
|
add( rcx, rax );
|
|
}
|
|
}
|
|
|
|
|
|
void bli_lpgemm_jit:: f32_f32_matrix_add( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
dim_t reg_num;
|
|
|
|
// rcx = matrix ptr
|
|
mov( rcx, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] );
|
|
// rax = ldm
|
|
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args3 ) ] );
|
|
mov( rax, ptr[ rax ] );
|
|
|
|
// ldm *= sizeof(float)
|
|
lea( rax, ptr[ rax * 4 ] );
|
|
|
|
mov( rsi, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
|
|
mov( rbx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
|
|
|
|
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(float) )
|
|
imul( rsi, rax );
|
|
|
|
// rsi = post_op_c_i * ( rs_c_downscale * sizeof(float) )
|
|
// + post_op_c_j * sizeof(float)
|
|
lea( rsi, ptr[ rsi + rbx * 4] );
|
|
|
|
add( rcx, rsi );
|
|
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
for( dim_t n = 0; n < num_full_loads; n++)
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + n;
|
|
vmovups(Zmm( load_start_idx + n ), ptr[ rcx + n * 64 ] );
|
|
vaddps( Zmm( reg_num ), Zmm( reg_num ),
|
|
Zmm( load_start_idx + n ) );
|
|
}
|
|
if( n_rem )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
|
|
vmovups( Zmm( load_start_idx + num_full_loads ) | k4 | T_z,
|
|
ptr[ rcx + num_full_loads * 64 ] );
|
|
vaddps( Zmm( reg_num ), Zmm( reg_num ),
|
|
Zmm( load_start_idx + num_full_loads ) );
|
|
}
|
|
|
|
// move to next row
|
|
add( rcx, rax );
|
|
}
|
|
}
|
|
void bli_lpgemm_jit:: bias_row_major( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
dim_t reg_num;
|
|
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] );
|
|
mov( rbx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
|
|
|
|
mov( rcx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, c_stor_type ) ] );
|
|
cmp( rcx, 4 );
|
|
je( "BIAS_BF16_ROW_MAJOR", T_NEAR );
|
|
|
|
// postops_c_j *= sizeof(float)
|
|
lea( rbx, ptr[ rbx * 4 ] );
|
|
add( rax, rbx );
|
|
for( dim_t n = 0; n < num_full_loads; n++ )
|
|
{
|
|
vmovups( Zmm( load_start_idx + n ), ptr[ rax + n * 64 ] );
|
|
}
|
|
if( n_rem )
|
|
{
|
|
vmovups( Zmm( load_start_idx + num_full_loads ) | k4,
|
|
ptr[ rax + num_full_loads * 64 ] );
|
|
}
|
|
jmp( "POST_BIAS_BF16_ROW_MAJOR", T_NEAR );
|
|
|
|
L( "BIAS_BF16_ROW_MAJOR" );
|
|
// postops_c_j *= sizeof(bfloat16)
|
|
lea( rbx, ptr[ rbx * 2 ] );
|
|
add( rax, rbx );
|
|
for( dim_t n = 0; n < num_full_loads; n++ )
|
|
{
|
|
// convert from 16 bit elements to 32 bit elements
|
|
vpmovsxwd( Zmm( load_start_idx + n ), ptr[ rax + n * 32 ] );
|
|
|
|
// Shift left by 16 bits
|
|
vpslld( Zmm( load_start_idx + n ), Zmm( load_start_idx + n ), 0x10 );
|
|
}
|
|
if( n_rem )
|
|
{
|
|
// load the bf16 elements from the downscale buffer using mask.
|
|
vmovdqu16( Ymm( load_start_idx + num_full_loads ) | k4 | T_z,
|
|
ptr[rax + num_full_loads * 32 ] );
|
|
|
|
// convert from 16 bit elements to 32 bit elements
|
|
vpmovsxwd( Zmm( load_start_idx + num_full_loads ),
|
|
Ymm( load_start_idx + num_full_loads ) );
|
|
|
|
// Shift left by 16 bits
|
|
vpslld( Zmm( load_start_idx + num_full_loads ),
|
|
Zmm( load_start_idx + num_full_loads ), 0x10 );
|
|
}
|
|
L( "POST_BIAS_BF16_ROW_MAJOR" );
|
|
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
for( dim_t n = 0; n < num_loads; n++ )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + n;
|
|
vaddps( Zmm( reg_num ), Zmm( reg_num ),
|
|
Zmm( load_start_idx + n ) );
|
|
}
|
|
}
|
|
}
|
|
|
|
void bli_lpgemm_jit:: bias_col_major( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
dim_t reg_num;
|
|
|
|
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args1 ) ] );
|
|
mov( rbx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
|
|
|
|
mov( rcx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, c_stor_type ) ] );
|
|
cmp( rcx, 4 );
|
|
je( "BIAS_BF16_COL_MAJOR", T_NEAR );
|
|
|
|
// postops_c_i *= sizeof(float)
|
|
lea( rbx, ptr[ rbx * 4 ] );
|
|
add( rax, rbx );
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
vbroadcastss( Zmm( alpha_reg ), ptr[ rax + m * 4 ] );
|
|
for( dim_t n = 0; n < num_loads; n++ )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + n;
|
|
vaddps( Zmm( reg_num ), Zmm( reg_num ), Zmm( alpha_reg ) );
|
|
}
|
|
}
|
|
jmp( "POST_BIAS_BF16_COL_MAJOR", T_NEAR );
|
|
|
|
L( "BIAS_BF16_COL_MAJOR" );
|
|
// postops_c_i *= sizeof(bfloat16)
|
|
lea( rbx, ptr[ rbx * 2 ] );
|
|
add( rax, rbx );
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
vpbroadcastw( Zmm( alpha_reg ), ptr[ rax + m * 4 ] );
|
|
|
|
// convert from 16 bit elements to 32 bit elements
|
|
vpmovsxwd( Zmm( alpha_reg ), Ymm( alpha_reg ) );
|
|
|
|
// Shift left by 16 bits
|
|
vpslld( Zmm( alpha_reg ), Zmm( alpha_reg ), 0x10 );
|
|
|
|
for( dim_t n = 0; n < num_loads; n++ )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + n;
|
|
vaddps( Zmm( reg_num ), Zmm( reg_num ), Zmm( alpha_reg ) );
|
|
}
|
|
}
|
|
L( "POST_BIAS_BF16_COL_MAJOR" );
|
|
}
|
|
|
|
void bli_lpgemm_jit:: relu( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
dim_t scratch_reg = bcst_start_idx;
|
|
|
|
vpxorq(Zmm( scratch_reg ), Zmm( scratch_reg ), Zmm( scratch_reg ) );
|
|
|
|
for( dim_t m = fma_start_idx; m < 32; m++ )
|
|
{
|
|
vmaxps( Zmm( m ), Zmm( m ), Zmm( scratch_reg ) );
|
|
}
|
|
}
|
|
|
|
void bli_lpgemm_jit:: relu_scale( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
dim_t zero_reg = load_start_idx;
|
|
dim_t scale_factor = bcst_start_idx;
|
|
|
|
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] );
|
|
vbroadcastss( Zmm( scale_factor ), ptr[ rax ] );
|
|
vpxorq( Zmm( zero_reg ), Zmm( zero_reg ), Zmm( zero_reg ) );
|
|
|
|
for( dim_t m = fma_start_idx; m < 32; m++ )
|
|
{
|
|
vcmpps( k5, Zmm( m ), Zmm( zero_reg ), 0x02 );
|
|
vmulps( Zmm( m ) | k5, Zmm( m ), Zmm( scale_factor ) );
|
|
}
|
|
}
|
|
|
|
void bli_lpgemm_jit::apply_post_ops_in_high_reg_pressure
|
|
(
|
|
const dim_t num_post_op_regs,
|
|
std::function< void( dim_t ) > op_fn
|
|
)
|
|
{
|
|
dim_t num_push_regs = num_post_op_regs - fma_start_idx ;
|
|
|
|
// If number of registers required to compute pots op is more than
|
|
// registers available, then push some accum registers to stack
|
|
// and use them to compute gelu.
|
|
store_zmms_in_stack( fma_start_idx, num_push_regs, 0 );
|
|
|
|
dim_t post_op_start = num_push_regs > 0 ? fma_start_idx + num_push_regs
|
|
: fma_start_idx;
|
|
|
|
// operate on non-pushed regs
|
|
for( dim_t reg = post_op_start; reg < 32; reg++ )
|
|
{
|
|
op_fn( reg );
|
|
}
|
|
|
|
// Push num_push_regs number of registers from last to stack and
|
|
// replace them with the items that were pushed earlier
|
|
// and compute on them.
|
|
store_zmms_in_stack( 32 - num_push_regs, num_push_regs,
|
|
num_push_regs * 64 );
|
|
get_zmms_from_stack( 32 - num_push_regs, num_push_regs, 0);
|
|
|
|
for( dim_t reg = 0; reg < num_push_regs; reg++ )
|
|
{
|
|
op_fn( 32 - num_push_regs + reg );
|
|
}
|
|
|
|
for( dim_t reg = 0; reg < num_push_regs; reg++ )
|
|
vmovups( Zmm( fma_start_idx + reg ),
|
|
Zmm( 32 - num_push_regs + reg ) );
|
|
|
|
get_zmms_from_stack( 32 - num_push_regs, num_push_regs,
|
|
num_push_regs * 64 );
|
|
}
|
|
|
|
//r2 and z, q are scratch regs
|
|
//r will be passed in and out of parent function.
|
|
void bli_lpgemm_jit:: POLY_EVAL_6_AVX512( )
|
|
{
|
|
vmulps( Zmm( r2 ), Zmm( r ), Zmm( r ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_exp_off, 3) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_exp_off, 2) );
|
|
|
|
vmovups( Zmm( q ), Zmm( const2 ) );
|
|
vfmadd231ps( Zmm( q ), Zmm( const1 ), Zmm( r ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_exp_off, 1) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_exp_off, 0) );
|
|
|
|
vmovups( Zmm( z ), Zmm( const2 ) );
|
|
vfmadd231ps( Zmm( z ), Zmm( const1 ), Zmm( r ) );
|
|
|
|
vfmadd231ps( Zmm( z ), Zmm( r2 ), Zmm( q ) );
|
|
|
|
vmulps(Zmm( r2 ), Zmm( r2 ), Zmm( r2 ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_exp_off, 5) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_exp_off, 4) );
|
|
|
|
vfmadd231ps( Zmm( const2 ), Zmm( const1 ), Zmm( r ) );
|
|
|
|
vfmadd231ps( Zmm( z ), Zmm( const2 ), Zmm( r2 ) );
|
|
vmovups(Zmm( r ), Zmm( z ) );
|
|
}
|
|
|
|
// z, r, dn is a scratch register
|
|
// takes 'x' as input and returns 'q' to the parent
|
|
void bli_lpgemm_jit:: EXPF_AVX512()
|
|
{
|
|
vbroadcastss( Zmm( const1 ), get_constant(gelu_macros_off, 0) );
|
|
|
|
vmulps( Zmm( z ), Zmm( x ), Zmm(const1 ) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(gelu_macros_off, 1) );
|
|
|
|
vaddps( Zmm( dn ), Zmm( z ), Zmm( const2 ) );
|
|
|
|
vsubps( Zmm( r ), Zmm( dn ), Zmm( const2 ) );
|
|
vsubps( Zmm( r ), Zmm( z ), Zmm( r ) );
|
|
|
|
POLY_EVAL_6_AVX512();
|
|
|
|
vpslld( Zmm( dn ), Zmm( dn ), 0x17 );
|
|
|
|
vpaddd( Zmm( q ), Zmm( r ), Zmm( dn ) );
|
|
|
|
vpxorq( Zmm( const2 ), Zmm( const2 ), Zmm( const2 ) );
|
|
|
|
vpbroadcastd( Zmm( const1 ), get_constant(gelu_macros_off, 2) );
|
|
|
|
vcmpps( k5, Zmm( const1 ), Zmm( x ), 0x06 );
|
|
|
|
vpandd( Zmm( q ) | k5, Zmm( q ), Zmm( const2 ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(gelu_macros_off, 3) );
|
|
|
|
vcmpps( k5, Zmm( const1 ), Zmm( x ), 0x06 );
|
|
|
|
vbroadcastss( Zmm( x ), get_constant(gelu_macros_off, 4) );
|
|
|
|
vpxord( Zmm( x ) | k5, Zmm( q ), Zmm( const2 ) );
|
|
vmovups(Zmm( q ), Zmm( x ) );
|
|
}
|
|
|
|
// uses z, dn, r as scratch regs
|
|
// passes r to child macro and gets q
|
|
// takes x_tanh as input and gives back x_tanh
|
|
void bli_lpgemm_jit:: TANHF_AVX512()
|
|
{
|
|
vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 2) );
|
|
|
|
mov( ebx, 0x7FFFFFFF );
|
|
vpbroadcastd( Zmm( const2 ), ebx );
|
|
vpandd( Zmm( x ), Zmm( x_tanh ), Zmm( const2 ) );
|
|
|
|
vmulps( Zmm( x ), Zmm( x ), Zmm( const1 ) );
|
|
|
|
EXPF_AVX512();
|
|
|
|
mov( eax, -1 );
|
|
vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 4) );
|
|
|
|
vaddps( Zmm( z ), Zmm( q ), Zmm( const1 ) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(gelu_consts_off, 5) );
|
|
|
|
vaddps( Zmm( r ), Zmm( z ), Zmm( const2 ) );
|
|
|
|
vdivps( Zmm( z ), Zmm( z ), Zmm( r ) );
|
|
|
|
vmulps( Zmm( z ), Zmm( z ), Zmm( const1 ) );
|
|
|
|
mov( eax, -2147483648 );
|
|
vpbroadcastd( Zmm( const1 ), eax );
|
|
|
|
vpandd(Zmm( q ), Zmm( x_tanh ), Zmm( const1 ) );
|
|
|
|
vpxord( Zmm( x_tanh ), Zmm( q ), Zmm( z ) );
|
|
}
|
|
|
|
void bli_lpgemm_jit:: GELU_TANH_F32_AVX512_DEF(dim_t reg )
|
|
{
|
|
vmulps( Zmm( r2 ), Zmm( reg ), Zmm( reg ) );
|
|
vmulps( Zmm( r2 ), Zmm( r2 ), Zmm( reg ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 0) );
|
|
vmovups( Zmm( r ), Zmm( reg ) );
|
|
vfmadd231ps( Zmm( r ), Zmm( r2 ), Zmm( const1 ) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(gelu_consts_off, 1) );
|
|
vmulps( Zmm( x_tanh ), Zmm( r ), Zmm( const2 ) );
|
|
|
|
TANHF_AVX512();
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(gelu_consts_off, 6) );
|
|
vaddps( Zmm( x_tanh ), Zmm( x_tanh ), Zmm( const2 ) );
|
|
vmulps( Zmm( x_tanh ), Zmm( x_tanh ), Zmm( reg ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 3) );
|
|
vmulps( Zmm( reg ), Zmm( x_tanh ), Zmm( const1 ) );
|
|
}
|
|
|
|
void bli_lpgemm_jit:: gelu_tanh( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
apply_post_ops_in_high_reg_pressure
|
|
(
|
|
num_gelu_regs,
|
|
std::bind
|
|
(
|
|
&bli_lpgemm_jit::GELU_TANH_F32_AVX512_DEF,
|
|
this,
|
|
std::placeholders::_1
|
|
)
|
|
);
|
|
}
|
|
|
|
void bli_lpgemm_jit:: POLY_EVAL_HORNER_16_0_AVX512()
|
|
{
|
|
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 15) );
|
|
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 14) );
|
|
|
|
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 13) );
|
|
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 12) );
|
|
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 11) );
|
|
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 10) );
|
|
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 9) );
|
|
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 8) );
|
|
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 7 ) );
|
|
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 6) );
|
|
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 5) );
|
|
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 4) );
|
|
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 3) );
|
|
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 2) );
|
|
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(lpgemm_erf_off, 1) );
|
|
vfmadd231ps( Zmm( const1 ), Zmm( r ), Zmm( const2 ) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(lpgemm_erf_off, 0) );
|
|
vfmadd231ps( Zmm( const2 ), Zmm( r ), Zmm( const1 ) );
|
|
|
|
vmulps( Zmm( x ), Zmm( const2 ), Zmm( r ) );
|
|
}
|
|
|
|
void bli_lpgemm_jit:: ERF_AVX512()
|
|
{
|
|
mov( eax, 0x7FFFFFFF );
|
|
vpbroadcastd( Zmm( const2 ), eax );
|
|
vpandd( Zmm( r ), Zmm( x_erf ), Zmm( const2 ) );
|
|
|
|
POLY_EVAL_HORNER_16_0_AVX512();
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(erf_consts_off, 1) );
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(erf_consts_off, 3) );
|
|
|
|
vcmpps( k5, Zmm( const2 ), Zmm( r ), 0x06 );
|
|
|
|
vpxorq( Zmm( const2 ), Zmm( const2 ), Zmm( const2 ) );
|
|
|
|
vpxord( Zmm( const1 ) | k5, Zmm( x ), Zmm( const2 ) );
|
|
vmovups( Zmm( x ), Zmm( const1 ) );
|
|
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(erf_consts_off, 1) );
|
|
|
|
vcmpps( k5, Zmm( const1 ), Zmm( x ), 0x06 );
|
|
|
|
vpxord( Zmm( const1 ) | k5, Zmm( x ), Zmm( const2 ) );
|
|
|
|
mov( eax, ~(0x7FFFFFFF));
|
|
vpbroadcastd( Zmm( const2 ), eax );
|
|
|
|
vpandd( Zmm( x_erf ), Zmm( x_erf ), Zmm( const2 ) );
|
|
|
|
vpord( Zmm( x_erf ), Zmm( x_erf ), Zmm( const1 ) );
|
|
}
|
|
|
|
void bli_lpgemm_jit:: GELU_ERF_F32_AVX512_DEF( dim_t reg )
|
|
{
|
|
vbroadcastss( Zmm( const1 ), get_constant(erf_consts_off, 0) );
|
|
vmulps( Zmm( x_erf ), Zmm( reg ), Zmm( const1 ) );
|
|
|
|
ERF_AVX512();
|
|
|
|
vbroadcastss( Zmm( const2 ), get_constant(erf_consts_off, 1) );
|
|
vaddps( Zmm( x_erf ), Zmm( x_erf ), Zmm( const2 ) );
|
|
|
|
vmulps( Zmm( x_erf ), Zmm( x_erf ), Zmm( reg ) );
|
|
vbroadcastss( Zmm( const2 ), get_constant(erf_consts_off, 2) );
|
|
vmulps( Zmm( reg ), Zmm( x_erf ), Zmm( const2 ) );
|
|
|
|
}
|
|
|
|
void bli_lpgemm_jit:: gelu_erf( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
apply_post_ops_in_high_reg_pressure
|
|
(
|
|
num_gelu_regs,
|
|
std::bind
|
|
(
|
|
&bli_lpgemm_jit::GELU_ERF_F32_AVX512_DEF,
|
|
this,
|
|
std::placeholders::_1
|
|
)
|
|
);
|
|
}
|
|
|
|
void bli_lpgemm_jit::SWISH_F32_AVX512_DEF( dim_t reg )
|
|
{
|
|
vpxorq( Zmm( x ), Zmm( x ), Zmm( x ) );
|
|
vfnmadd231ps( Zmm( x ), Zmm( reg ), Zmm( x_tanh ) );
|
|
|
|
// Input reg x and output reg q.
|
|
EXPF_AVX512();
|
|
|
|
vbroadcastss( Zmm( const1 ), get_constant(gelu_consts_off, 6) );
|
|
vaddps( Zmm( q ), Zmm( q ), Zmm( const1 ) );
|
|
vdivps( Zmm( reg ), Zmm( reg ), Zmm( q ) );
|
|
}
|
|
|
|
void bli_lpgemm_jit::swish( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] );
|
|
vbroadcastss( Zmm( x_tanh ), ptr[ rax ] );
|
|
|
|
apply_post_ops_in_high_reg_pressure
|
|
(
|
|
num_gelu_regs,
|
|
std::bind
|
|
(
|
|
&bli_lpgemm_jit::SWISH_F32_AVX512_DEF,
|
|
this,
|
|
std::placeholders::_1
|
|
)
|
|
);
|
|
}
|
|
|
|
void bli_lpgemm_jit:: store_f32( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
dim_t reg_num;
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
if( m > 0 ) add( rcx, rdi );
|
|
|
|
for( dim_t n = 0; n < num_full_loads; n++ )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + n;
|
|
vmovups( ptr[ rcx + n * 64 ], Zmm( reg_num ) );
|
|
}
|
|
|
|
// Use mask in case of n_fringe.
|
|
if( n_rem )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
|
|
vmovups( ptr[ rcx + num_full_loads * 64 ] | k4, Zmm( reg_num ) );
|
|
}
|
|
}
|
|
}
|
|
void bli_lpgemm_jit:: cvt_store_f32_bf16_mask( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
dim_t reg_num;
|
|
|
|
mov( rcx, ptr[ rsp + stack_off_buf_downscale ] );
|
|
mov( rax, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, rs_c_downscale ) ] );
|
|
|
|
// rs_c_downscale *= sizeof(bfloat16)
|
|
lea( rax, ptr[rax * 2 ] );
|
|
mov( rsi, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
|
|
mov( rbx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
|
|
|
|
imul( rsi, rax );
|
|
lea( rsi, ptr[ rsi + rbx * 2 ] );
|
|
add( rcx, rsi );
|
|
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
for( dim_t n = 0; n < num_full_loads; n++ )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + n;
|
|
// convert from 32 bit elements to 16 bit elements
|
|
vcvtneps2bf16( Ymm( reg_num ), Zmm( reg_num ) );
|
|
vmovdqu16( ptr[ rcx + n * 32 ], Ymm( reg_num ) );
|
|
}
|
|
if( n_rem )
|
|
{
|
|
reg_num = fma_start_idx + ( m * num_loads ) + num_full_loads;
|
|
// convert from 32 bit elements to 16 bit elements
|
|
vcvtneps2bf16( Ymm( reg_num ), Zmm( reg_num ) );
|
|
vmovdqu16( ptr[ rcx + num_full_loads * 32 ] | k4, Ymm( reg_num ) );
|
|
}
|
|
// move to next row
|
|
add( rcx, rax );
|
|
}
|
|
}
|
|
|
|
void bli_lpgemm_jit::initialize_params( lpgemm_jit_inputs_t* params )
|
|
{
|
|
// params needed in kernel
|
|
// a(r14, rax), b(rbx), c(r12, rcx) podim_ters. To be stored in regs
|
|
// rs_a(r8), cs_a(r9), rs_b(r10), rs_c(rdi).
|
|
// alpha(rax), beta(rbx) values. To be pushed to stack
|
|
// m_iter(r11), ps_a(rax) values. ps_a to be pushed to stack.
|
|
// k_iter(rsi), k_left(rsi) value. To be pushed to stack.
|
|
|
|
// load values from params struct to registers and stack
|
|
if( params->m_loop )
|
|
{
|
|
// move address of a
|
|
mov( r14, ptr[ rdi + offsetof( lpgemm_jit_params_t, a ) ] );
|
|
mov( r11, ptr[ rdi + offsetof( lpgemm_jit_params_t, m_iter ) ] );
|
|
}
|
|
else
|
|
{
|
|
mov( rax, ptr[ rdi + offsetof(lpgemm_jit_params_t, a ) ] );
|
|
}
|
|
|
|
if( params->generate_mask )
|
|
{
|
|
// This mask will be used to load/store bf16 elements
|
|
kmovd( k3, ptr[ rdi + offsetof( lpgemm_jit_params_t, mask16 ) ] );
|
|
// This mask will be used to load/store f32 elements
|
|
kmovw( k4, ptr[ rdi + offsetof(lpgemm_jit_params_t, mask32 ) ] );
|
|
}
|
|
|
|
mov( r12, ptr[ rdi + offsetof( lpgemm_jit_params_t, c ) ] );
|
|
mov( r8, ptr[ rdi + offsetof( lpgemm_jit_params_t, rs_a ) ] );
|
|
mov( r9, ptr[ rdi + offsetof( lpgemm_jit_params_t, cs_a ) ] );
|
|
mov( r10, ptr [rdi + offsetof( lpgemm_jit_params_t, rs_b ) ] );
|
|
|
|
|
|
// Push all the params that will be required in later stages
|
|
// of kernel to stack.
|
|
// Pusing in order ps_a2, k_iter, k_left, alpha, beta, b
|
|
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, ps_a2 ) ] );
|
|
mov( ptr[ rsp + stack_off_ps_a ], rbx);
|
|
|
|
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t,
|
|
k_iter_before_prefetch ) ] );
|
|
mov( ptr[ rsp + stack_off_k_iter_before_prefetch ], rbx );
|
|
|
|
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t,
|
|
k_iter_after_prefetch ) ] );
|
|
mov( ptr[ rsp + stack_off_k_iter_after_prefetch ], rbx );
|
|
|
|
|
|
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, k_left ) ] );
|
|
mov( ptr[ rsp + stack_off_k_left ], rbx );
|
|
|
|
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, alpha ) ] );
|
|
mov( ptr[ rsp + stack_off_alpha ], rbx );
|
|
|
|
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, beta ) ] );
|
|
mov( ptr[ rsp + stack_off_beta ], rbx );
|
|
|
|
mov( rbx, ptr[ rdi + offsetof( lpgemm_jit_params_t, b ) ] );
|
|
mov( ptr[ rsp + stack_off_b_ptr ], rbx );
|
|
|
|
// once all the params that will be required in
|
|
// later stages of kernel are pushed to stack,
|
|
// move rs_c dim_to rdi.
|
|
mov( rdi, ptr[ rdi + offsetof( lpgemm_jit_params_t, rs_c ) ] );
|
|
|
|
|
|
// push all members of lpgemm_post_op_attr struct to stack.
|
|
// Since this will be passed as 2nd arg to the function, it will be in rsi
|
|
|
|
mov( rbx, ptr[ rsi + offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
|
|
mov( ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_i ) ], rbx );
|
|
|
|
mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, post_op_c_j ) ] );
|
|
mov( ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_j ) ], rcx );
|
|
|
|
mov( rbx, ptr[ rsi + offsetof( lpgemm_post_op_attr, rs_c_downscale ) ] );
|
|
mov( ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, rs_c_downscale)], rbx );
|
|
|
|
mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, cs_c_downscale ) ] );
|
|
mov( ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, cs_c_downscale)], rcx );
|
|
|
|
mov( rbx, ptr[ rsi + offsetof(lpgemm_post_op_attr, buf_downscale ) ] );
|
|
mov( ptr[ rsp + stack_off_buf_downscale ], rbx );
|
|
|
|
mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, is_first_k ) ] );
|
|
mov( ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, is_first_k ) ], rcx );
|
|
|
|
mov( rbx, ptr[ rsi + offsetof(lpgemm_post_op_attr, is_last_k ) ] );
|
|
mov( ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, is_last_k ) ], rbx );
|
|
|
|
mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, c_stor_type ) ] );
|
|
mov( ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, c_stor_type ) ], rcx );
|
|
|
|
mov( rbx, ptr[ rsi + offsetof(lpgemm_post_op_attr, b_sum_offset)]);
|
|
mov( ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, b_sum_offset )] , rbx );
|
|
|
|
mov( rcx, ptr[ rsi + offsetof( lpgemm_post_op_attr, b_col_sum_vec ) ] );
|
|
mov( ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, b_col_sum_vec ) ], rcx );
|
|
|
|
mov( rbx, ptr[ rsi +
|
|
offsetof( lpgemm_post_op_attr, b_col_sum_vec_s16 ) ] );
|
|
|
|
mov( ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, b_col_sum_vec_s16 ) ], rbx );
|
|
|
|
// Storing the address to the head node of post-op list in stack
|
|
// It needs to be restored after every loop of m_iter
|
|
mov( ptr[ rsp + stack_off_temp_list ], rdx );
|
|
|
|
// initialize top of zmm stack
|
|
zmm_stack_top = stack_off_zmm_stack;
|
|
}
|
|
|
|
void bli_lpgemm_jit:: prefetchC( dim_t m_dim, dim_t n_dim )
|
|
{
|
|
for( dim_t m = 0; m < m_dim; m++ )
|
|
{
|
|
if( m > 0 ) add( rcx, rdi );
|
|
for( dim_t n = 0; n < num_loads; n++ )
|
|
{
|
|
prefetcht1( ptr[ rcx + n * 64 ] );
|
|
}
|
|
}
|
|
}
|
|
|
|
void bli_lpgemm_jit:: post_op_label_lastk_safe_jump_with_next_ptr()
|
|
{
|
|
mov( rdx, ptr[rdx+offsetof( lpgemm_post_op, next ) ] );
|
|
post_op_label_lastk_safe_jump();
|
|
}
|
|
void bli_lpgemm_jit:: post_op_label_lastk_safe_jump()
|
|
{
|
|
// check if post_ops_list_temp != NULL
|
|
cmp( rdx, 0 );
|
|
je( "POST_OPS_6x64_DISABLE", T_NEAR );
|
|
|
|
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_code ) ] );
|
|
cmp( rax, POST_OPS_DISABLE );
|
|
je( "POST_OPS_6x64_DISABLE", T_NEAR );
|
|
cmp( rax, POST_OPS_BIAS ) ;
|
|
je( "POST_OPS_BIAS_6x64", T_NEAR );
|
|
cmp( rax, POST_OPS_RELU );
|
|
je( "POST_OPS_RELU_6x64", T_NEAR );
|
|
cmp( rax, POST_OPS_RELU_SCALE );
|
|
je( "POST_OPS_RELU_SCALE_6x64", T_NEAR );
|
|
cmp( rax, POST_OPS_GELU_TANH );
|
|
je( "POST_OPS_GELU_TANH_6x64", T_NEAR );
|
|
cmp( rax, POST_OPS_GELU_ERF );
|
|
je( "POST_OPS_GELU_ERF_6x64", T_NEAR );
|
|
cmp( rax, POST_OPS_CLIP );
|
|
je( "POST_OPS_CLIP_6x64", T_NEAR );
|
|
cmp( rax, POST_OPS_DOWNSCALE );
|
|
je( "POST_OPS_DOWNSCALE_6x64", T_NEAR );
|
|
cmp( rax, POST_OPS_MATRIX_ADD );
|
|
je( "POST_OPS_MATRIX_ADD_6x64", T_NEAR );
|
|
cmp( rax, POST_OPS_SWISH );
|
|
je( "POST_OPS_SWISH_6x64", T_NEAR );
|
|
}
|
|
|
|
// Constructor
|
|
bli_lpgemm_jit:: bli_lpgemm_jit( void* buffer, size_t bufferSize )
|
|
: CodeGenerator( bufferSize, buffer )
|
|
{
|
|
protect( buffer, bufferSize, PROTECT_RWE );
|
|
}
|
|
|
|
// Main kernel function body
|
|
void bli_lpgemm_jit::generate_kernel( lpgemm_jit_inputs_t* params )
|
|
{
|
|
|
|
dim_t m_dim = params->MR;
|
|
dim_t n_dim = params->NR;
|
|
|
|
// In kernel-function pointer array, kernels to handle n < 16
|
|
// are stored at col-index 0. Hacking n_dim to some value 0 < value < 16
|
|
// so masked instructions are generated.
|
|
// This will be removed when we support on-the-fly generation of kernels.
|
|
if( n_dim == 0 )
|
|
{
|
|
n_dim = 2;
|
|
params->generate_mask = TRUE;
|
|
}
|
|
|
|
|
|
n_rem = n_dim % NUM_F32_ELEMS_PER_ZMM;
|
|
|
|
// Number of loads that doesn't require mask
|
|
num_full_loads = ( n_dim / num_elems_per_reg );
|
|
|
|
// Number of loads in total = full loads + mask load (if required)
|
|
num_loads = ( num_full_loads ) + ( n_rem > 0 ? 1 : 0 );
|
|
|
|
// Total number of registers to store accumulated values.
|
|
num_fma_regs = m_dim * num_loads;
|
|
|
|
// calculating start index for accumulation registers.
|
|
// If the kernel requires 'x' number of accumulation regs, we use the
|
|
// last 'x' ZMMs available on certain architecture.
|
|
// 31 is hardcoded here since we only support AVX-512 as of now,
|
|
// This needs to be made as a configurable parameter later.
|
|
fma_start_idx = 31 - num_fma_regs + 1;
|
|
|
|
// If a kernel requires x registers for loads, we always use the
|
|
// first 'x' ZMM registers available for loads.
|
|
// And the immediate registers next to load regs are used for broadcast.
|
|
bcst_start_idx = load_start_idx + num_loads;
|
|
|
|
// While scaling the accumulated registers with beta,
|
|
// load regs will be used to load C matrix,
|
|
// Hence using broadcast register to store beta value.
|
|
beta_reg = bcst_start_idx;
|
|
|
|
|
|
preamble();
|
|
// add some spack in stack to store params
|
|
sub( rsp, 512 );
|
|
// Initialize all the paramters required for execution of kernel.
|
|
// load some values to registers and push the rest of them to stack.
|
|
initialize_params( params );
|
|
|
|
/* register usage:
|
|
r14, rax - podim_ter for A matrix
|
|
r8 - rs_a
|
|
r9 - cs_a
|
|
r13 - 3 * rs_a
|
|
r15 - 5 * rs_a
|
|
rbx - podim_ter to B matrix, beta
|
|
r10 - rs_b
|
|
r12, rcx - podim_ter for C matrix
|
|
rdi - rs_c
|
|
r11 - m_iter
|
|
rsi - k_iter, k_left
|
|
rax - ps_a2, alpha
|
|
*/
|
|
|
|
|
|
lea( rdi, ptr[ rdi * 4 ] ); // rs_c *= sizeof(float) => rs_c *= 4
|
|
|
|
lea( r8, ptr[ r8 * 2 ] ); // rs_a *= sizeof(dt) => rs_a *= 2
|
|
lea( r9, ptr[ r9 * 2 ] ); // cs_a *= sizeof(dt) => cs_a *= 2
|
|
if ( m_dim >= 4)
|
|
lea( r13, ptr[r8 + r8 * 2 ] ); // r13 = 3 * rs_a
|
|
if( m_dim >= 6 )
|
|
lea( r15, ptr[r8 + r8 * 4 ] ); // r15 = 5 * rs_a
|
|
|
|
lea( r10, ptr[ r10 * 2 ] ); // rs_b *= sizeof(dt) => rs_b *= 2
|
|
|
|
|
|
mov( rcx, r12 );
|
|
|
|
if( params->m_loop )
|
|
{
|
|
|
|
L( "BLOOP6X64I" );
|
|
mov( rax, r14 ); // reset rax to current upanel of a.
|
|
}
|
|
|
|
|
|
mov( rbx, ptr[ rsp + stack_off_b_ptr ] ); // move address of b
|
|
|
|
|
|
// Zero all the registers that will be used for accumulation.
|
|
reg_init( m_dim, n_dim );
|
|
|
|
// load k_iter
|
|
mov( rsi, ptr[ rsp + stack_off_k_iter_before_prefetch ] );
|
|
test( rsi, rsi );
|
|
je( "BPREFETCH", T_NEAR );
|
|
L( "BLOOPKITER" );
|
|
|
|
// Main k-unroll loop
|
|
kernel_unroll( m_dim, n_dim );
|
|
|
|
dec( rsi ); // i -= 1
|
|
jne("BLOOPKITER", T_NEAR );
|
|
|
|
L( "BPREFETCH" );
|
|
|
|
prefetchC( m_dim, n_dim );
|
|
|
|
mov( rsi, ptr[ rsp + stack_off_k_iter_after_prefetch ] );
|
|
test( rsi, rsi );
|
|
je( "BCONSIDKLEFT", T_NEAR );
|
|
|
|
L( "AFTERPREFETCH" );
|
|
|
|
kernel_unroll( m_dim, n_dim );
|
|
|
|
dec( rsi );
|
|
jne( "AFTERPREFETCH", T_NEAR );
|
|
|
|
L( "BCONSIDKLEFT" );
|
|
// load k_left
|
|
mov( rsi, ptr[ rsp + stack_off_k_left ] );
|
|
test( rsi, rsi );
|
|
je( "BPOSTACCUM", T_NEAR );
|
|
|
|
// k_fringe
|
|
k_fringe_loop( m_dim, n_dim );
|
|
|
|
|
|
L( "BPOSTACCUM" );
|
|
|
|
// Generate alpha scaling code only when required.
|
|
if( params->alpha_scale )
|
|
{
|
|
mov( rax, ptr[ rsp + stack_off_alpha ] ); // load address of alpha
|
|
vbroadcastss( Zmm( alpha_reg ), ptr[ rax ] );
|
|
|
|
scale_alpha( m_dim, n_dim );
|
|
|
|
}
|
|
|
|
mov( rbx, ptr[ rsp + stack_off_beta ] );
|
|
vbroadcastss( Xmm( beta_reg ), ptr[ rbx ] ); // load address of beta
|
|
|
|
// Zero out a register
|
|
vxorps( Xmm( alpha_reg ), Xmm( alpha_reg ) );
|
|
// cmp beta value with zero
|
|
vucomiss( Xmm( beta_reg ), Xmm( alpha_reg ) );
|
|
// if beta=0, skip beta scaling
|
|
je( "BPOSTBETAOP", T_NEAR );
|
|
|
|
// check if buf_downscale is NULL
|
|
mov( rax, ptr[ rsp + stack_off_buf_downscale ] );
|
|
cmp( rax, 0 );
|
|
je( "BETAOP", T_NEAR );
|
|
|
|
// Check if is_first_k is 0
|
|
mov( rcx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, is_first_k ) ] );
|
|
test( rcx, rcx );
|
|
je( "BETAOP", T_NEAR );
|
|
|
|
L( "DOWNSCALEBETAOP" );
|
|
vbroadcastss( Zmm( beta_reg ), ptr[ rbx ] );
|
|
bf16_f32_beta_op( m_dim, n_dim );
|
|
jmp( "BPOSTBETAOP", T_NEAR );
|
|
|
|
L( "BETAOP" );
|
|
mov( rcx, r12 );
|
|
vbroadcastss( Zmm( beta_reg ), ptr[ rbx ] );
|
|
f32_f32_beta_op( m_dim, n_dim );
|
|
|
|
L( "BPOSTBETAOP" );
|
|
|
|
// Check if is_last_k is 0
|
|
mov( rcx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, is_last_k ) ] );
|
|
test(rcx, rcx);
|
|
je( "POST_OPS_6x64_DISABLE", T_NEAR );
|
|
|
|
post_op_label_lastk_safe_jump();
|
|
|
|
|
|
L( "POST_OPS_BIAS_6x64" );
|
|
|
|
mov( rax, ptr[ rdx + offsetof( lpgemm_post_op, op_args2 ) ] );
|
|
mov( bl, ptr[ rax ] );
|
|
|
|
//check if op_args2 == 'R'
|
|
cmp( bl, 0x52 );
|
|
je("BIAS_ROW_MAJOR", T_NEAR );
|
|
// check if op_args2 == 'r
|
|
cmp( bl, 0x72 );
|
|
je( "BIAS_ROW_MAJOR", T_NEAR );
|
|
|
|
bias_col_major( m_dim, n_dim );
|
|
jmp( "POST_BIAS", T_NEAR );
|
|
|
|
L( "BIAS_ROW_MAJOR" );
|
|
bias_row_major( m_dim, n_dim );
|
|
|
|
L( "POST_BIAS" );
|
|
post_op_label_lastk_safe_jump_with_next_ptr();
|
|
|
|
L( "POST_OPS_RELU_6x64" );
|
|
relu( m_dim, n_dim );
|
|
post_op_label_lastk_safe_jump_with_next_ptr();
|
|
|
|
L( "POST_OPS_RELU_SCALE_6x64" );
|
|
relu_scale( m_dim, n_dim );
|
|
post_op_label_lastk_safe_jump_with_next_ptr();
|
|
|
|
L( "POST_OPS_GELU_TANH_6x64" );
|
|
gelu_tanh( m_dim, n_dim );
|
|
post_op_label_lastk_safe_jump_with_next_ptr();
|
|
|
|
L( "POST_OPS_GELU_ERF_6x64" );
|
|
gelu_erf( m_dim, n_dim );
|
|
post_op_label_lastk_safe_jump_with_next_ptr();
|
|
|
|
L( "POST_OPS_CLIP_6x64" );
|
|
clip_f32( m_dim, n_dim );
|
|
post_op_label_lastk_safe_jump_with_next_ptr();
|
|
|
|
L( "POST_OPS_DOWNSCALE_6x64" );
|
|
post_op_label_lastk_safe_jump_with_next_ptr();
|
|
|
|
L( "POST_OPS_MATRIX_ADD_6x64" );
|
|
|
|
mov( rcx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, c_stor_type ) ] );
|
|
cmp( rcx, 4 );
|
|
je( "BF16_MATADD", T_NEAR );
|
|
f32_f32_matrix_add( m_dim, n_dim );
|
|
jmp( "POST_MATADD", T_NEAR );
|
|
L( "BF16_MATADD" );
|
|
bf16_f32_matrix_add( m_dim, n_dim );
|
|
L( "POST_MATADD" );
|
|
|
|
post_op_label_lastk_safe_jump_with_next_ptr();
|
|
|
|
L( "POST_OPS_SWISH_6x64" );
|
|
swish( m_dim, n_dim );
|
|
post_op_label_lastk_safe_jump_with_next_ptr();
|
|
|
|
L( "POST_OPS_6x64_DISABLE" );
|
|
|
|
// check if buf_downscale is NULL
|
|
mov( rax, ptr[ rsp + stack_off_buf_downscale ] );
|
|
cmp( rax, 0 );
|
|
je( "F32_STORE", T_NEAR );
|
|
|
|
// Check if is_last_k is 0
|
|
mov( rcx, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, is_last_k ) ] );
|
|
test( rcx, rcx );
|
|
je( "F32_STORE", T_NEAR );
|
|
|
|
L( "BF16_STORE" );
|
|
//mov( rcx, ptr[rsp + stack_off_buf_downscale]);
|
|
cvt_store_f32_bf16_mask( m_dim, n_dim );
|
|
jmp( "END", T_NEAR );
|
|
|
|
L( "F32_STORE" );
|
|
mov( rcx, r12 );
|
|
store_f32( m_dim, n_dim );
|
|
|
|
L( "END" );
|
|
|
|
if( params->m_loop )
|
|
{
|
|
mov(rax, ptr[ rsp + stack_off_ps_a ] );
|
|
|
|
lea( r12, ptr[ r12 + rdi * 4 ] );
|
|
lea( r12, ptr[ r12 + rdi * 2 ] ); // c_ii = r12 += 6*rs_c;
|
|
|
|
lea(r14, ptr[ r14 + rax ] ); // a_ii = r14 += ps_a2
|
|
|
|
//add(, m_dim );
|
|
mov( rax, ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_i ) ] );
|
|
add( rax, m_dim);
|
|
|
|
mov( ptr[ rsp + stack_off_postop +
|
|
offsetof( lpgemm_post_op_attr, post_op_c_i ) ], rax );
|
|
|
|
mov( rdx, ptr[ rsp + stack_off_temp_list ] );
|
|
|
|
dec(r11);
|
|
jne("BLOOP6X64I", T_NEAR);
|
|
}
|
|
|
|
// release the space that is requested from stack
|
|
add( rsp, 512 );
|
|
|
|
// restore the callee-save registers.
|
|
postamble();
|
|
|
|
ret();
|
|
|
|
align(64);
|
|
L(tables);
|
|
|
|
db(reinterpret_cast<uint8_t*>( &gelu_consts ), sizeof( gelu_consts ) );
|
|
db(reinterpret_cast<uint8_t*>( &gelu_macros ), sizeof( gelu_macros ) );
|
|
db(reinterpret_cast<uint8_t*>( &lpgemm_exp ), sizeof( lpgemm_exp ) );
|
|
db(reinterpret_cast<uint8_t*>( &erf_consts ), sizeof( erf_consts ) );
|
|
db(reinterpret_cast<uint8_t*>( &lpgemm_erf ), sizeof( lpgemm_erf ) );
|
|
|
|
}
|
|
|
|
const void (* bli_lpgemm_jit:: get_function ()const)( lpgemm_jit_params_t*,
|
|
lpgemm_post_op_attr*,
|
|
lpgemm_post_op* )
|
|
{
|
|
return getCode<const void (*)( lpgemm_jit_params_t*,
|
|
lpgemm_post_op_attr*,
|
|
lpgemm_post_op*)>();
|
|
}
|
|
|
|
const void* bli_lpgemm_jit:: get_code ()const
|
|
{
|
|
return getCode<const void (*)>();
|
|
}
|
|
dim_t bli_lpgemm_jit:: get_size ()
|
|
{
|
|
return getSize();
|
|
}
|