Implemented reference unreorder bf16 function

Description:

Implemented a c reference for
aocl_gemm_unreorder_bf16bf16f32of32 function

The implementation working for row major and
column major yet to be enabled.

AMD-Internal: [ SWLCSG-3279 ]

Change-Id: Ibcce4180bb897a40252140012d8d6886c38cb77a
This commit is contained in:
Nallani Bhaskar
2025-02-06 10:17:40 +00:00
parent ef04388a44
commit 0acb5eb9a4
6 changed files with 1001 additions and 33 deletions

View File

@@ -140,6 +140,102 @@ AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32_reference)
}
AOCL_GEMM_UNREORDER(bfloat16, bf16bf16f32of32_reference)
{
if ( ( output_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) ||
( k <= 0 ) || ( n <= 0 ) )
{
return; // Error.
}
inc_t rs_b, cs_b;
// Check for the validity of strides.
if( ( order == 'r' ) || ( order == 'R' ) )
{
if( ldb < n ) return; // Error
else
{
rs_b = ldb;
cs_b = 1;
}
}
else if( ( order == 'c' ) || ( order == 'C' ) )
{
if( ldb < k ) return; // Error.
else
{
rs_b = 1;
cs_b = ldb;
}
}
else
{
return; // Error.
}
// Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
return; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
AOCL_MATRIX_TYPE input_mat_type;
bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type );
if ( input_mat_type == A_MATRIX )
{
return; // A reorder not supported.
}
#if (defined(BLIS_KERNELS_ZEN4) && (!defined(LPGEMM_BF16_JIT)))
if( n == 1 )
{
if( rs_b == 1 )
{
memcpy( output_buf_addr, reorder_buf_addr, ( k * sizeof( bfloat16 ) ) );
}
else
{
for( dim_t k0 = 0; k0 < k; k0++ )
{
output_buf_addr[k0*rs_b] = reorder_buf_addr[k0];
}
}
return;
}
#endif
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
// create dummy b_reorder obj.
lpgemm_obj_t b_reorder;
b_reorder.storage.aligned_buffer = ( void* )reorder_buf_addr;
// create dummy b obj.
lpgemm_obj_t b;
b.storage.aligned_buffer = ( void* )output_buf_addr;
b.rs = rs_b;
b.cs = cs_b;
b.width = n;
b.length = k;
unreorderb_nr64_bf16bf16f32of32_reference( &b, &b_reorder, &rntm_g, lcntx_g );
}
AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32)
{
if ( ( k <= 0 ) || ( n <= 0 ) )

View File

@@ -108,6 +108,7 @@ BLIS_EXPORT_ADDON void aocl_unreorder_ ## LP_SFX \
) \
AOCL_GEMM_UNREORDER(bfloat16, bf16bf16f32of32);
AOCL_GEMM_UNREORDER(bfloat16, bf16bf16f32of32_reference);
#define AOCL_GEMM_MATMUL(A_type,B_type,C_type,Sum_type,LP_SFX) \
BLIS_EXPORT_ADDON void aocl_gemm_ ## LP_SFX \

View File

