New kernels for int4 B matrix reordering following BF16 kernel schema.

-To enable Weight-only-Quantization (WOQ) workflow, new LPGEMM APIs
are required where data types are A:bf16, B:int4 and C:f32/bf16. It
is expected that the BF16 kernels will be reused within this API and
subsequently the B matrix needs to be reordered following the BF16
kernel schema, but with the reordered matrix type still being int4. To
address this, new BF16 reorder kernels enabling the same are added.

AMD-Internal: [SWLCSG-2943]
Change-Id: Ib770ecbf90a3d906deafece94b1a96e0b9412738
This commit is contained in:
mkadavil
2024-07-23 07:37:23 +05:30
committed by Nallani Bhaskar
parent eacad443e3
commit 7114376519
7 changed files with 1325 additions and 288 deletions

View File

@@ -0,0 +1,887 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <immintrin.h>
#include <string.h>
#include "blis.h"
#ifdef BLIS_ADDON_LPGEMM
#include "../int4_utils_avx512.h"
void packb_nr64_bf16s4f32of32_row_major
(
int8_t* pack_b_buffer,
const int8_t* b,
const dim_t rs_b,
const dim_t NC,
const dim_t KC,
dim_t* rs_p,
dim_t* cs_p
);
void packb_nr48_bf16s4f32of32_row_major
(
int8_t* pack_b_buffer,
const int8_t* b,
const dim_t rs_b,
const dim_t KC
);
void packb_nr32_bf16s4f32of32_row_major
(
int8_t* pack_b_buffer,
const int8_t* b,
const dim_t rs_b,
const dim_t KC
);
void packb_nr16_bf16s4f32of32_row_major
(
int8_t* pack_b_buffer,
const int8_t* b,
const dim_t rs_b,
const dim_t KC
);
void packb_nrlt16_bf16s4f32of32_row_major
(
int8_t* pack_b_buffer,
const int8_t* b,
const dim_t rs_b,
const dim_t KC,
const dim_t n0_partial_rem
);
void packb_nr64_bf16s4f32of32
(
int8_t* pack_b_buffer,
const int8_t* b,
const dim_t rs_b,
const dim_t cs_b,
const dim_t NC,
const dim_t KC,
dim_t* rs_p,
dim_t* cs_p
)
{
if (cs_b == 1)
{
packb_nr64_bf16s4f32of32_row_major(pack_b_buffer,
b, rs_b, NC, KC, rs_p, cs_p);
}
else
{
bli_print_msg("Only row major supported for int4 packing.",
__FILE__, __LINE__);
return;
}
}
void packb_nr64_bf16s4f32of32_row_major
(
int8_t* pack_b_buffer,
const int8_t* b,
const dim_t rs_b,
const dim_t NC,
const dim_t KC,
dim_t* rs_p,
dim_t* cs_p
)
{
dim_t NR = 64;
dim_t n_full_pieces = NC / NR;
dim_t n_full_pieces_loop_limit = n_full_pieces * NR;
dim_t n_partial_pieces = NC % NR;
dim_t k_full_pieces_blks = KC / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = KC % 2;
// KC when not multiple of 2 will have padding to make it multiple of 2
// in packed buffer.
dim_t KC_updated = KC;
if ( k_partial_pieces > 0 )
{
KC_updated += ( 2 - k_partial_pieces );
}
bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE;
bool signed_upscale = TRUE;
const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments.
// Used for permuting the mm512i elements for use in dpbf16_ps instruction.
__m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB );
__m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF );
// Selectors for int4 -> int8 conversion.
__m512i shift_idx_64;
MULTISHIFT_32BIT_8_INT4_IDX_64ELEM( shift_idx_64 );
__m512i sign_comp = _mm512_set1_epi8( 0x08 );
__mmask32 hmask = _cvtu32_mask32(0xFFFFFFFF); // 32 bytes or 64 int4.
__mmask32 hmask_odd = _cvtu32_mask32(0x80000000); // Last 1 int4.
CREATE_CVT_INT4_INT8_PERM_IDX_64ELEM_ODD_LD(conv_shift_arr);
__m512i conv_shift = _mm512_loadu_epi64(conv_shift_arr);
// Selectors for int8 -> int4 conversion.
CREATE_CVT_INT8_INT4_PERM_IDX_64ELEM_2_ZMM_REG(even_idx_arr)
__m512i even_perm_idx = _mm512_loadu_si512( even_idx_arr );
__m512i all_1s = _mm512_maskz_set1_epi8( _cvtu64_mask64( 0xFFFFFFFFFFFFFFFF ), 0x01 );
__m512i odd_perm_idx = _mm512_add_epi8( even_perm_idx, all_1s );
__m512i clear_hi_bits = _mm512_maskz_set1_epi8( _cvtu64_mask64( 0xFFFFFFFFFFFFFFFF ), 0x0F );
__m256i h_a0;
__m256i h_b0;
__m256i h_b0_l4bit;
__m512i a0;
__m512i b0;
__m512i r_lo;
__m512i r_hi;
__m512i s4_out;
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
{
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
{
// Int4 array has to be accessed like byte array, but with
// half the elements traversed in the byte array.
h_a0 = _mm256_maskz_loadu_epi8( hmask,
b + ( ( ( rs_b * ( kr + 0 ) ) + jc ) / incr_adj_factor ) );
CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_a0, a0, shift_idx_64, \
sign_comp, signed_upscale);
// If the stride, i.e. rs_b is odd, then the stride increment
// (rs_b * ...)/2 will point at the byte of which the high 4
// bits is our desired starting element. However since data
// access is at byte level, the low 4 bits of this byte will
// be wrongly included, and additionally the last int4 element
// won't be included either. Extra data movement done to
// account for the same.
// Since kr is a multiple of 2, only kr+1 will have the
// aforementioned issue.
if ( is_odd_stride == FALSE )
{
h_b0 = _mm256_maskz_loadu_epi8( hmask,
b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / incr_adj_factor ) );
CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_b0, b0, shift_idx_64, \
sign_comp, signed_upscale);
}
else
{
h_b0 = _mm256_maskz_loadu_epi8( hmask,
b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / incr_adj_factor ) );
// Only load the last byte/ 32nd byte.
h_b0_l4bit = _mm256_maskz_loadu_epi8( hmask_odd,
b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / incr_adj_factor ) + 1 );
CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(h_b0, h_b0_l4bit, b0, \
shift_idx_64, conv_shift, sign_comp, signed_upscale);
}
// Restructuring at int8 level.
r_lo = _mm512_unpacklo_epi8( a0, b0 );
r_hi = _mm512_unpackhi_epi8( a0, b0 );
a0 = _mm512_permutex2var_epi64( r_lo, selector1, r_hi );
b0 = _mm512_permutex2var_epi64( r_lo, selector1_1, r_hi );
// To be converted to int4 for storing.
CVT_INT8_INT4_64ELEM_2_ZMM_REG(a0, b0, s4_out, \
even_perm_idx, odd_perm_idx, clear_hi_bits);
// Int4 array has to be accessed like byte array, but with
// half the elements traversed in the byte array.
_mm512_storeu_si512( pack_b_buffer +
( ( ( jc * KC_updated ) + ( kr * NR ) ) / incr_adj_factor ),
s4_out );
}
// Handle k remainder.
if( k_partial_pieces > 0)
{
h_a0 = _mm256_maskz_loadu_epi8( hmask,
b + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + jc ) /
incr_adj_factor ) );
CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_a0, a0, shift_idx_64, \
sign_comp, signed_upscale);
b0 = _mm512_setzero_si512();
// Restructuring at int8 level.
r_lo = _mm512_unpacklo_epi8( a0, b0 );
r_hi = _mm512_unpackhi_epi8( a0, b0 );
a0 = _mm512_permutex2var_epi64( r_lo, selector1, r_hi );
b0 = _mm512_permutex2var_epi64( r_lo, selector1_1, r_hi );
// To be converted to int4 for storing.
CVT_INT8_INT4_64ELEM_2_ZMM_REG(a0, b0, s4_out, \
even_perm_idx, odd_perm_idx, clear_hi_bits);
_mm512_storeu_si512( pack_b_buffer +
( ( ( jc * KC_updated ) + ( k_full_pieces * NR ) ) /
incr_adj_factor ), s4_out );
}
}
if(n_partial_pieces > 0)
{
dim_t n0_partial_rem = n_partial_pieces % 16;
dim_t n0_partial_pack = 0;
// Split into multiple smaller fringe kernels, so as to maximize
// vectorization after packing. Any n0 < NR(64) can be expressed
// as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16.
dim_t n0_48 = n_partial_pieces / 48;
dim_t n0_32 = n_partial_pieces / 32;
dim_t n0_16 = n_partial_pieces / 16;
if ( n0_48 == 1 )
{
packb_nr48_bf16s4f32of32_row_major
(
( pack_b_buffer +
( ( n_full_pieces_loop_limit * KC_updated ) /
incr_adj_factor ) ),
( b + ( n_full_pieces_loop_limit / incr_adj_factor ) ),
rs_b, KC
);
n0_partial_pack = 48;
}
else if ( n0_32 == 1 )
{
packb_nr32_bf16s4f32of32_row_major
(
( pack_b_buffer +
( ( n_full_pieces_loop_limit * KC_updated ) /
incr_adj_factor ) ),
( b + ( n_full_pieces_loop_limit / incr_adj_factor ) ),
rs_b, KC
);
n0_partial_pack = 32;
}
else if ( n0_16 == 1 )
{
packb_nr16_bf16s4f32of32_row_major
(
( pack_b_buffer +
( ( n_full_pieces_loop_limit * KC_updated ) /
incr_adj_factor ) ),
( b + ( n_full_pieces_loop_limit / incr_adj_factor ) ),
rs_b, KC
);
n0_partial_pack = 16;
}
if ( n0_partial_rem > 0 )
{
packb_nrlt16_bf16s4f32of32_row_major
(
( pack_b_buffer + ( ( ( n_full_pieces_loop_limit * KC_updated ) +
( n0_partial_pack * KC_updated ) ) / incr_adj_factor ) ),
( b + ( ( n_full_pieces_loop_limit + n0_partial_pack ) /
incr_adj_factor ) ),
rs_b, KC, n0_partial_rem
);
}
}
*rs_p = NR * 2;
*cs_p = NR / 2;
}
void packb_nr48_bf16s4f32of32_row_major
(
int8_t* pack_b_buffer,
const int8_t* b,
const dim_t rs_b,
const dim_t KC
)
{
const dim_t NR = 48;
const dim_t NR_32x2 = 64;
dim_t k_full_pieces_blks = KC / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = KC % 2;
bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE;
bool signed_upscale = TRUE;
const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments.
// Used for permuting the mm512i elements for use in dpbf16_ps instruction.
__m256i selector1_32 = _mm256_setr_epi64x( 0x0, 0x1, 0x4, 0x5 );
__m256i selector1_1_32 = _mm256_setr_epi64x( 0x2, 0x3, 0x6, 0x7 );
// Selectors for int4 -> int8 conversion.
// First 32 int4 elements selectors.
__m256i shift_idx_32;
MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32);
__m256i sign_comp_32 = _mm256_set1_epi8( 0x08 );
__mmask16 hmask_32 = _cvtu32_mask16( 0x0000FFFF ); //16 bytes or 32 int4.
__mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4.
CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32);
__m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ),
conv_shift_arr_32 );
// Next 16 int4 elements selectors.
__m128i shift_idx_16;
MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16);
__m128i sign_comp_16 = _mm_set1_epi8( 0x08 );
__mmask16 hmask_16 = _cvtu32_mask16( 0x000000FF ); //8 bytes or 16 int4.
__mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4.
CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16);
__m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ),
conv_shift_arr_16 );
// Selectors for int8 -> int4 conversion.
// First 32 int8 elements selectors.
CREATE_CVT_INT8_INT4_PERM_IDX_32ELEM_2_YMM_REG(even_idx_arr_32);
__m256i even_perm_idx_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ),
even_idx_arr_32 );
__m256i all_1s_32 = _mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ),
0x01 );
__m256i odd_perm_idx_32 = _mm256_add_epi8( even_perm_idx_32, all_1s_32 );
__m256i clear_hi_bits_32 =
_mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ), 0x0F );
// Next 16 int4 elements selectors.
CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(even_idx_arr_16);
__m128i even_perm_idx_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ),
even_idx_arr_16 );
__m128i all_1s_16 = _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ),
0x01 );
__m128i odd_perm_idx_16 = _mm_add_epi8( even_perm_idx_16, all_1s_16 );
__m128i clear_hi_bits_16 =
_mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), 0x0F );
__mmask16 sel_all_mask_16 = _cvtu32_mask16( 0xFFFF );
__m128i h_a0_32;
__m128i h_b0_32;
__m128i h_b0_32_l4bit;
__m128i a0_16;
__m128i b0_16;
__m128i r_lo_16;
__m128i r_hi_16;
__m128i s4_out_16;
__m256i a0_32;
__m256i b0_32;
__m256i r_lo_32;
__m256i r_hi_32;
__m256i s4_out_32;
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
{
// First 32 columns.
h_a0_32 = _mm_maskz_loadu_epi8( hmask_32,
b + ( ( rs_b * ( kr + 0 ) ) / incr_adj_factor ) );
CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \
sign_comp_32, signed_upscale);
// Last 16 columns.
h_a0_32 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( ( rs_b * ( kr + 0 ) ) + 32 ) / 2 ) );
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_32, a0_16, shift_idx_16, \
sign_comp_16, signed_upscale);
if ( is_odd_stride == FALSE )
{
// First 32 columns.
h_b0_32 = _mm_maskz_loadu_epi8( hmask_32,
b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) );
CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \
sign_comp_32, signed_upscale);
// Last 16 columns.
h_b0_32 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) );
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_32, b0_16, shift_idx_16, \
sign_comp_16, signed_upscale);
}
else
{
// First 32 columns.
h_b0_32 = _mm_maskz_loadu_epi8( hmask_32,
b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) );
// Only load the last byte/ 16th byte.
h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32,
b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) + 1 );
CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, \
b0_32, shift_idx_32, conv_shift_32, sign_comp_32, \
signed_upscale);
// Last 16 columns.
h_b0_32 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) );
// Only load the last byte/ 8th byte.
h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16,
b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) + 1 );
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, \
b0_16, shift_idx_16, conv_shift_16, sign_comp_16, \
signed_upscale);
}
// Restructuring at int8 level.
// First 32 columns.
r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 );
r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 );
a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 );
b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 );
CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \
even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32);
_mm256_storeu_epi64( pack_b_buffer +
( ( kr * NR ) / incr_adj_factor ), s4_out_32 );
// Last 16 columns.
r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 );
r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 );
CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \
even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16);
_mm_storeu_epi64( pack_b_buffer +
( ( ( kr * NR ) + NR_32x2 ) / incr_adj_factor ), s4_out_16 );
}
// Handle k remainder.
if( k_partial_pieces > 0)
{
// First 32 columns.
h_a0_32 = _mm_maskz_loadu_epi8( hmask_32,
b + ( ( rs_b * ( k_full_pieces + 0 ) ) / incr_adj_factor ) );
CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \
sign_comp_32, signed_upscale);
b0_32 = _mm256_setzero_si256();
r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 );
r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 );
a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 );
b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 );
CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \
even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32);
_mm256_storeu_epi64( pack_b_buffer +
( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_32 );
// Last 16 columns.
h_a0_32 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + 32 ) / 2 ) );
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_32, a0_16, shift_idx_16, \
sign_comp_16, signed_upscale);
b0_16 = _mm_setzero_si128();
r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 );
r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 );
CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \
even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16);
_mm_storeu_epi64( pack_b_buffer +
( ( ( k_full_pieces * NR ) + NR_32x2 ) / incr_adj_factor ), s4_out_16 );
}
}
void packb_nr32_bf16s4f32of32_row_major
(
int8_t* pack_b_buffer,
const int8_t* b,
const dim_t rs_b,
const dim_t KC
)
{
const dim_t NR = 32;
dim_t k_full_pieces_blks = KC / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = KC % 2;
bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE;
bool signed_upscale = TRUE;
const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments.
// Used for permuting the mm512i elements for use in dpbf16_ps instruction.
__m256i selector1_32 = _mm256_setr_epi64x( 0x0, 0x1, 0x4, 0x5 );
__m256i selector1_1_32 = _mm256_setr_epi64x( 0x2, 0x3, 0x6, 0x7 );
// Selectors for int4 -> int8 conversion.
__m256i shift_idx_32;
MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32);
__m256i sign_comp_32 = _mm256_set1_epi8( 0x08 );
__mmask16 hmask_32 = _cvtu32_mask16( 0x0000FFFF ); //16 bytes or 32 int4.
__mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4.
CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32);
__m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ),
conv_shift_arr_32 );
// Selectors for int8 -> int4 conversion.
CREATE_CVT_INT8_INT4_PERM_IDX_32ELEM_2_YMM_REG(even_idx_arr_32);
__m256i even_perm_idx_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ),
even_idx_arr_32 );
__m256i all_1s_32 = _mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ),
0x01 );
__m256i odd_perm_idx_32 = _mm256_add_epi8( even_perm_idx_32, all_1s_32 );
__m256i clear_hi_bits_32 =
_mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ), 0x0F );
__m128i h_a0_32;
__m128i h_b0_32;
__m128i h_b0_32_l4bit;
__m256i a0_32;
__m256i b0_32;
__m256i r_lo_32;
__m256i r_hi_32;
__m256i s4_out_32;
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
{
h_a0_32 = _mm_maskz_loadu_epi8( hmask_32,
b + ( ( rs_b * ( kr + 0 ) ) / incr_adj_factor ) );
CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \
sign_comp_32, signed_upscale);
if ( is_odd_stride == FALSE )
{
h_b0_32 = _mm_maskz_loadu_epi8( hmask_32,
b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) );
CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \
sign_comp_32, signed_upscale);
}
else
{
h_b0_32 = _mm_maskz_loadu_epi8( hmask_32,
b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) );
// Only load the last byte/ 16th byte.
h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32,
b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) + 1 );
CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, \
b0_32, shift_idx_32, conv_shift_32, sign_comp_32, \
signed_upscale);
}
// Restructuring at int8 level.
// First 32 columns.
r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 );
r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 );
a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 );
b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 );
CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \
even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32);
_mm256_storeu_epi64( pack_b_buffer +
( ( kr * NR ) / incr_adj_factor ), s4_out_32 );
}
// Handle k remainder.
if( k_partial_pieces > 0)
{
h_a0_32 = _mm_maskz_loadu_epi8( hmask_32,
b + ( ( rs_b * ( k_full_pieces + 0 ) ) / incr_adj_factor ) );
CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \
sign_comp_32, signed_upscale);
b0_32 = _mm256_setzero_si256();
r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 );
r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 );
a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 );
b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 );
CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \
even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32);
_mm256_storeu_epi64( pack_b_buffer +
( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_32 );
}
}
void packb_nr16_bf16s4f32of32_row_major
(
int8_t* pack_b_buffer,
const int8_t* b,
const dim_t rs_b,
const dim_t KC
)
{
const dim_t NR = 16;
dim_t k_full_pieces_blks = KC / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = KC % 2;
bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE;
bool signed_upscale = TRUE;
const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments.
// Selectors for int4 -> int8 conversion.
__m128i shift_idx_16;
MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16);
__m128i sign_comp_16 = _mm_set1_epi8( 0x08 );
__mmask16 hmask_16 = _cvtu32_mask16( 0x000000FF ); //8 bytes or 16 int4.
__mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4.
CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16);
__m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ),
conv_shift_arr_16 );
// Selectors for int8 -> int4 conversion.
CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(even_idx_arr_16);
__m128i even_perm_idx_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ),
even_idx_arr_16 );
__m128i all_1s_16 = _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ),
0x01 );
__m128i odd_perm_idx_16 = _mm_add_epi8( even_perm_idx_16, all_1s_16 );
__m128i clear_hi_bits_16 =
_mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), 0x0F );
__mmask16 sel_all_mask_16 = _cvtu32_mask16( 0xFFFF );
__m128i h_a0_16;
__m128i h_b0_16;
__m128i h_b0_16_l4bit;
__m128i a0_16;
__m128i b0_16;
__m128i r_lo_16;
__m128i r_hi_16;
__m128i s4_out_16;
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
{
h_a0_16 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( rs_b * ( kr + 0 ) ) / 2 ) );
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \
sign_comp_16, signed_upscale);
if ( is_odd_stride == FALSE )
{
h_b0_16 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( rs_b * ( kr + 1 ) ) / 2 ) );
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \
sign_comp_16, signed_upscale);
}
else
{
h_b0_16 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( rs_b * ( kr + 1 ) ) / 2 ) );
// Only load the last byte/ 8th byte.
h_b0_16_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16,
b + ( ( rs_b * ( kr + 1 ) ) / 2 ) + 1 );
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, \
b0_16, shift_idx_16, conv_shift_16, sign_comp_16, \
signed_upscale);
}
r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 );
r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 );
CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \
even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16);
_mm_storeu_epi64( pack_b_buffer +
( ( kr * NR ) / incr_adj_factor ), s4_out_16 );
}
// Handle k remainder.
if( k_partial_pieces > 0)
{
h_a0_16 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) );
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \
sign_comp_16, signed_upscale);
b0_16 = _mm_setzero_si128();
r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 );
r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 );
CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \
even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16);
_mm_storeu_epi64( pack_b_buffer +
( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_16 );
}
}
void packb_nrlt16_bf16s4f32of32_row_major
(
int8_t* pack_b_buffer,
const int8_t* b,
const dim_t rs_b,
const dim_t KC,
const dim_t n0_partial_rem
)
{
const dim_t NR = 16;
dim_t k_full_pieces_blks = KC / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = KC % 2;
bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE;
bool signed_upscale = TRUE;
const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments.
// Selectors for int4 -> int8 conversion.
__m128i shift_idx_16;
MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16);
__m128i sign_comp_16 = _mm_set1_epi8( 0x08 );
// 16 int4 elems in 8 bytes, so adjusting the mask for nr < 16 by
// a factor of 2. In case of odd remainder, the last int4 element
// within the last byte (hi 4 bits) will be ingnored similar to
// padding bits.
__mmask16 hmask_16;
if ( is_odd_stride == FALSE )
{
hmask_16 = _cvtu32_mask16( 0x000000FF >>
( ( 16 - n0_partial_rem ) / 2 ) );
}
else
{
if ( ( n0_partial_rem % 2 ) == 0 )
{
// An interesting property here is that n0_partial_rem is
// guaranteed to be < 16. In that case the largest even n0
// rem would be 14, and the max number of bytes that will be
// loaded including the extra 4 bit at the beginning will
// only be 7 bytes out of 8. So in any case loading 1 more
// byte will bring the last int4 in the register, while not
// crossing the register boundaries.
hmask_16 = _cvtu32_mask16( 0x000000FF >>
( ( ( 16 - n0_partial_rem ) / 2 ) - 1 ) );
}
else
{
// If the n0 rem is odd, and if the starting position is an odd
// index, then the last odd element will also be loaded as part
// of loading the last byte (high 4 bits of last byte).
hmask_16 = _cvtu32_mask16( 0x000000FF >>
( ( 16 - n0_partial_rem ) / 2 ) );
}
}
CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16);
__m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ),
conv_shift_arr_16 );
// Selectors for int8 -> int4 conversion.
CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(even_idx_arr_16);
__m128i even_perm_idx_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ),
even_idx_arr_16 );
__m128i all_1s_16 = _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ),
0x01 );
__m128i odd_perm_idx_16 = _mm_add_epi8( even_perm_idx_16, all_1s_16 );
__m128i clear_hi_bits_16 =
_mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), 0x0F );
__mmask16 sel_all_mask_16 = _cvtu32_mask16( 0xFFFF );
__m128i h_a0_16;
__m128i h_b0_16;
__m128i a0_16;
__m128i b0_16;
__m128i r_lo_16;
__m128i r_hi_16;
__m128i s4_out_16;
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
{
h_a0_16 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( rs_b * ( kr + 0 ) ) / 2 ) );
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \
sign_comp_16, signed_upscale);
if ( is_odd_stride == FALSE )
{
h_b0_16 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( rs_b * ( kr + 1 ) ) / 2 ) );
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \
sign_comp_16, signed_upscale);
}
else
{
h_b0_16 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( rs_b * ( kr + 1 ) ) / 2 ) );
// The last int4 elem is already loaded in the previous
// register. Details given in comments about hmask_16.
__m128i h_b0_16_l4bit = _mm_setzero_si128();
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, \
b0_16, shift_idx_16, conv_shift_16, sign_comp_16, \
signed_upscale);
}
r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 );
r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 );
CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \
even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16);
_mm_storeu_epi64( pack_b_buffer +
( ( kr * NR ) / incr_adj_factor ), s4_out_16 );
}
// Handle k remainder.
if( k_partial_pieces > 0)
{
h_a0_16 = _mm_maskz_loadu_epi8( hmask_16,
b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) );
CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \
sign_comp_16, signed_upscale);
b0_16 = _mm_setzero_si128();
r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 );
r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 );
CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \
even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16);
_mm_storeu_epi64( pack_b_buffer +
( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_16 );
}
}
#endif

