mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Implemented reference unreorder bf16 function
Description: Implemented a c reference for aocl_gemm_unreorder_bf16bf16f32of32 function The implementation working for row major and column major yet to be enabled. AMD-Internal: [ SWLCSG-3279 ] Change-Id: Ibcce4180bb897a40252140012d8d6886c38cb77a
This commit is contained in:
@@ -140,6 +140,102 @@ AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32_reference)
|
||||
|
||||
}
|
||||
|
||||
AOCL_GEMM_UNREORDER(bfloat16, bf16bf16f32of32_reference)
|
||||
{
|
||||
if ( ( output_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) ||
|
||||
( k <= 0 ) || ( n <= 0 ) )
|
||||
{
|
||||
return; // Error.
|
||||
}
|
||||
|
||||
inc_t rs_b, cs_b;
|
||||
|
||||
// Check for the validity of strides.
|
||||
if( ( order == 'r' ) || ( order == 'R' ) )
|
||||
{
|
||||
if( ldb < n ) return; // Error
|
||||
else
|
||||
{
|
||||
rs_b = ldb;
|
||||
cs_b = 1;
|
||||
}
|
||||
}
|
||||
else if( ( order == 'c' ) || ( order == 'C' ) )
|
||||
{
|
||||
if( ldb < k ) return; // Error.
|
||||
else
|
||||
{
|
||||
rs_b = 1;
|
||||
cs_b = ldb;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return; // Error.
|
||||
}
|
||||
|
||||
// Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
|
||||
return; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
AOCL_MATRIX_TYPE input_mat_type;
|
||||
bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type );
|
||||
|
||||
if ( input_mat_type == A_MATRIX )
|
||||
{
|
||||
return; // A reorder not supported.
|
||||
}
|
||||
|
||||
#if (defined(BLIS_KERNELS_ZEN4) && (!defined(LPGEMM_BF16_JIT)))
|
||||
if( n == 1 )
|
||||
{
|
||||
if( rs_b == 1 )
|
||||
{
|
||||
memcpy( output_buf_addr, reorder_buf_addr, ( k * sizeof( bfloat16 ) ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
for( dim_t k0 = 0; k0 < k; k0++ )
|
||||
{
|
||||
output_buf_addr[k0*rs_b] = reorder_buf_addr[k0];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
|
||||
|
||||
// create dummy b_reorder obj.
|
||||
lpgemm_obj_t b_reorder;
|
||||
b_reorder.storage.aligned_buffer = ( void* )reorder_buf_addr;
|
||||
|
||||
// create dummy b obj.
|
||||
lpgemm_obj_t b;
|
||||
b.storage.aligned_buffer = ( void* )output_buf_addr;
|
||||
b.rs = rs_b;
|
||||
b.cs = cs_b;
|
||||
b.width = n;
|
||||
b.length = k;
|
||||
|
||||
unreorderb_nr64_bf16bf16f32of32_reference( &b, &b_reorder, &rntm_g, lcntx_g );
|
||||
}
|
||||
|
||||
AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32)
|
||||
{
|
||||
if ( ( k <= 0 ) || ( n <= 0 ) )
|
||||
|
||||
@@ -108,6 +108,7 @@ BLIS_EXPORT_ADDON void aocl_unreorder_ ## LP_SFX \
|
||||
) \
|
||||
|
||||
AOCL_GEMM_UNREORDER(bfloat16, bf16bf16f32of32);
|
||||
AOCL_GEMM_UNREORDER(bfloat16, bf16bf16f32of32_reference);
|
||||
|
||||
#define AOCL_GEMM_MATMUL(A_type,B_type,C_type,Sum_type,LP_SFX) \
|
||||
BLIS_EXPORT_ADDON void aocl_gemm_ ## LP_SFX \
|
||||
|
||||
@@ -170,6 +170,91 @@ void reorderb_nr64_bf16bf16f32of32_reference
|
||||
b_reorder->mtag = REORDERED;
|
||||
}
|
||||
|
||||
void unreorderb_nr64_bf16bf16f32of32_reference
|
||||
(
|
||||
lpgemm_obj_t * b,
|
||||
lpgemm_obj_t * b_unreorder,
|
||||
rntm_t* rntm,
|
||||
lpgemm_cntx_t* lcntx
|
||||
)
|
||||
{
|
||||
dim_t NC = lcntx->blksz.NC;
|
||||
dim_t KC = lcntx->blksz.KC;
|
||||
dim_t NR = lcntx->blksz.NR;
|
||||
|
||||
// Extracting the matrix properties from the lpgemm object
|
||||
dim_t rs_b = b->rs;
|
||||
dim_t cs_b = b->cs;
|
||||
dim_t n = b->width;
|
||||
dim_t k = b->length;
|
||||
|
||||
dim_t k_updated = k;
|
||||
k_updated += (k_updated & 0x1);
|
||||
|
||||
dim_t n_threads = bli_rntm_num_threads( rntm );
|
||||
n_threads = ( n_threads > 0 ) ? n_threads : 1;
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
{
|
||||
// Initialise a local thrinfo obj for work split across threads.
|
||||
thrinfo_t thread_jc;
|
||||
bli_thrinfo_set_n_way( n_threads, &thread_jc );
|
||||
bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc );
|
||||
#else
|
||||
{
|
||||
// Initialise a local thrinfo obj for work split across threads.
|
||||
thrinfo_t thread_jc;
|
||||
bli_thrinfo_set_n_way( 1, &thread_jc );
|
||||
bli_thrinfo_set_work_id( 0, &thread_jc );
|
||||
#endif
|
||||
|
||||
// Compute the JC loop thread range for the current thread.
|
||||
dim_t jc_start, jc_end;
|
||||
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
|
||||
|
||||
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
|
||||
{
|
||||
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
|
||||
|
||||
dim_t jc_cur_loop = jc;
|
||||
dim_t jc_cur_loop_rem = 0;
|
||||
dim_t n_sub_updated;
|
||||
|
||||
get_B_panel_reordered_start_offset_width
|
||||
(
|
||||
jc, n, NC, 16,
|
||||
&jc_cur_loop, &jc_cur_loop_rem,
|
||||
&nc0, &n_sub_updated
|
||||
);
|
||||
|
||||
for ( dim_t pc = 0; pc < k; pc += KC )
|
||||
{
|
||||
dim_t kc0 = bli_min( ( k - pc ), KC );
|
||||
|
||||
// k needs to be a multiple of 2 so that it can be used with dpbf
|
||||
// instruction. Padding is added in cases this condition is not
|
||||
// satisfied, and therefore the k offset used for packed/reordered
|
||||
// buffer needs to be updated.
|
||||
dim_t kc0_updated = kc0;
|
||||
kc0_updated += (kc0_updated & 0x1);
|
||||
|
||||
unpackb_nr64_bf16bf16f32of32_reference
|
||||
(
|
||||
( ( bfloat16* )b_unreorder->storage.aligned_buffer ) +
|
||||
( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) +
|
||||
( jc_cur_loop_rem * kc0_updated ),
|
||||
( ( ( bfloat16* )b->storage.aligned_buffer ) +
|
||||
( rs_b * pc ) + (jc * cs_b)),
|
||||
nc0, kc0, rs_b, cs_b
|
||||
);
|
||||
}
|
||||
|
||||
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void reorderb_nr64_bf16bf16f32of32
|
||||
(
|
||||
lpgemm_obj_t* b,
|
||||
|
||||
@@ -49,6 +49,24 @@ void packb_nr64_bf16bf16f32of32_reference
|
||||
dim_t* cs_p
|
||||
);
|
||||
|
||||
void unpackb_nr64_bf16bf16f32of32_reference
|
||||
(
|
||||
bfloat16* b,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t rs_b,
|
||||
dim_t cs_b
|
||||
);
|
||||
|
||||
void unreorderb_nr64_bf16bf16f32of32_reference
|
||||
(
|
||||
lpgemm_obj_t* b_reorder,
|
||||
lpgemm_obj_t* b_unreorder,
|
||||
rntm_t* rntm,
|
||||
lpgemm_cntx_t* lcntx
|
||||
);
|
||||
|
||||
void reorderb_nr64_bf16bf16f32of32_reference
|
||||
(
|
||||
lpgemm_obj_t* b,
|
||||
|
||||
768
addon/aocl_gemm/frame/bf16bf16f32/lpgemm_unreorder_bf16_ref.c
Normal file
768
addon/aocl_gemm/frame/bf16bf16f32/lpgemm_unreorder_bf16_ref.c
Normal file
@@ -0,0 +1,768 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <string.h>
|
||||
#include "blis.h"
|
||||
|
||||
#ifdef BLIS_ADDON_LPGEMM
|
||||
|
||||
/*
|
||||
Below are the reference unpackb functions which are
|
||||
varied based on block size NR (64, 48, 32, 16, lt) and
|
||||
order (row / column (transpose)).
|
||||
*/
|
||||
|
||||
void unpackb_nr48_bf16bf16f32of32_row_major_ref
|
||||
(
|
||||
bfloat16* b,
|
||||
bfloat16* unpack_b,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
{
|
||||
dim_t NR1 = 32;
|
||||
dim_t NR2 = 16;
|
||||
|
||||
dim_t k_full_pieces_blks = KC / 2;
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = KC % 2;
|
||||
|
||||
dim_t kr_new = 0;
|
||||
|
||||
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
|
||||
{
|
||||
bfloat16* outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ));
|
||||
bfloat16* outp1 = ( unpack_b + ( ldb * ( kr + 1 ) ));
|
||||
bfloat16* outp2 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NR2);
|
||||
bfloat16* outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NR2);
|
||||
|
||||
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR1 ));
|
||||
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR1 ));
|
||||
|
||||
for(dim_t i = 0; i < 16; i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp1++ = *inp0++;
|
||||
*outp2++ = *inp1++;
|
||||
*outp3++ = *inp1++;
|
||||
}
|
||||
|
||||
outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NR1);
|
||||
outp1 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NR1);
|
||||
outp2 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NR1 + 8);
|
||||
outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NR1 + 8);
|
||||
|
||||
inp0 = ( b + ( ( kr_new + 2 ) * NR1 ));
|
||||
inp1 = ( b + ( ( kr_new + 2 ) * NR1 + NR2));
|
||||
|
||||
for(dim_t i = 0; i < 8; i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp1++ = *inp0++;
|
||||
*outp2++ = *inp1++;
|
||||
*outp3++ = *inp1++;
|
||||
}
|
||||
kr_new += 3;
|
||||
}
|
||||
|
||||
// Handle k remainder.
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
bfloat16* inp0 = ( b + ( ldb * ( k_full_pieces + 0 ) ));
|
||||
bfloat16* inp2 = ( b + ( ldb * ( k_full_pieces + 0 ) ) + NR2);
|
||||
|
||||
bfloat16* outp0 = ( unpack_b + ( ( kr_new + 0 ) * NR1 ));
|
||||
bfloat16* outp1 = ( unpack_b + ( ( kr_new + 1 ) * NR1 ));
|
||||
|
||||
for(dim_t i = 0; i < 16; i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp0++ = 0;
|
||||
*outp1++ = *inp2++;
|
||||
*outp1++ = 0;
|
||||
}
|
||||
|
||||
inp0 = ( b + ( ldb * ( k_full_pieces + 0 ) ) + NR1);
|
||||
inp2 = ( b + ( ldb * ( k_full_pieces + 0 ) ) + NR1 + 8);
|
||||
|
||||
outp0 = ( unpack_b + ( ( kr_new + 2 ) * NR1 ));
|
||||
outp1 = ( unpack_b + ( ( kr_new + 2 ) * NR1 + NR2));
|
||||
|
||||
for(dim_t i = 0; i < 8; i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp0++ = 0;
|
||||
*outp1++ = *inp2++;
|
||||
*outp1++ = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void unpackb_nr32_bf16bf16f32of32_row_major_ref
|
||||
(
|
||||
bfloat16* b,
|
||||
bfloat16* unpack_b,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
{
|
||||
dim_t NR = 32;
|
||||
dim_t NR2 = 16;
|
||||
|
||||
dim_t k_full_pieces_blks = KC / 2;
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = KC % 2;
|
||||
|
||||
dim_t kr_new = 0;
|
||||
|
||||
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
|
||||
{
|
||||
bfloat16* outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ));
|
||||
bfloat16* outp1 = ( unpack_b + ( ldb * ( kr + 1 ) ));
|
||||
bfloat16* outp2 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NR2);
|
||||
bfloat16* outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NR2);
|
||||
|
||||
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
|
||||
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
|
||||
|
||||
for(dim_t i = 0; i < 16; i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp1++ = *inp0++;
|
||||
*outp2++ = *inp1++;
|
||||
*outp3++ = *inp1++;
|
||||
}
|
||||
|
||||
kr_new += 2;
|
||||
}
|
||||
|
||||
// Handle k remainder.
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
bfloat16* outp0 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ));
|
||||
bfloat16* outp2 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ) + NR2);
|
||||
|
||||
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
|
||||
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
|
||||
|
||||
for(dim_t i = 0; i < 16; i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp2++ = *inp1++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void unpackb_nr16_bf16bf16f32of32_row_major_ref
|
||||
(
|
||||
bfloat16* b,
|
||||
bfloat16* unpack_b,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
{
|
||||
dim_t NR = 16;
|
||||
dim_t NRBY2 = 8;
|
||||
|
||||
dim_t k_full_pieces_blks = KC / 2;
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = KC % 2;
|
||||
|
||||
dim_t kr_new = 0;
|
||||
|
||||
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
|
||||
{
|
||||
bfloat16* outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ));
|
||||
bfloat16* outp1 = ( unpack_b + ( ldb * ( kr + 1 ) ));
|
||||
bfloat16* outp2 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NRBY2);
|
||||
bfloat16* outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NRBY2);
|
||||
|
||||
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
|
||||
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
|
||||
|
||||
for(dim_t i = 0; i < NRBY2; i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp1++ = *inp0++;
|
||||
*outp2++ = *inp1++;
|
||||
*outp3++ = *inp1++;
|
||||
}
|
||||
kr_new += 2;
|
||||
}
|
||||
|
||||
// Handle k remainder.
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
bfloat16* outp0 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ));
|
||||
bfloat16* outp2 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ) + NRBY2);
|
||||
|
||||
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
|
||||
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
|
||||
|
||||
for(dim_t i = 0; i < NRBY2; i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp2++ = *inp1++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void unpackb_nrlt16_bf16bf16f32of32_row_major_ref
|
||||
(
|
||||
bfloat16* b,
|
||||
bfloat16* unpack_b,
|
||||
const dim_t KC,
|
||||
dim_t ldb,
|
||||
dim_t n0_partial_rem
|
||||
)
|
||||
{
|
||||
dim_t NR = 16;
|
||||
dim_t NRBY2 = 8;
|
||||
|
||||
dim_t k_full_pieces_blks = KC / 2;
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = KC % 2;
|
||||
|
||||
dim_t kr_new = 0;
|
||||
|
||||
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
|
||||
{
|
||||
bfloat16* outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ));
|
||||
bfloat16* outp1 = ( unpack_b + ( ldb * ( kr + 1 ) ));
|
||||
bfloat16* outp2 = ( unpack_b + ( ldb * ( kr + 0 ) ) + NRBY2);
|
||||
bfloat16* outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + NRBY2);
|
||||
|
||||
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
|
||||
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
|
||||
|
||||
for(dim_t i = 0; i < (n0_partial_rem/2); i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp1++ = *inp0++;
|
||||
*outp2++ = *inp1++;
|
||||
*outp3++ = *inp1++;
|
||||
}
|
||||
kr_new += 2;
|
||||
}
|
||||
|
||||
// Handle k remainder.
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
bfloat16* outp0 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ));
|
||||
bfloat16* outp2 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ) + NRBY2);
|
||||
|
||||
bfloat16* inp0 = ( b + ( ( kr_new + 0 ) * NR ));
|
||||
bfloat16* inp1 = ( b + ( ( kr_new + 1 ) * NR ));
|
||||
|
||||
for(dim_t i = 0; i < (n0_partial_rem/2); i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp2++ = *inp1++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void unpackb_nr64_bf16bf16f32of32_row_major_ref
|
||||
(
|
||||
bfloat16* b,
|
||||
bfloat16* unpack_b,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
{
|
||||
dim_t NR = 64;
|
||||
|
||||
dim_t n_full_pieces = NC / NR;
|
||||
dim_t n_full_pieces_loop_limit = n_full_pieces * NR;
|
||||
dim_t n_partial_pieces = NC % NR;
|
||||
|
||||
dim_t k_full_pieces_blks = KC / 2;
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = KC % 2;
|
||||
|
||||
// KC when not multiple of 2 will have padding to make it multiple of 2 in packed buffer.
|
||||
dim_t KC_updated = KC;
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
KC_updated += ( 2 - k_partial_pieces );
|
||||
}
|
||||
|
||||
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
|
||||
{
|
||||
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
|
||||
{
|
||||
bfloat16* outp0 = ( unpack_b + ( ldb * ( kr + 0 ) ) + jc );
|
||||
bfloat16* outp1 = ( unpack_b + ( ldb * ( kr + 0 ) ) + jc + 32 );
|
||||
bfloat16* outp2 = ( unpack_b + ( ldb * ( kr + 1 ) ) + jc );
|
||||
bfloat16* outp3 = ( unpack_b + ( ldb * ( kr + 1 ) ) + jc + 32 );
|
||||
|
||||
//load from b reordered buffer
|
||||
bfloat16* inp0 = ( b + ( jc * KC_updated ) + ( ( kr + 0 ) * NR ));
|
||||
bfloat16* inp1 = ( b + ( jc * KC_updated ) + ( ( kr + 1 ) * NR ));
|
||||
|
||||
for(dim_t i = 0; i < 32; i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp2++ = *inp0++;
|
||||
*outp1++ = *inp1++;
|
||||
*outp3++ = *inp1++;
|
||||
}
|
||||
}
|
||||
|
||||
if( k_partial_pieces > 0 )
|
||||
{
|
||||
bfloat16* outp0 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ) + jc );
|
||||
bfloat16* outp1 = ( unpack_b + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32 );
|
||||
|
||||
//load from b reordered buffer
|
||||
bfloat16* inp0 = ( b + ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) );
|
||||
bfloat16* inp1 = ( b + ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) );
|
||||
for(dim_t i = 0; i < 32; i++)
|
||||
{
|
||||
*outp0++ = *inp0++;
|
||||
*outp0++ = 0;
|
||||
*outp1++ = *inp1++;
|
||||
*outp1++ = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if( n_partial_pieces > 0 )
|
||||
{
|
||||
dim_t n0_partial_rem = n_partial_pieces % 16;
|
||||
dim_t n0_partial_unpack = 0;
|
||||
|
||||
// Split into multiple smaller fringe kernels, so as to maximize
|
||||
// vectorization after packing. Any n0 < NR(64) can be expressed
|
||||
// as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16.
|
||||
dim_t n0_48 = n_partial_pieces / 48;
|
||||
dim_t n0_32 = n_partial_pieces / 32;
|
||||
dim_t n0_16 = n_partial_pieces / 16;
|
||||
|
||||
if ( n0_48 == 1 )
|
||||
{
|
||||
unpackb_nr48_bf16bf16f32of32_row_major_ref
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b + n_full_pieces_loop_limit ), KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_unpack = 48;
|
||||
}
|
||||
else if ( n0_32 == 1 )
|
||||
{
|
||||
unpackb_nr32_bf16bf16f32of32_row_major_ref
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b + n_full_pieces_loop_limit ), KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_unpack = 32;
|
||||
}
|
||||
else if ( n0_16 == 1 )
|
||||
{
|
||||
unpackb_nr16_bf16bf16f32of32_row_major_ref
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b + n_full_pieces_loop_limit ), KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_unpack = 16;
|
||||
}
|
||||
|
||||
if ( n0_partial_rem > 0 )
|
||||
{
|
||||
unpackb_nrlt16_bf16bf16f32of32_row_major_ref
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) +
|
||||
( n0_partial_unpack * KC_updated ) ),
|
||||
( unpack_b + n_full_pieces_loop_limit + n0_partial_unpack ), KC, ldb,
|
||||
n0_partial_rem
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void unpackb_nrlt16_bf16bf16f32of32_col_major_ref
|
||||
(
|
||||
bfloat16* b,
|
||||
bfloat16* unpack_b,
|
||||
const dim_t KC,
|
||||
dim_t NR,
|
||||
dim_t ldb
|
||||
)
|
||||
{
|
||||
// Used for permuting the mm512i elements for use in dpbf16_ps instruction.
|
||||
dim_t kr = 0;
|
||||
for ( kr = 0; ( kr + 31 ) < KC; kr += 32 )
|
||||
{
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
|
||||
for( dim_t j = 0; j < 16; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = *inp++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for ( ; ( kr + 15 ) < KC; kr += 16 )
|
||||
{
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
|
||||
for( dim_t j = 0; j < 8; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = *inp++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for( ; ( kr +7 ) < KC; kr += 8 )
|
||||
{
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
|
||||
for( dim_t j = 0; j < 4; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = *inp++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for( ; ( kr +3 ) < KC; kr += 4 )
|
||||
{
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
|
||||
for( dim_t j = 0; j < 2; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = *inp++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for( ; ( kr +1 ) < KC; kr += 2 )
|
||||
{
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
|
||||
for( dim_t j = 0; j < 1; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = *inp++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for( ; kr < KC; kr += 1 )
|
||||
{
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
|
||||
for( dim_t j = 0; j < 1; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void unpackb_nr_mult_16_bf16bf16f32of32_col_major_ref
|
||||
(
|
||||
bfloat16* b,
|
||||
bfloat16* unpack_b,
|
||||
const dim_t NR,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
{
|
||||
dim_t kr = 0;
|
||||
for ( kr = 0; ( kr + 31 ) < KC; kr += 32 )
|
||||
{
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = ( b + ( jr * KC ) + ( ( kr + 0 ) * NR ));
|
||||
for( dim_t j = 0; j < 16; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = *inp++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for ( ; ( kr + 15 ) < KC; kr += 16 )
|
||||
{
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
|
||||
for( dim_t j = 0; j < 8; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = *inp++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for( ; ( kr +7 ) < KC; kr += 8 )
|
||||
{
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
|
||||
for( dim_t j = 0; j < 4; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = *inp++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for( ; ( kr +3 ) < KC; kr += 4 )
|
||||
{
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
|
||||
for( dim_t j = 0; j < 2; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = *inp++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for( ; ( kr +1 ) < KC; kr += 2 )
|
||||
{
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
|
||||
for( dim_t j = 0; j < 1; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = *inp++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for( ; kr < KC; kr += 1 )
|
||||
{
|
||||
for( dim_t jr = 0; jr < NR; jr += 16 )
|
||||
{
|
||||
bfloat16 *inp, *outp;
|
||||
for( dim_t i = 0; i < 16; i++ )
|
||||
{
|
||||
outp = (unpack_b + ( ldb * ( jr + i ) ) + kr);
|
||||
inp = b + ( jr * 2 ) + (kr * NR) + i * 2;
|
||||
for( dim_t j = 0; j < 1; j++ )
|
||||
{
|
||||
*(outp + ( j * 2 * NR)) = *inp++;
|
||||
*(outp + (( j * 2 * NR) + 1)) = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void unpackb_nr64_bf16bf16f32of32_col_major_ref
|
||||
(
|
||||
bfloat16* b,
|
||||
bfloat16* unpack_b,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
{
|
||||
dim_t NR = 64;
|
||||
|
||||
dim_t n_full_pieces = NC / NR;
|
||||
dim_t n_full_pieces_loop_limit = n_full_pieces * NR;
|
||||
dim_t n_partial_pieces = NC % NR;
|
||||
|
||||
dim_t k_partial_pieces = KC % 2;
|
||||
|
||||
dim_t KC_updated = KC;
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
KC_updated += ( 2 - k_partial_pieces );
|
||||
}
|
||||
|
||||
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
|
||||
{
|
||||
unpackb_nr_mult_16_bf16bf16f32of32_col_major_ref
|
||||
( b + (jc * KC_updated),
|
||||
unpack_b + (jc * ldb), 64, KC, ldb
|
||||
);
|
||||
}
|
||||
|
||||
if(n_partial_pieces > 0)
|
||||
{
|
||||
dim_t n0_partial_rem = n_partial_pieces % 16;
|
||||
dim_t n0_partial_pack = 0;
|
||||
|
||||
// Split into multiple smaller fringe kernels, so as to maximize
|
||||
// vectorization after packing. Any n0 < NR(64) can be expressed
|
||||
// as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16.
|
||||
dim_t n0_48 = n_partial_pieces / 48;
|
||||
dim_t n0_32 = n_partial_pieces / 32;
|
||||
dim_t n0_16 = n_partial_pieces / 16;
|
||||
|
||||
if ( n0_48 == 1 )
|
||||
{
|
||||
unpackb_nr_mult_16_bf16bf16f32of32_col_major_ref
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b + n_full_pieces_loop_limit * ldb ), 48, KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_pack = 48;
|
||||
}
|
||||
else if ( n0_32 == 1 )
|
||||
{
|
||||
unpackb_nr_mult_16_bf16bf16f32of32_col_major_ref
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b + n_full_pieces_loop_limit * ldb ), 32, KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_pack = 32;
|
||||
}
|
||||
else if ( n0_16 == 1 )
|
||||
{
|
||||
unpackb_nr_mult_16_bf16bf16f32of32_col_major_ref
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b + n_full_pieces_loop_limit * ldb ), 16, KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_pack = 16;
|
||||
}
|
||||
|
||||
if ( n0_partial_rem > 0 )
|
||||
{
|
||||
unpackb_nrlt16_bf16bf16f32of32_col_major_ref
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) +
|
||||
( n0_partial_pack * KC_updated ) ),
|
||||
( unpack_b + ( n_full_pieces_loop_limit + n0_partial_pack ) * ldb ), KC, ldb,
|
||||
n0_partial_rem
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void unpackb_nr64_bf16bf16f32of32_reference
|
||||
(
|
||||
bfloat16* b,
|
||||
bfloat16* unpack_b,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t rs_b,
|
||||
dim_t cs_b
|
||||
)
|
||||
{
|
||||
if( cs_b == 1 )
|
||||
{
|
||||
unpackb_nr64_bf16bf16f32of32_row_major_ref( b, unpack_b, NC, KC, rs_b );
|
||||
}
|
||||
else
|
||||
{
|
||||
return;
|
||||
//TODO: Implement column major unpacking
|
||||
//unpackb_nr64_bf16bf16f32of32_col_major_ref( b, unpack_b, NC, KC, cs_b );
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -41,7 +41,7 @@
|
||||
void unpackb_nr48_bf16bf16f32of32_row_major
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
@@ -73,8 +73,8 @@ void unpackb_nr48_bf16bf16f32of32_row_major
|
||||
a01 = _mm512_permutex2var_epi16( b0, selector_even, a0 );
|
||||
b0 = _mm512_permutex2var_epi16( b0, selector_odd, a0 );
|
||||
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), b0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ), a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ), b0 );
|
||||
|
||||
c0 = _mm512_loadu_si512( b + ( ( kr_new + 2 ) * NR1 ) );
|
||||
d0 = _mm512_setzero_si512();
|
||||
@@ -82,8 +82,8 @@ void unpackb_nr48_bf16bf16f32of32_row_major
|
||||
c01 = _mm512_permutex2var_epi16( d0, selector_even, c0 );
|
||||
d0 = _mm512_permutex2var_epi16( d0, selector_odd, c0 );
|
||||
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + NR1, 0xFFFF, c01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + NR1, 0xFFFF, d0 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + NR1, 0xFFFF, c01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + NR1, 0xFFFF, d0 );
|
||||
|
||||
kr_new += 3;
|
||||
}
|
||||
@@ -96,18 +96,18 @@ void unpackb_nr48_bf16bf16f32of32_row_major
|
||||
|
||||
a01 = _mm512_permutex2var_epi16( b0, selector_even, a0 );
|
||||
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), a01 );
|
||||
|
||||
c0 = _mm512_loadu_si512( b + ( ( kr_new + 2 ) * NR1 ) );
|
||||
c01 = _mm512_permutex2var_epi16( c0, selector_even, c0 );
|
||||
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + NR1, 0xFFFF, c01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + NR1, 0xFFFF, c01 );
|
||||
}
|
||||
}
|
||||
void unpackb_nr32_bf16bf16f32of32_row_major
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
@@ -138,8 +138,8 @@ void unpackb_nr32_bf16bf16f32of32_row_major
|
||||
c0 = _mm512_permutex2var_epi16( c0, selector_odd, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), c0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ), a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ), c0 );
|
||||
|
||||
}
|
||||
if( k_partial_pieces > 0 )
|
||||
@@ -150,13 +150,13 @@ void unpackb_nr32_bf16bf16f32of32_row_major
|
||||
a0 = _mm512_permutex2var_epi16( c0, selector_even, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), a0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), a0 );
|
||||
}
|
||||
}
|
||||
void unpackb_nr16_bf16bf16f32of32_row_major
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
@@ -187,8 +187,8 @@ void unpackb_nr16_bf16bf16f32of32_row_major
|
||||
c0 = _mm512_permutex2var_epi16( a0, selector_odd, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), 0xFFFF, a01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), 0xFFFF, c0 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ), 0xFFFF, a01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ), 0xFFFF, c0 );
|
||||
}
|
||||
if( k_partial_pieces > 0 )
|
||||
{
|
||||
@@ -197,13 +197,13 @@ void unpackb_nr16_bf16bf16f32of32_row_major
|
||||
a0 = _mm512_permutex2var_epi16( a0, selector_even, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), 0xFFFF, a0 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), 0xFFFF, a0 );
|
||||
}
|
||||
}
|
||||
void unpackb_nrlt16_bf16bf16f32of32_row_major
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t KC,
|
||||
dim_t ldb,
|
||||
dim_t n0_partial_rem
|
||||
@@ -237,8 +237,8 @@ void unpackb_nrlt16_bf16bf16f32of32_row_major
|
||||
c0 = _mm512_permutex2var_epi16( a0, selector_odd, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), store_mask, a01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), store_mask, c0 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ), store_mask, a01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ), store_mask, c0 );
|
||||
}
|
||||
if( k_partial_pieces > 0 )
|
||||
{
|
||||
@@ -247,14 +247,14 @@ void unpackb_nrlt16_bf16bf16f32of32_row_major
|
||||
a0 = _mm512_permutex2var_epi16( a0, selector_even, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), store_mask, a0 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), store_mask, a0 );
|
||||
}
|
||||
}
|
||||
|
||||
void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
@@ -304,10 +304,10 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
d0 = _mm512_permutex2var_epi16( d0, selector_odd, c0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + jc, a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + jc + 32, c01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + jc, b0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + jc + 32, d0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + jc, a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + jc + 32, c01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + jc, b0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + jc + 32, d0 );
|
||||
|
||||
}
|
||||
if( k_partial_pieces > 0 )
|
||||
@@ -322,8 +322,8 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
c01 = _mm512_permutex2var_epi16( d0, selector_even, c0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + jc, a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32, c01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + jc, a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32, c01 );
|
||||
}
|
||||
}
|
||||
|
||||
@@ -344,7 +344,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
unpackb_nr48_bf16bf16f32of32_row_major
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
|
||||
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_unpack = 48;
|
||||
@@ -354,7 +354,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
unpackb_nr32_bf16bf16f32of32_row_major
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
|
||||
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_unpack = 32;
|
||||
@@ -364,7 +364,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
unpackb_nr16_bf16bf16f32of32_row_major
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
|
||||
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_unpack = 16;
|
||||
@@ -376,7 +376,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) +
|
||||
( n0_partial_unpack * KC_updated ) ),
|
||||
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit + n0_partial_unpack ), KC, ldb,
|
||||
( unpack_b_buffer + n_full_pieces_loop_limit + n0_partial_unpack ), KC, ldb,
|
||||
n0_partial_rem
|
||||
);
|
||||
}
|
||||
@@ -895,7 +895,7 @@ void unpackb_nr64_bf16bf16f32of32_col_major
|
||||
void unpackb_nr64_bf16bf16f32of32
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t rs_b,
|
||||
@@ -904,11 +904,11 @@ void unpackb_nr64_bf16bf16f32of32
|
||||
{
|
||||
if( cs_b == 1 )
|
||||
{
|
||||
unpackb_nr64_bf16bf16f32of32_row_major( b, unpack_b_buffer_bf16bf16f32of32, NC, KC, rs_b );
|
||||
unpackb_nr64_bf16bf16f32of32_row_major( b, unpack_b_buffer, NC, KC, rs_b );
|
||||
}
|
||||
else
|
||||
{
|
||||
unpackb_nr64_bf16bf16f32of32_col_major( b, unpack_b_buffer_bf16bf16f32of32, NC, KC, cs_b );
|
||||
unpackb_nr64_bf16bf16f32of32_col_major( b, unpack_b_buffer, NC, KC, cs_b );
|
||||
}
|
||||
}
|
||||
#endif // BLIS_ADDON_LPGEMM
|
||||
|
||||
Reference in New Issue
Block a user