@@ -170,6 +170,91 @@ void reorderb_nr64_bf16bf16f32of32_reference
b_reorder->mtag = REORDERED;
}
void unreorderb_nr64_bf16bf16f32of32_reference
(
lpgemm_obj_t * b,
lpgemm_obj_t * b_unreorder,
rntm_t* rntm,
lpgemm_cntx_t* lcntx
)
{
dim_t NC = lcntx->blksz.NC;
dim_t KC = lcntx->blksz.KC;
dim_t NR = lcntx->blksz.NR;
// Extracting the matrix properties from the lpgemm object
dim_t rs_b = b->rs;
dim_t cs_b = b->cs;
dim_t n = b->width;
dim_t k = b->length;
dim_t k_updated = k;
k_updated += (k_updated & 0x1);
dim_t n_threads = bli_rntm_num_threads( rntm );
n_threads = ( n_threads > 0 ) ? n_threads : 1;
#ifdef BLIS_ENABLE_OPENMP
_Pragma( "omp parallel num_threads(n_threads)" )
{
// Initialise a local thrinfo obj for work split across threads.
thrinfo_t thread_jc;
bli_thrinfo_set_n_way( n_threads, &thread_jc );
bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc );
#else
{
// Initialise a local thrinfo obj for work split across threads.
thrinfo_t thread_jc;
bli_thrinfo_set_n_way( 1, &thread_jc );
bli_thrinfo_set_work_id( 0, &thread_jc );
#endif
// Compute the JC loop thread range for the current thread.
dim_t jc_start, jc_end;
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
{
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
dim_t jc_cur_loop = jc;
dim_t jc_cur_loop_rem = 0;
dim_t n_sub_updated;
get_B_panel_reordered_start_offset_width
(
jc, n, NC, 16,
&jc_cur_loop, &jc_cur_loop_rem,
&nc0, &n_sub_updated
);
for ( dim_t pc = 0; pc < k; pc += KC )
{
dim_t kc0 = bli_min( ( k - pc ), KC );
// k needs to be a multiple of 2 so that it can be used with dpbf
// instruction. Padding is added in cases this condition is not
// satisfied, and therefore the k offset used for packed/reordered
// buffer needs to be updated.
dim_t kc0_updated = kc0;
kc0_updated += (kc0_updated & 0x1);
unpackb_nr64_bf16bf16f32of32_reference
(
( ( bfloat16* )b_unreorder->storage.aligned_buffer ) +
( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) +
( jc_cur_loop_rem * kc0_updated ),
( ( ( bfloat16* )b->storage.aligned_buffer ) +
( rs_b * pc ) + (jc * cs_b)),
nc0, kc0, rs_b, cs_b
);
}
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
}
}
}
void reorderb_nr64_bf16bf16f32of32
(
lpgemm_obj_t* b,

View File

@@ -49,6 +49,24 @@ void packb_nr64_bf16bf16f32of32_reference
dim_t* cs_p
);
void unpackb_nr64_bf16bf16f32of32_reference
(
bfloat16* b,
bfloat16* unpack_b_buffer,
const dim_t NC,
const dim_t KC,
dim_t rs_b,
dim_t cs_b
);
void unreorderb_nr64_bf16bf16f32of32_reference
(
lpgemm_obj_t* b_reorder,
lpgemm_obj_t* b_unreorder,
rntm_t* rntm,
lpgemm_cntx_t* lcntx
);
void reorderb_nr64_bf16bf16f32of32_reference
(
lpgemm_obj_t* b,

View File

@@ -0,0 +1,768 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <immintrin.h>
#include <string.h>
#include "blis.h"
#ifdef BLIS_ADDON_LPGEMM
/*
Below are the reference unpackb functions which are
varied based on block size NR (64, 48, 32, 16, lt) and
order (row / column (transpose)).
*/
void unpackb_nr48_bf16bf16f32of32_row_major_ref
(
bfloat16* b,
bfloat16* unpack_b,
const dim_t KC,
dim_t ldb
)
{
dim_t NR1 = 32;
dim_t NR2 = 16;
dim_t k_full_pieces_blks = KC / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = KC % 2;
dim_t kr_new = 0;
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
{
bfloat16* outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ));
bfloat16* outp1 = ( unpack_b + ( ldb * ( kr + 1 ) ));
bfloat16* outp2 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NR2);
bfloat16* outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NR2);
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR1 ));
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR1 ));
for(dim_t i = 0; i < 16; i++)
{
*outp0++ = *inp0++;
*outp1++ = *inp0++;
*outp2++ = *inp1++;
*outp3++ = *inp1++;
}
outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NR1);
outp1 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NR1);
outp2 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NR1 + 8);
outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NR1 + 8);
inp0 = ( b + ( ( kr_new + 2 ) * NR1 ));
inp1 = ( b + ( ( kr_new + 2 ) * NR1 + NR2));
for(dim_t i = 0; i < 8; i++)
{
*outp0++ = *inp0++;
*outp1++ = *inp0++;
*outp2++ = *inp1++;
*outp3++ = *inp1++;
}
kr_new += 3;
}
// Handle k remainder.
if ( k_partial_pieces > 0 )
{
bfloat16* inp0 = ( b + ( ldb * ( k_full_pieces + 0 ) ));
bfloat16* inp2 = ( b + ( ldb * ( k_full_pieces + 0 ) ) + NR2);
bfloat16* outp0 = ( unpack_b + ( ( kr_new + 0 ) * NR1 ));
bfloat16* outp1 = ( unpack_b + ( ( kr_new + 1 ) * NR1 ));
for(dim_t i = 0; i < 16; i++)
{
*outp0++ = *inp0++;
*outp0++ = 0;
*outp1++ = *inp2++;
*outp1++ = 0;
}
inp0 = ( b + ( ldb * ( k_full_pieces + 0 ) ) + NR1);
inp2 = ( b + ( ldb * ( k_full_pieces + 0 ) ) + NR1 + 8);
outp0 = ( unpack_b + ( ( kr_new + 2 ) * NR1 ));
outp1 = ( unpack_b + ( ( kr_new + 2 ) * NR1 + NR2));
for(dim_t i = 0; i < 8; i++)
{
*outp0++ = *inp0++;
*outp0++ = 0;
*outp1++ = *inp2++;
*outp1++ = 0;
}
}
}
void unpackb_nr32_bf16bf16f32of32_row_major_ref
(
bfloat16* b,
bfloat16* unpack_b,
const dim_t KC,
dim_t ldb
)
{
dim_t NR = 32;
dim_t NR2 = 16;
dim_t k_full_pieces_blks = KC / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = KC % 2;
dim_t kr_new = 0;
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
{
bfloat16* outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ));
bfloat16* outp1 = ( unpack_b + ( ldb * ( kr + 1 ) ));
bfloat16* outp2 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NR2);
bfloat16* outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NR2);
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
for(dim_t i = 0; i < 16; i++)
{
*outp0++ = *inp0++;
*outp1++ = *inp0++;
*outp2++ = *inp1++;
*outp3++ = *inp1++;
}
kr_new += 2;
}
// Handle k remainder.
if ( k_partial_pieces > 0 )
{
bfloat16* outp0 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ));
bfloat16* outp2 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ) + NR2);
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
for(dim_t i = 0; i < 16; i++)
{
*outp0++ = *inp0++;
*outp2++ = *inp1++;
}
}
}
void unpackb_nr16_bf16bf16f32of32_row_major_ref
(
bfloat16* b,
bfloat16* unpack_b,
const dim_t KC,
dim_t ldb
)
{
dim_t NR = 16;
dim_t NRBY2 = 8;
dim_t k_full_pieces_blks = KC / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = KC % 2;
dim_t kr_new = 0;
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
{
bfloat16* outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ));
bfloat16* outp1 = ( unpack_b + ( ldb * ( kr + 1 ) ));
bfloat16* outp2 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NRBY2);
bfloat16* outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NRBY2);
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
for(dim_t i = 0; i < NRBY2; i++)
{
*outp0++ = *inp0++;
*outp1++ = *inp0++;
*outp2++ = *inp1++;
*outp3++ = *inp1++;
}
kr_new += 2;
}
// Handle k remainder.
if ( k_partial_pieces > 0 )
{
bfloat16* outp0 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ));
bfloat16* outp2 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ) + NRBY2);
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
for(dim_t i = 0; i < NRBY2; i++)
{
*outp0++ = *inp0++;
*outp2++ = *inp1++;
}
}
}
void unpackb_nrlt16_bf16bf16f32of32_row_major_ref
(
bfloat16* b,
bfloat16* unpack_b,
const dim_t KC,
dim_t ldb,
dim_t n0_partial_rem
)
{
dim_t NR = 16;
dim_t NRBY2 = 8;
dim_t k_full_pieces_blks = KC / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = KC % 2;
dim_t kr_new = 0;
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
{
bfloat16* outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ));
bfloat16* outp1 = ( unpack_b + ( ldb * ( kr + 1 ) ));
bfloat16* outp2 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NRBY2);
bfloat16* outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NRBY2);
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
for(dim_t i = 0; i < (n0_partial_rem/2); i++)
{
*outp0++ = *inp0++;
*outp1++ = *inp0++;
*outp2++ = *inp1++;
*outp3++ = *inp1++;
}
kr_new += 2;
}
// Handle k remainder.
if ( k_partial_pieces > 0 )
{
bfloat16* outp0 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ));
bfloat16* outp2 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ) + NRBY2);
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
for(dim_t i = 0; i < (n0_partial_rem/2); i++)
{
*outp0++ = *inp0++;
*outp2++ = *inp1++;
}
}
}
void unpackb_nr64_bf16bf16f32of32_row_major_ref
(
bfloat16* b,
bfloat16* unpack_b,
const dim_t NC,
const dim_t KC,
dim_t ldb
)
{
dim_t NR = 64;
dim_t n_full_pieces = NC / NR;
dim_t n_full_pieces_loop_limit = n_full_pieces * NR;
dim_t n_partial_pieces = NC % NR;
dim_t k_full_pieces_blks = KC / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = KC % 2;
// KC when not multiple of 2 will have padding to make it multiple of 2 in packed buffer.
dim_t KC_updated = KC;
if ( k_partial_pieces > 0 )
{
KC_updated += ( 2 - k_partial_pieces );
}
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
{
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
{
bfloat16* outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ) + jc );
bfloat16* outp1 = ( unpack_b + ( ldb * ( kr + 0 ) ) + jc + 32 );
bfloat16* outp2 = ( unpack_b + ( ldb * ( kr + 1 ) ) + jc );
bfloat16* outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + jc + 32 );
//load from b reordered buffer
bfloat16* inp0 = ( b + ( jc * KC_updated ) + ( ( kr + 0 ) * NR ));
bfloat16* inp1 = ( b + ( jc * KC_updated ) + ( ( kr + 1 ) * NR ));
for(dim_t i = 0; i < 32; i++)
{
*outp0++ = *inp0++;
*outp2++ = *inp0++;
*outp1++ = *inp1++;
*outp3++ = *inp1++;
}
}
if( k_partial_pieces > 0 )
{
bfloat16* outp0 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ) + jc );
bfloat16* outp1 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32 );
//load from b reordered buffer
bfloat16* inp0 = ( b + ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) );
bfloat16* inp1 = ( b + ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) );
for(dim_t i = 0; i < 32; i++)
{
*outp0++ = *inp0++;
*outp0++ = 0;
*outp1++ = *inp1++;
*outp1++ = 0;
}
}
}
if( n_partial_pieces > 0 )
{
dim_t n0_partial_rem = n_partial_pieces % 16;
dim_t n0_partial_unpack = 0;
// Split into multiple smaller fringe kernels, so as to maximize
// vectorization after packing. Any n0 < NR(64) can be expressed
// as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16.
dim_t n0_48 = n_partial_pieces / 48;
dim_t n0_32 = n_partial_pieces / 32;
dim_t n0_16 = n_partial_pieces / 16;
if ( n0_48 == 1 )
{
unpackb_nr48_bf16bf16f32of32_row_major_ref
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b + n_full_pieces_loop_limit ), KC, ldb
);
n0_partial_unpack = 48;
}
else if ( n0_32 == 1 )
{
unpackb_nr32_bf16bf16f32of32_row_major_ref
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b + n_full_pieces_loop_limit ), KC, ldb
);
n0_partial_unpack = 32;
}
else if ( n0_16 == 1 )
{
unpackb_nr16_bf16bf16f32of32_row_major_ref
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b + n_full_pieces_loop_limit ), KC, ldb
);
n0_partial_unpack = 16;
}
if ( n0_partial_rem > 0 )
{
unpackb_nrlt16_bf16bf16f32of32_row_major_ref
(
( b + ( n_full_pieces_loop_limit * KC_updated ) +
( n0_partial_unpack * KC_updated ) ),
( unpack_b + n_full_pieces_loop_limit + n0_partial_unpack ), KC, ldb,
n0_partial_rem
);
}
}
}
void unpackb_nrlt16_bf16bf16f32of32_col_major_ref
(
bfloat16* b,
bfloat16* unpack_b,
const dim_t KC,
dim_t NR,
dim_t ldb
)
{
// Used for permuting the mm512i elements for use in dpbf16_ps instruction.
dim_t kr = 0;
for ( kr = 0; ( kr + 31 ) < KC; kr += 32 )
{
for( dim_t jr = 0; jr < NR; jr += 16 )
{
bfloat16 *inp, *outp;
for( dim_t i = 0; i < 16; i++ )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
for( dim_t j = 0; j < 16; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = *inp++;
}
}
}
}
for ( ; ( kr + 15 ) < KC; kr += 16 )
{
for( dim_t jr = 0; jr < NR; jr += 16 )
{
bfloat16 *inp, *outp;
for( dim_t i = 0; i < 16; i++ )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
for( dim_t j = 0; j < 8; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = *inp++;
}
}
}
}
for( ; ( kr +7 ) < KC; kr += 8 )
{
for( dim_t jr = 0; jr < NR; jr += 16 )
{
bfloat16 *inp, *outp;
for( dim_t i = 0; i < 16; i++ )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
for( dim_t j = 0; j < 4; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = *inp++;
}
}
}
}
for( ; ( kr +3 ) < KC; kr += 4 )
{
for( dim_t jr = 0; jr < NR; jr += 16 )
{
bfloat16 *inp, *outp;
for( dim_t i = 0; i < 16; i++ )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
for( dim_t j = 0; j < 2; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = *inp++;
}
}
}
}
for( ; ( kr +1 ) < KC; kr += 2 )
{
for( dim_t jr = 0; jr < NR; jr += 16 )
{
bfloat16 *inp, *outp;
for( dim_t i = 0; i < 16; i++ )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
for( dim_t j = 0; j < 1; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = *inp++;
}
}
}
}
for( ; kr < KC; kr += 1 )
{
for( dim_t jr = 0; jr < NR; jr += 16 )
{
bfloat16 *inp, *outp;
for( dim_t i = 0; i < 16; i++ )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
for( dim_t j = 0; j < 1; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = 0;
}
}
}
}
}
void unpackb_nr_mult_16_bf16bf16f32of32_col_major_ref
(
bfloat16* b,
bfloat16* unpack_b,
const dim_t NR,
const dim_t KC,
dim_t ldb
)
{
dim_t kr = 0;
for ( kr = 0; ( kr + 31 ) < KC; kr += 32 )
{
for( dim_t i = 0; i < 16; i++ )
{
bfloat16 *inp, *outp;
for( dim_t jr = 0; jr < NR; jr += 16 )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = ( b + ( jr * KC ) + ( ( kr + 0 ) * NR ));
for( dim_t j = 0; j < 16; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = *inp++;
}
}
}
}
for ( ; ( kr + 15 ) < KC; kr += 16 )
{
for( dim_t jr = 0; jr < NR; jr += 16 )
{
bfloat16 *inp, *outp;
for( dim_t i = 0; i < 16; i++ )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
for( dim_t j = 0; j < 8; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = *inp++;
}
}
}
}
for( ; ( kr +7 ) < KC; kr += 8 )
{
for( dim_t jr = 0; jr < NR; jr += 16 )
{
bfloat16 *inp, *outp;
for( dim_t i = 0; i < 16; i++ )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
for( dim_t j = 0; j < 4; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = *inp++;
}
}
}
}
for( ; ( kr +3 ) < KC; kr += 4 )
{
for( dim_t jr = 0; jr < NR; jr += 16 )
{
bfloat16 *inp, *outp;
for( dim_t i = 0; i < 16; i++ )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
for( dim_t j = 0; j < 2; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = *inp++;
}
}
}
}
for( ; ( kr +1 ) < KC; kr += 2 )
{
for( dim_t jr = 0; jr < NR; jr += 16 )
{
bfloat16 *inp, *outp;
for( dim_t i = 0; i < 16; i++ )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
for( dim_t j = 0; j < 1; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = *inp++;
}
}
}
}
for( ; kr < KC; kr += 1 )
{
for( dim_t jr = 0; jr < NR; jr += 16 )
{
bfloat16 *inp, *outp;
for( dim_t i = 0; i < 16; i++ )
{
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
for( dim_t j = 0; j < 1; j++ )
{
*(outp + ( j * 2 * NR)) = *inp++;
*(outp + (( j * 2 * NR) + 1)) = 0;
}
}
}
}
};
void unpackb_nr64_bf16bf16f32of32_col_major_ref
(
bfloat16* b,
bfloat16* unpack_b,
const dim_t NC,
const dim_t KC,
dim_t ldb
)
{
dim_t NR = 64;
dim_t n_full_pieces = NC / NR;
dim_t n_full_pieces_loop_limit = n_full_pieces * NR;
dim_t n_partial_pieces = NC % NR;
dim_t k_partial_pieces = KC % 2;
dim_t KC_updated = KC;
if ( k_partial_pieces > 0 )
{
KC_updated += ( 2 - k_partial_pieces );
}
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
{
unpackb_nr_mult_16_bf16bf16f32of32_col_major_ref
( b + (jc * KC_updated),
unpack_b + (jc * ldb), 64, KC, ldb
);
}
if(n_partial_pieces > 0)
{
dim_t n0_partial_rem = n_partial_pieces % 16;
dim_t n0_partial_pack = 0;
// Split into multiple smaller fringe kernels, so as to maximize
// vectorization after packing. Any n0 < NR(64) can be expressed
// as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16.
dim_t n0_48 = n_partial_pieces / 48;
dim_t n0_32 = n_partial_pieces / 32;
dim_t n0_16 = n_partial_pieces / 16;
if ( n0_48 == 1 )
{
unpackb_nr_mult_16_bf16bf16f32of32_col_major_ref
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b + n_full_pieces_loop_limit * ldb ), 48, KC, ldb
);
n0_partial_pack = 48;
}
else if ( n0_32 == 1 )
{
unpackb_nr_mult_16_bf16bf16f32of32_col_major_ref
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b + n_full_pieces_loop_limit * ldb ), 32, KC, ldb
);
n0_partial_pack = 32;
}
else if ( n0_16 == 1 )
{
unpackb_nr_mult_16_bf16bf16f32of32_col_major_ref
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b + n_full_pieces_loop_limit * ldb ), 16, KC, ldb
);
n0_partial_pack = 16;
}
if ( n0_partial_rem > 0 )
{
unpackb_nrlt16_bf16bf16f32of32_col_major_ref
(
( b + ( n_full_pieces_loop_limit * KC_updated ) +
( n0_partial_pack * KC_updated ) ),
( unpack_b + ( n_full_pieces_loop_limit + n0_partial_pack ) * ldb ), KC, ldb,
n0_partial_rem
);
}
}
};
void unpackb_nr64_bf16bf16f32of32_reference
(
bfloat16* b,
bfloat16* unpack_b,
const dim_t NC,
const dim_t KC,
dim_t rs_b,
dim_t cs_b
)
{
if( cs_b == 1 )
{
unpackb_nr64_bf16bf16f32of32_row_major_ref( b, unpack_b, NC, KC, rs_b );
}
else
{
return;
//TODO: Implement column major unpacking
//unpackb_nr64_bf16bf16f32of32_col_major_ref( b, unpack_b, NC, KC, cs_b );
}
}
#endif