View File

@@ -479,17 +479,17 @@ LPGEMV_M_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32)
// bf16 zero point value (scalar or vector).
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 )
{
zero_point0 = _mm512_cvtpbh_ps(
( __m256bh )_mm256_maskz_set1_epi16( zp_mask,
zero_point0 = CVT_BF16_F32_INT_SHIFT(
_mm256_maskz_set1_epi16( zp_mask,
*( ( bfloat16* )post_ops_list_temp->op_args1 ) ) );
zero_point1 = _mm512_cvtpbh_ps(
( __m256bh )_mm256_maskz_set1_epi16( zp_mask,
zero_point1 = CVT_BF16_F32_INT_SHIFT(
_mm256_maskz_set1_epi16( zp_mask,
*( ( bfloat16* )post_ops_list_temp->op_args1 ) ) );
zero_point2 = _mm512_cvtpbh_ps(
( __m256bh )_mm256_maskz_set1_epi16( zp_mask,
zero_point2 = CVT_BF16_F32_INT_SHIFT(
_mm256_maskz_set1_epi16( zp_mask,
*( ( bfloat16* )post_ops_list_temp->op_args1 ) ) );
zero_point3 = _mm512_cvtpbh_ps(
( __m256bh )_mm256_maskz_set1_epi16( zp_mask,
zero_point3 = CVT_BF16_F32_INT_SHIFT(
_mm256_maskz_set1_epi16( zp_mask,
*( ( bfloat16* )post_ops_list_temp->op_args1 ) ) );
}
@@ -518,20 +518,20 @@ LPGEMV_M_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32)
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_cvtpbh_ps(
( __m256bh )_mm256_maskz_loadu_epi16( k1,
zero_point0 = CVT_BF16_F32_INT_SHIFT(
_mm256_maskz_loadu_epi16( k1,
( ( bfloat16* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_j + ( 0 * 16 ) ) );
zero_point1 = _mm512_cvtpbh_ps(
( __m256bh )_mm256_maskz_loadu_epi16( k2,
zero_point1 = CVT_BF16_F32_INT_SHIFT(
_mm256_maskz_loadu_epi16( k2,
( ( bfloat16* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_j + ( 1 * 16 ) ) );
zero_point2 = _mm512_cvtpbh_ps(
( __m256bh )_mm256_maskz_loadu_epi16( k3,
zero_point2 = CVT_BF16_F32_INT_SHIFT(
_mm256_maskz_loadu_epi16( k3,
( ( bfloat16* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_j + ( 2 * 16 ) ) );
zero_point3 = _mm512_cvtpbh_ps(
( __m256bh )_mm256_maskz_loadu_epi16( k4,
zero_point3 = CVT_BF16_F32_INT_SHIFT(
_mm256_maskz_loadu_epi16( k4,
( ( bfloat16* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_j + ( 3 * 16 ) ) );
}

View File

@@ -672,8 +672,8 @@ LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32)
// bf16 zero point value (scalar or vector).
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 )
{
zero_point0 = _mm512_cvtpbh_ps(
( __m256bh )_mm256_maskz_set1_epi16( zp_mask,
zero_point0 = CVT_BF16_F32_INT_SHIFT(
_mm256_maskz_set1_epi16( zp_mask,
*( ( bfloat16* )post_ops_list_temp->op_args1 ) ) );
}
@@ -701,8 +701,8 @@ LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32)
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_cvtpbh_ps(
( __m256bh )_mm256_maskz_loadu_epi16( k2,
zero_point0 = CVT_BF16_F32_INT_SHIFT(
_mm256_maskz_loadu_epi16( k2,
( ( bfloat16* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 ) );
}

View File

@@ -0,0 +1,397 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef LPGEMM_INT4_CVT_UTILS_H
#define LPGEMM_INT4_CVT_UTILS_H
/* shift_idx:__m512i*/
#define MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx) \
/* Multi shift uses indices that corresponds to the bit starting positions
* of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12,
* 16, 20, 24, 28. */ \
shift_idx = _mm512_set1_epi64( 0x1C1814100C080400lu );
/* shift_idx:__m256i*/
#define MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx) \
/* Multi shift uses indices that corresponds to the bit starting positions
* of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12,
* 16, 20, 24, 28. */ \
shift_idx = _mm256_maskz_set1_epi64( _cvtu32_mask8( 0xFF ), \
0x1C1814100C080400lu );
/* shift_idx:__m128i*/
#define MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx) \
/* Multi shift uses indices that corresponds to the bit starting positions
* of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12,
* 16, 20, 24, 28. */ \
shift_idx = _mm_maskz_set1_epi64( _cvtu32_mask8( 0xFF ), \
0x1C1814100C080400lu );
/* input:__m256i, output: __m512i*/
#define UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx) \
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). Unsigned conversion is
* used so as to ensure the signed bit in int4 at MSB position of 4
* byte group is not modified. */ \
output = _mm512_multishift_epi64_epi8( shift_idx, \
_mm512_cvtepu32_epi64( input ) ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm512_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm512_set1_epi8( 0x0F ) );
/* input:__m256i, output: __m512i*/
#define UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, \
output, odd_shift_idx, conv_shift) \
/* Unsigned conversion is used so as to ensure the signed bit.
* in int4 at MSB position of 4 byte group is not modified. */ \
__m512i upscale_input = _mm512_cvtepu32_epi64( input_0 ); \
__m512i shift_input = _mm512_cvtepu32_epi64( input_1 ); \
\
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). */ \
output = _mm512_multishift_epi64_epi8( odd_shift_idx, upscale_input ); \
\
/* Combine both the input registers, starting from elem[1] till elem[n-1]
* in output(without elem[0]), and first non zero element in shift_input.
* It is at this point that the first 4bit and last 4bit elements, the 2
* that were loaded extra due to byte level access are discarded. */ \
output = _mm512_permutex2var_epi8( output, conv_shift, shift_input ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm512_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm512_set1_epi8( 0x0F ) );
/* input:__m128i, output: __m256i*/
#define UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx) \
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). Unsigned conversion is
* used so as to ensure the signed bit in int4 at MSB position of 4
* byte group is not modified. */ \
output = _mm256_multishift_epi64_epi8( shift_idx, \
_mm256_cvtepu32_epi64( input ) ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm256_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm256_set1_epi8( 0x0F ) );
/* input:__m128i, output: __m256i*/
#define UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, \
output, odd_shift_idx, conv_shift) \
/* Unsigned conversion is used so as to ensure the signed bit.
* in int4 at MSB position of 4 byte group is not modified. */ \
__m256i upscale_input = _mm256_cvtepu32_epi64( input_0 ); \
__m256i shift_input = _mm256_cvtepu32_epi64( input_1 ); \
\
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). */ \
output = _mm256_multishift_epi64_epi8( odd_shift_idx, upscale_input ); \
\
/* Combine both the input registers, starting from elem[1] till elem[n-1]
* in output(without elem[0]), and first non zero element in shift_input.
* It is at this point that the first 4bit and last 4bit elements, the 2
* that were loaded extra due to byte level access are discarded. */ \
output = _mm256_permutex2var_epi8( output, conv_shift, shift_input ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm256_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm256_set1_epi8( 0x0F ) );
/* input:int64_t, output: __m128i*/
#define UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx) \
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). Unsigned conversion is
* used so as to ensure the signed bit in int4 at MSB position of 4
* byte group is not modified. */ \
output = _mm_multishift_epi64_epi8( shift_idx, \
_mm_cvtepu32_epi64( input ) ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm_set1_epi8( 0x0F ) );
/* input:int64_t, output:__m128i*/
#define UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, \
output, odd_shift_idx, conv_shift) \
/* Unsigned conversion is used so as to ensure the signed bit.
* in int4 at MSB position of 4 byte group is not modified. */ \
input_0 = _mm_cvtepu32_epi64( input_0 ); \
input_1 = _mm_cvtepu32_epi64( input_1 ); \
\
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). */ \
output = _mm_multishift_epi64_epi8( odd_shift_idx, input_0 ); \
\
/* Combine both the input registers, starting from elem[1] till elem[n-1]
* in output(without elem[0]), and first non zero element in shift_input.
* It is at this point that the first 4bit and last 4bit elements, the 2
* that were loaded extra due to byte level access are discarded. */ \
output = _mm_permutex2var_epi8( output, conv_shift, input_1 ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm_set1_epi8( 0x0F ) );
#define SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp) \
/* Comparison of signed bit in int4 and appending sign bits. */ \
/* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit
* is 1) to 1 and rest every other bits to 0. */ \
__m512i hi_bits_512 = _mm512_and_epi32( output, sign_comp ); \
\
/* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit
* is 0) to 1 and rest every other bits to 0. */ \
hi_bits_512 = _mm512_xor_epi32( hi_bits_512, sign_comp ); \
\
/* Set the sign extension bits on an int8_t size basis, this will then be
* OR with output to get the signed outputs. */ \
hi_bits_512 = _mm512_add_epi8( hi_bits_512, _mm512_set1_epi8( 0xF8 ) ); \
\
output = _mm512_or_epi32( output, hi_bits_512 );
#define SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp) \
/* Comparison of signed bit in int4 and appending sign bits. */ \
/* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit
* is 1) to 1 and rest every other bits to 0. */ \
__m256i hi_bits_256 = _mm256_maskz_and_epi32( _cvtu32_mask8( 0xFF ),\
output, sign_comp ); \
\
/* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit
* is 0) to 1 and rest every other bits to 0. */ \
hi_bits_256 = _mm256_xor_epi32( hi_bits_256, sign_comp ); \
\
/* Set the sign extension bits on an int8_t size basis, this will then be
* OR with output to get the signed outputs. */ \
hi_bits_256 = _mm256_add_epi8( hi_bits_256, _mm256_set1_epi8( 0xF8 ) ); \
\
output = _mm256_or_epi32( output, hi_bits_256 );
#define SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp) \
/* Comparison of signed bit in int4 and appending sign bits. */ \
/* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit
* is 1) to 1 and rest every other bits to 0. */ \
__m128i hi_bits_128 = _mm_maskz_and_epi32( _cvtu32_mask8( 0xFF ),\
output, sign_comp ); \
\
/* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit
* is 0) to 1 and rest every other bits to 0. */ \
hi_bits_128 = _mm_xor_epi32( hi_bits_128, sign_comp ); \
\
/* Set the sign extension bits on an int8_t size basis, this will then be
* OR with output to get the signed outputs. */ \
hi_bits_128 = _mm_add_epi8( hi_bits_128, _mm_set1_epi8( 0xF8 ) ); \
\
output = _mm_or_epi32( output, hi_bits_128 );
/* input:__m256i, output: __m512i*/
#define CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp); \
} \
} while (0);
#define CREATE_CVT_INT4_INT8_PERM_IDX_64ELEM_ODD_LD(var_name) \
const int64_t var_name[8] = { \
0x0807060504030201, 0x100F0E0D0C0B0A09, \
0X1817161514131211, 0X201F1E1D1C1B1A19, \
0X2827262524232221, 0X302F2E2D2C2B2A29, \
0X3837363534333231, 0X7B3F3E3D3C3B3A39 };
/* input:__m256i, output: __m512i*/
#define CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp); \
} \
} while (0);
/* input:__m128i, output: __m256i*/
#define CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp); \
} \
} while (0);
#define CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(var_name) \
const int64_t var_name[4] = { \
0x0807060504030201, 0x100F0E0D0C0B0A09, \
0X1817161514131211, 0X3B1F1E1D1C1B1A19 };
/* input:__m128i, output: __m256i*/
#define CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp); \
} \
} while (0);
/* input:int64_t, output: __m128i*/
#define CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp); \
} \
} while (0);
#define CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(var_name) \
const int64_t var_name[2] = { \
0x0807060504030201, 0x1B0F0E0D0C0B0A09 };
/* input:int64_t, output: __m128i*/
#define CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp); \
} \
} while (0);
#define CREATE_CVT_INT8_INT4_PERM_IDX_64ELEM_2_ZMM_REG(var_name) \
int8_t var_name[64] __attribute__((aligned(64))) = \
{0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, \
0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E, \
0x20, 0x22, 0x24, 0x26, 0x28, 0x2A, 0x2C, 0x2E, \
0x30, 0x32, 0x34, 0x36, 0x38, 0x3A, 0x3C, 0x3E, \
0x40, 0x42, 0x44, 0x46, 0x48, 0x4A, 0x4C, 0x4E, \
0x50, 0x52, 0x54, 0x56, 0x58, 0x5A, 0x5C, 0x5E, \
0x60, 0x62, 0x64, 0x66, 0x68, 0x6A, 0x6C, 0x6E, \
0x70, 0x72, 0x74, 0x76, 0x78, 0x7A, 0x7C, 0x7E};
/* Conversion from int8 to int4. First split the elements in __m512i
* register at even indices and odd indices into two separate __m256i
* even and odd registers. Then shift the elements in odd by 4 to the
* left and OR with even register. */
/* input_*:__m512i, output: __m512i */
#define CVT_INT8_INT4_64ELEM_2_ZMM_REG(input_0, input_1, output, \
even_perm_idx, odd_perm_idx, clear_hi_bits) \
do { \
output = _mm512_permutex2var_epi8( input_0, even_perm_idx, input_1 ); \
__m512i odd_out = _mm512_permutex2var_epi8( input_0, \
odd_perm_idx, input_1 ); \
\
/* Ensure the hi 4 bits are cleared. */ \
output = _mm512_and_epi32( output, clear_hi_bits ); \
\
__m256i odd1_256 = _mm512_extracti64x4_epi64( odd_out, 0x0 ); \
__m256i odd2_256 = _mm512_extracti64x4_epi64( odd_out, 0x1 ); \
\
/* Shift the elemts in odd register by 4 to the left. */ \
odd1_256 = _mm512_cvtepi16_epi8( \
_mm512_slli_epi16( _mm512_cvtepu8_epi16( odd1_256 ), 0x4 ) ); \
odd2_256 = _mm512_cvtepi16_epi8( \
_mm512_slli_epi16( _mm512_cvtepu8_epi16( odd2_256 ), 0x4 ) ); \
\
odd_out = _mm512_castsi256_si512( odd1_256 ); \
odd_out = _mm512_inserti64x4( odd_out, odd2_256, 0x01 ); \
\
output = _mm512_or_epi32( output, odd_out ); \
} while (0);
#define CREATE_CVT_INT8_INT4_PERM_IDX_32ELEM_2_YMM_REG(var_name) \
int8_t var_name[32] __attribute__((aligned(64))) = \
{0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, \
0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E, \
0x20, 0x22, 0x24, 0x26, 0x28, 0x2A, 0x2C, 0x2E, \
0x30, 0x32, 0x34, 0x36, 0x38, 0x3A, 0x3C, 0x3E};
/* input_*:__m256i, output: __m256i */
#define CVT_INT8_INT4_32ELEM_2_YMM_REG(input_0, input_1, output, \
even_perm_idx, odd_perm_idx, clear_hi_bits) \
do { \
output = _mm256_permutex2var_epi8( input_0, even_perm_idx, input_1 ); \
__m256i odd_out = _mm256_permutex2var_epi8( input_0, \
odd_perm_idx, input_1 ); \
\
/* Ensure the hi 4 bits are cleared. */ \
output = _mm256_maskz_and_epi32( _cvtu32_mask8( 0xFF ), \
output, clear_hi_bits ); \
\
/* Shift the elemts in odd register by 4 to the left. */ \
odd_out = _mm512_cvtepi16_epi8( \
_mm512_slli_epi16( _mm512_cvtepu8_epi16( odd_out ), 0x4 ) ); \
\
output = _mm256_or_epi32( output, odd_out ); \
} while (0);
#define CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(var_name) \
int8_t var_name[16] __attribute__((aligned(64))) = \
{0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, \
0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E};
/* input_*:__m128i, output: __m128i */
#define CVT_INT8_INT4_16ELEM_2_XMM_REG(input_0, input_1, output, \
even_perm_idx, odd_perm_idx, clear_hi_bits) \
do { \
output = _mm_permutex2var_epi8( input_0, even_perm_idx, input_1 ); \
__m128i odd_out = _mm_permutex2var_epi8( input_0, \
odd_perm_idx, input_1 ); \
\
/* Ensure the hi 4 bits are cleared. */ \
output = _mm_maskz_and_epi32( _cvtu32_mask8( 0xFF ), \
output, clear_hi_bits ); \
\
/* Shift the elemts in odd register by 4 to the left. */ \
__mmask16 sel_all_mask = _cvtu32_mask16( 0xFFFF ); \
odd_out = _mm256_maskz_cvtepi16_epi8( sel_all_mask, \
_mm256_maskz_slli_epi16( sel_all_mask, \
_mm256_maskz_cvtepu8_epi16( sel_all_mask, odd_out ), 0x4 ) ); \
\
output = _mm_or_epi32( output, odd_out ); \
} while (0);
#endif //LPGEMM_INT4_CVT_UTILS_H

View File

@@ -231,11 +231,7 @@ void packb_nr64_u8s8s32o32_row_major
__mmask32 hmask = _cvtu32_mask32(0xFFFFFFFF); // 32 bytes or 64 int4.
__mmask32 hmask_odd = _cvtu32_mask32(0x80000000); // Last 1 int4.
const int64_t conv_shift_arr[8] = {
0x0807060504030201, 0x100F0E0D0C0B0A09, \
0X1817161514131211, 0X201F1E1D1C1B1A19, \
0X2827262524232221, 0X302F2E2D2C2B2A29, \
0X3837363534333231, 0X7B3F3E3D3C3B3A39 };
CREATE_CVT_INT4_INT8_PERM_IDX_64ELEM_ODD_LD(conv_shift_arr);
__m512i conv_shift = _mm512_loadu_epi64(conv_shift_arr);
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
@@ -569,6 +565,7 @@ void packb_nr48_u8s8s32o32_row_major
__m128i a01_16;
__m128i c01_16;
// First 32 int4 elements selectors.
__m256i shift_idx_32;
MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32);
@@ -577,12 +574,11 @@ void packb_nr48_u8s8s32o32_row_major
__mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4.
const int64_t conv_shift_arr_32[4] = {
0x0807060504030201, 0x100F0E0D0C0B0A09, \
0X1817161514131211, 0X3B1F1E1D1C1B1A19 };
CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32);
__m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ),
conv_shift_arr_32 );
// Next 16 int4 elements selectors.
__m128i shift_idx_16;
MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16);
@@ -591,8 +587,7 @@ void packb_nr48_u8s8s32o32_row_major
__mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4.
const int64_t conv_shift_arr_16[2] = {
0x0807060504030201, 0x1B0F0E0D0C0B0A09 };
CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16);
__m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ),
conv_shift_arr_16 );
@@ -1027,9 +1022,7 @@ void packb_nr32_u8s8s32o32_row_major
__mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4.
const int64_t conv_shift_arr_32[4] = {
0x0807060504030201, 0x100F0E0D0C0B0A09, \
0X1817161514131211, 0X3B1F1E1D1C1B1A19 };
CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32);
__m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ),
conv_shift_arr_32 );
@@ -1293,8 +1286,7 @@ void packb_nr16_u8s8s32o32_row_major
__mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4.
const int64_t conv_shift_arr_16[2] = {
0x0807060504030201, 0x1B0F0E0D0C0B0A09 };
CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16);
__m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ),
conv_shift_arr_16 );
@@ -1547,7 +1539,6 @@ void packb_nrlt16_u8s8s32o32_row_major
}
else
{
if ( ( n0_partial_rem % 2 ) == 0 )
{
// An interesting property here is that n0_partial_rem is
@@ -1570,8 +1561,7 @@ void packb_nrlt16_u8s8s32o32_row_major
}
}
const int64_t conv_shift_arr_16[2] = {
0x0807060504030201, 0x1B0F0E0D0C0B0A09 };
CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16);
__m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ),
conv_shift_arr_16 );