View File

@@ -41,7 +41,7 @@
void unpackb_nr48_bf16bf16f32of32_row_major
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t KC,
dim_t ldb
)
@@ -73,8 +73,8 @@ void unpackb_nr48_bf16bf16f32of32_row_major
a01 = _mm512_permutex2var_epi16( b0, selector_even, a0 );
b0 = _mm512_permutex2var_epi16( b0, selector_odd, a0 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), a01 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), b0 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ), a01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ), b0 );
c0 = _mm512_loadu_si512( b + ( ( kr_new + 2 ) * NR1 ) );
d0 = _mm512_setzero_si512();
@@ -82,8 +82,8 @@ void unpackb_nr48_bf16bf16f32of32_row_major
c01 = _mm512_permutex2var_epi16( d0, selector_even, c0 );
d0 = _mm512_permutex2var_epi16( d0, selector_odd, c0 );
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + NR1, 0xFFFF, c01 );
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + NR1, 0xFFFF, d0 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + NR1, 0xFFFF, c01 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + NR1, 0xFFFF, d0 );
kr_new += 3;
}
@@ -96,18 +96,18 @@ void unpackb_nr48_bf16bf16f32of32_row_major
a01 = _mm512_permutex2var_epi16( b0, selector_even, a0 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), a01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), a01 );
c0 = _mm512_loadu_si512( b + ( ( kr_new + 2 ) * NR1 ) );
c01 = _mm512_permutex2var_epi16( c0, selector_even, c0 );
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + NR1, 0xFFFF, c01 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + NR1, 0xFFFF, c01 );
}
}
void unpackb_nr32_bf16bf16f32of32_row_major
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t KC,
dim_t ldb
)
@@ -138,8 +138,8 @@ void unpackb_nr32_bf16bf16f32of32_row_major
c0 = _mm512_permutex2var_epi16( c0, selector_odd, a0 );
// Store to unpack buffer
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), a01 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), c0 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ), a01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ), c0 );
}
if( k_partial_pieces > 0 )
@@ -150,13 +150,13 @@ void unpackb_nr32_bf16bf16f32of32_row_major
a0 = _mm512_permutex2var_epi16( c0, selector_even, a0 );
// Store to unpack buffer
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), a0 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), a0 );
}
}
void unpackb_nr16_bf16bf16f32of32_row_major
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t KC,
dim_t ldb
)
@@ -187,8 +187,8 @@ void unpackb_nr16_bf16bf16f32of32_row_major
c0 = _mm512_permutex2var_epi16( a0, selector_odd, a0 );
// Store to unpack buffer
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), 0xFFFF, a01 );
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), 0xFFFF, c0 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ), 0xFFFF, a01 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ), 0xFFFF, c0 );
}
if( k_partial_pieces > 0 )
{
@@ -197,13 +197,13 @@ void unpackb_nr16_bf16bf16f32of32_row_major
a0 = _mm512_permutex2var_epi16( a0, selector_even, a0 );
// Store to unpack buffer
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), 0xFFFF, a0 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), 0xFFFF, a0 );
}
}
void unpackb_nrlt16_bf16bf16f32of32_row_major
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t KC,
dim_t ldb,
dim_t n0_partial_rem
@@ -237,8 +237,8 @@ void unpackb_nrlt16_bf16bf16f32of32_row_major
c0 = _mm512_permutex2var_epi16( a0, selector_odd, a0 );
// Store to unpack buffer
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), store_mask, a01 );
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), store_mask, c0 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ), store_mask, a01 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ), store_mask, c0 );
}
if( k_partial_pieces > 0 )
{
@@ -247,14 +247,14 @@ void unpackb_nrlt16_bf16bf16f32of32_row_major
a0 = _mm512_permutex2var_epi16( a0, selector_even, a0 );
// Store to unpack buffer
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), store_mask, a0 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), store_mask, a0 );
}
}
void unpackb_nr64_bf16bf16f32of32_row_major
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t NC,
const dim_t KC,
dim_t ldb
@@ -304,10 +304,10 @@ void unpackb_nr64_bf16bf16f32of32_row_major
d0 = _mm512_permutex2var_epi16( d0, selector_odd, c0 );
// Store to unpack buffer
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + jc, a01 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + jc + 32, c01 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + jc, b0 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + jc + 32, d0 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + jc, a01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + jc + 32, c01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + jc, b0 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + jc + 32, d0 );
}
if( k_partial_pieces > 0 )
@@ -322,8 +322,8 @@ void unpackb_nr64_bf16bf16f32of32_row_major
c01 = _mm512_permutex2var_epi16( d0, selector_even, c0 );
// Store to unpack buffer
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + jc, a01 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32, c01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + jc, a01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32, c01 );
}
}
@@ -344,7 +344,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
unpackb_nr48_bf16bf16f32of32_row_major
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
);
n0_partial_unpack = 48;
@@ -354,7 +354,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
unpackb_nr32_bf16bf16f32of32_row_major
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
);
n0_partial_unpack = 32;
@@ -364,7 +364,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
unpackb_nr16_bf16bf16f32of32_row_major
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
);
n0_partial_unpack = 16;
@@ -376,7 +376,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
(
( b + ( n_full_pieces_loop_limit * KC_updated ) +
( n0_partial_unpack * KC_updated ) ),
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit + n0_partial_unpack ), KC, ldb,
( unpack_b_buffer + n_full_pieces_loop_limit + n0_partial_unpack ), KC, ldb,
n0_partial_rem
);
}
@@ -895,7 +895,7 @@ void unpackb_nr64_bf16bf16f32of32_col_major
void unpackb_nr64_bf16bf16f32of32
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t NC,
const dim_t KC,
dim_t rs_b,
@@ -904,11 +904,11 @@ void unpackb_nr64_bf16bf16f32of32
{
if( cs_b == 1 )
{
unpackb_nr64_bf16bf16f32of32_row_major( b, unpack_b_buffer_bf16bf16f32of32, NC, KC, rs_b );
unpackb_nr64_bf16bf16f32of32_row_major( b, unpack_b_buffer, NC, KC, rs_b );
}
else
{
unpackb_nr64_bf16bf16f32of32_col_major( b, unpack_b_buffer_bf16bf16f32of32, NC, KC, cs_b );
unpackb_nr64_bf16bf16f32of32_col_major( b, unpack_b_buffer, NC, KC, cs_b );
}
}
#endif // BLIS_ADDON_LPGEMM