View File

@@ -35,255 +35,7 @@
#ifndef LPGEMM_S32_PACK_MACROS_H
#define LPGEMM_S32_PACK_MACROS_H
/* shift_idx:__m512i*/
#define MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx) \
/* Multi shift uses indices that corresponds to the bit starting positions
* of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12,
* 16, 20, 24, 28. */ \
shift_idx = _mm512_set1_epi64( 0x1C1814100C080400lu );
/* shift_idx:__m256i*/
#define MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx) \
/* Multi shift uses indices that corresponds to the bit starting positions
* of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12,
* 16, 20, 24, 28. */ \
shift_idx = _mm256_maskz_set1_epi64( _cvtu32_mask8( 0xFF ), \
0x1C1814100C080400lu );
/* shift_idx:__m128i*/
#define MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx) \
/* Multi shift uses indices that corresponds to the bit starting positions
* of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12,
* 16, 20, 24, 28. */ \
shift_idx = _mm_maskz_set1_epi64( _cvtu32_mask8( 0xFF ), \
0x1C1814100C080400lu );
/* input:__m256i, output: __m512i*/
#define UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx) \
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). Unsigned conversion is
* used so as to ensure the signed bit in int4 at MSB position of 4
* byte group is not modified. */ \
output = _mm512_multishift_epi64_epi8( shift_idx, \
_mm512_cvtepu32_epi64( input ) ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm512_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm512_set1_epi8( 0x0F ) );
/* input:__m256i, output: __m512i*/
#define UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, \
output, odd_shift_idx, conv_shift) \
/* Unsigned conversion is used so as to ensure the signed bit.
* in int4 at MSB position of 4 byte group is not modified. */ \
__m512i upscale_input = _mm512_cvtepu32_epi64( input_0 ); \
__m512i shift_input = _mm512_cvtepu32_epi64( input_1 ); \
\
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). */ \
output = _mm512_multishift_epi64_epi8( odd_shift_idx, upscale_input ); \
\
/* Combine both the input registers, starting from elem[1] till elem[n-1]
* in output(without elem[0]), and first non zero element in shift_input.
* It is at this point that the first 4bit and last 4bit elements, the 2
* that were loaded extra due to byte level access are discarded. */ \
output = _mm512_permutex2var_epi8( output, conv_shift, shift_input ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm512_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm512_set1_epi8( 0x0F ) );
/* input:__m128i, output: __m256i*/
#define UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx) \
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). Unsigned conversion is
* used so as to ensure the signed bit in int4 at MSB position of 4
* byte group is not modified. */ \
output = _mm256_multishift_epi64_epi8( shift_idx, \
_mm256_cvtepu32_epi64( input ) ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm256_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm256_set1_epi8( 0x0F ) );
/* input:__m128i, output: __m256i*/
#define UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, \
output, odd_shift_idx, conv_shift) \
/* Unsigned conversion is used so as to ensure the signed bit.
* in int4 at MSB position of 4 byte group is not modified. */ \
__m256i upscale_input = _mm256_cvtepu32_epi64( input_0 ); \
__m256i shift_input = _mm256_cvtepu32_epi64( input_1 ); \
\
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). */ \
output = _mm256_multishift_epi64_epi8( odd_shift_idx, upscale_input ); \
\
/* Combine both the input registers, starting from elem[1] till elem[n-1]
* in output(without elem[0]), and first non zero element in shift_input.
* It is at this point that the first 4bit and last 4bit elements, the 2
* that were loaded extra due to byte level access are discarded. */ \
output = _mm256_permutex2var_epi8( output, conv_shift, shift_input ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm256_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm256_set1_epi8( 0x0F ) );
/* input:int64_t, output: __m128i*/
#define UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx) \
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). Unsigned conversion is
* used so as to ensure the signed bit in int4 at MSB position of 4
* byte group is not modified. */ \
output = _mm_multishift_epi64_epi8( shift_idx, \
_mm_cvtepu32_epi64( input ) ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm_set1_epi8( 0x0F ) );
/* input:int64_t, output:__m128i*/
#define UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, \
output, odd_shift_idx, conv_shift) \
/* Unsigned conversion is used so as to ensure the signed bit.
* in int4 at MSB position of 4 byte group is not modified. */ \
input_0 = _mm_cvtepu32_epi64( input_0 ); \
input_1 = _mm_cvtepu32_epi64( input_1 ); \
\
/* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit
* /8 bytes (containing 8 int8 elements). */ \
output = _mm_multishift_epi64_epi8( odd_shift_idx, input_0 ); \
\
/* Combine both the input registers, starting from elem[1] till elem[n-1]
* in output(without elem[0]), and first non zero element in shift_input.
* It is at this point that the first 4bit and last 4bit elements, the 2
* that were loaded extra due to byte level access are discarded. */ \
output = _mm_permutex2var_epi8( output, conv_shift, input_1 ); \
\
/* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \
output = _mm_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \
_mm_set1_epi8( 0x0F ) );
#define SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp) \
/* Comparison of signed bit in int4 and appending sign bits. */ \
/* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit
* is 1) to 1 and rest every other bits to 0. */ \
__m512i hi_bits_512 = _mm512_and_epi32( output, sign_comp ); \
\
/* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit
* is 0) to 1 and rest every other bits to 0. */ \
hi_bits_512 = _mm512_xor_epi32( hi_bits_512, sign_comp ); \
\
/* Set the sign extension bits on an int8_t size basis, this will then be
* OR with output to get the signed outputs. */ \
hi_bits_512 = _mm512_add_epi8( hi_bits_512, _mm512_set1_epi8( 0xF8 ) ); \
\
output = _mm512_or_epi32( output, hi_bits_512 );
#define SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp) \
/* Comparison of signed bit in int4 and appending sign bits. */ \
/* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit
* is 1) to 1 and rest every other bits to 0. */ \
__m256i hi_bits_256 = _mm256_maskz_and_epi32( _cvtu32_mask8( 0xFF ),\
output, sign_comp ); \
\
/* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit
* is 0) to 1 and rest every other bits to 0. */ \
hi_bits_256 = _mm256_xor_epi32( hi_bits_256, sign_comp ); \
\
/* Set the sign extension bits on an int8_t size basis, this will then be
* OR with output to get the signed outputs. */ \
hi_bits_256 = _mm256_add_epi8( hi_bits_256, _mm256_set1_epi8( 0xF8 ) ); \
\
output = _mm256_or_epi32( output, hi_bits_256 );
#define SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp) \
/* Comparison of signed bit in int4 and appending sign bits. */ \
/* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit
* is 1) to 1 and rest every other bits to 0. */ \
__m128i hi_bits_128 = _mm_maskz_and_epi32( _cvtu32_mask8( 0xFF ),\
output, sign_comp ); \
\
/* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit
* is 0) to 1 and rest every other bits to 0. */ \
hi_bits_128 = _mm_xor_epi32( hi_bits_128, sign_comp ); \
\
/* Set the sign extension bits on an int8_t size basis, this will then be
* OR with output to get the signed outputs. */ \
hi_bits_128 = _mm_add_epi8( hi_bits_128, _mm_set1_epi8( 0xF8 ) ); \
\
output = _mm_or_epi32( output, hi_bits_128 );
/* input:__m256i, output: __m512i*/
#define CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp); \
} \
} while (0);
/* input:__m256i, output: __m512i*/
#define CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp); \
} \
} while (0);
/* input:__m128i, output: __m256i*/
#define CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp); \
} \
} while (0);
/* input:__m128i, output: __m256i*/
#define CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp); \
} \
} while (0);
/* input:int64_t, output: __m128i*/
#define CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp); \
} \
} while (0);
/* input:int64_t, output: __m128i*/
#define CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift, sign_comp, signed_scale) \
do { \
UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, output, \
odd_shift_idx, conv_shift); \
\
if ( signed_scale == TRUE ) \
{ \
SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp); \
} \
} while (0);
#include "../int4_utils_avx512.h"
#define LOAD_16_COLS_AVX512 \
a_reg[0] = _mm512_loadu_si512(b + (ldb * (jr + 0)) + kr); \