diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_pack_bf16.h b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_pack_bf16.h index 1ceb83318..9acecc5eb 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_pack_bf16.h +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_pack_bf16.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -71,6 +71,17 @@ void packb_nr64_bf16bf16f32of32 dim_t* cs_p ); +void packb_nr64_bf16s4f32of32 + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + const dim_t NC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ); void packa_mr16_bf16bf16f32of32 ( diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_s4_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_s4_amd512vnni.c new file mode 100644 index 000000000..94b3080a9 --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_s4_amd512vnni.c @@ -0,0 +1,887 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "../int4_utils_avx512.h" + +void packb_nr64_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t NC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ); + +void packb_nr48_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ); + +void packb_nr32_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ); + +void packb_nr16_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ); + +void packb_nrlt16_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC, + const dim_t n0_partial_rem + ); + +void packb_nr64_bf16s4f32of32 + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + const dim_t NC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ) +{ + if (cs_b == 1) + { + packb_nr64_bf16s4f32of32_row_major(pack_b_buffer, + b, rs_b, NC, KC, rs_p, cs_p); + } + else + { + bli_print_msg("Only row major supported for int4 packing.", + __FILE__, __LINE__); + return; + } +} + +void packb_nr64_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t NC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ) +{ + dim_t NR = 64; + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + // KC when not multiple of 2 will have padding to make it multiple of 2 + // in packed buffer. + dim_t KC_updated = KC; + if ( k_partial_pieces > 0 ) + { + KC_updated += ( 2 - k_partial_pieces ); + } + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + bool signed_upscale = TRUE; + const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments. + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + // Selectors for int4 -> int8 conversion. + __m512i shift_idx_64; + MULTISHIFT_32BIT_8_INT4_IDX_64ELEM( shift_idx_64 ); + + __m512i sign_comp = _mm512_set1_epi8( 0x08 ); + __mmask32 hmask = _cvtu32_mask32(0xFFFFFFFF); // 32 bytes or 64 int4. + __mmask32 hmask_odd = _cvtu32_mask32(0x80000000); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_64ELEM_ODD_LD(conv_shift_arr); + __m512i conv_shift = _mm512_loadu_epi64(conv_shift_arr); + + // Selectors for int8 -> int4 conversion. + CREATE_CVT_INT8_INT4_PERM_IDX_64ELEM_2_ZMM_REG(even_idx_arr) + __m512i even_perm_idx = _mm512_loadu_si512( even_idx_arr ); + __m512i all_1s = _mm512_maskz_set1_epi8( _cvtu64_mask64( 0xFFFFFFFFFFFFFFFF ), 0x01 ); + __m512i odd_perm_idx = _mm512_add_epi8( even_perm_idx, all_1s ); + __m512i clear_hi_bits = _mm512_maskz_set1_epi8( _cvtu64_mask64( 0xFFFFFFFFFFFFFFFF ), 0x0F ); + + __m256i h_a0; + __m256i h_b0; + __m256i h_b0_l4bit; + + __m512i a0; + __m512i b0; + __m512i r_lo; + __m512i r_hi; + __m512i s4_out; + + for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) + { + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // Int4 array has to be accessed like byte array, but with + // half the elements traversed in the byte array. + h_a0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 0 ) ) + jc ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_a0, a0, shift_idx_64, \ + sign_comp, signed_upscale); + + // If the stride, i.e. rs_b is odd, then the stride increment + // (rs_b * ...)/2 will point at the byte of which the high 4 + // bits is our desired starting element. However since data + // access is at byte level, the low 4 bits of this byte will + // be wrongly included, and additionally the last int4 element + // won't be included either. Extra data movement done to + // account for the same. + // Since kr is a multiple of 2, only kr+1 will have the + // aforementioned issue. + if ( is_odd_stride == FALSE ) + { + h_b0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_b0, b0, shift_idx_64, \ + sign_comp, signed_upscale); + } + else + { + h_b0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / incr_adj_factor ) ); + // Only load the last byte/ 32nd byte. + h_b0_l4bit = _mm256_maskz_loadu_epi8( hmask_odd, + b + ( ( ( rs_b * ( kr + 1 ) ) + jc ) / incr_adj_factor ) + 1 ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(h_b0, h_b0_l4bit, b0, \ + shift_idx_64, conv_shift, sign_comp, signed_upscale); + } + + // Restructuring at int8 level. + r_lo = _mm512_unpacklo_epi8( a0, b0 ); + r_hi = _mm512_unpackhi_epi8( a0, b0 ); + + a0 = _mm512_permutex2var_epi64( r_lo, selector1, r_hi ); + b0 = _mm512_permutex2var_epi64( r_lo, selector1_1, r_hi ); + + // To be converted to int4 for storing. + CVT_INT8_INT4_64ELEM_2_ZMM_REG(a0, b0, s4_out, \ + even_perm_idx, odd_perm_idx, clear_hi_bits); + + // Int4 array has to be accessed like byte array, but with + // half the elements traversed in the byte array. + _mm512_storeu_si512( pack_b_buffer + + ( ( ( jc * KC_updated ) + ( kr * NR ) ) / incr_adj_factor ), + s4_out ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + h_a0 = _mm256_maskz_loadu_epi8( hmask, + b + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + jc ) / + incr_adj_factor ) ); + CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(h_a0, a0, shift_idx_64, \ + sign_comp, signed_upscale); + + b0 = _mm512_setzero_si512(); + + // Restructuring at int8 level. + r_lo = _mm512_unpacklo_epi8( a0, b0 ); + r_hi = _mm512_unpackhi_epi8( a0, b0 ); + + a0 = _mm512_permutex2var_epi64( r_lo, selector1, r_hi ); + b0 = _mm512_permutex2var_epi64( r_lo, selector1_1, r_hi ); + + // To be converted to int4 for storing. + CVT_INT8_INT4_64ELEM_2_ZMM_REG(a0, b0, s4_out, \ + even_perm_idx, odd_perm_idx, clear_hi_bits); + + _mm512_storeu_si512( pack_b_buffer + + ( ( ( jc * KC_updated ) + ( k_full_pieces * NR ) ) / + incr_adj_factor ), s4_out ); + } + } + + if(n_partial_pieces > 0) + { + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if ( n0_48 == 1 ) + { + packb_nr48_bf16s4f32of32_row_major + ( + ( pack_b_buffer + + ( ( n_full_pieces_loop_limit * KC_updated ) / + incr_adj_factor ) ), + ( b + ( n_full_pieces_loop_limit / incr_adj_factor ) ), + rs_b, KC + ); + + n0_partial_pack = 48; + } + else if ( n0_32 == 1 ) + { + packb_nr32_bf16s4f32of32_row_major + ( + ( pack_b_buffer + + ( ( n_full_pieces_loop_limit * KC_updated ) / + incr_adj_factor ) ), + ( b + ( n_full_pieces_loop_limit / incr_adj_factor ) ), + rs_b, KC + ); + + n0_partial_pack = 32; + } + else if ( n0_16 == 1 ) + { + packb_nr16_bf16s4f32of32_row_major + ( + ( pack_b_buffer + + ( ( n_full_pieces_loop_limit * KC_updated ) / + incr_adj_factor ) ), + ( b + ( n_full_pieces_loop_limit / incr_adj_factor ) ), + rs_b, KC + ); + + n0_partial_pack = 16; + } + + if ( n0_partial_rem > 0 ) + { + packb_nrlt16_bf16s4f32of32_row_major + ( + ( pack_b_buffer + ( ( ( n_full_pieces_loop_limit * KC_updated ) + + ( n0_partial_pack * KC_updated ) ) / incr_adj_factor ) ), + ( b + ( ( n_full_pieces_loop_limit + n0_partial_pack ) / + incr_adj_factor ) ), + rs_b, KC, n0_partial_rem + ); + } + } + *rs_p = NR * 2; + *cs_p = NR / 2; +} + +void packb_nr48_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ) +{ + const dim_t NR = 48; + const dim_t NR_32x2 = 64; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + bool signed_upscale = TRUE; + const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments. + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m256i selector1_32 = _mm256_setr_epi64x( 0x0, 0x1, 0x4, 0x5 ); + __m256i selector1_1_32 = _mm256_setr_epi64x( 0x2, 0x3, 0x6, 0x7 ); + + // Selectors for int4 -> int8 conversion. + // First 32 int4 elements selectors. + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + __mmask16 hmask_32 = _cvtu32_mask16( 0x0000FFFF ); //16 bytes or 32 int4. + __mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32); + __m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_32 ); + + // Next 16 int4 elements selectors. + __m128i shift_idx_16; + MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16); + + __m128i sign_comp_16 = _mm_set1_epi8( 0x08 ); + __mmask16 hmask_16 = _cvtu32_mask16( 0x000000FF ); //8 bytes or 16 int4. + __mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); + __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_16 ); + + // Selectors for int8 -> int4 conversion. + // First 32 int8 elements selectors. + CREATE_CVT_INT8_INT4_PERM_IDX_32ELEM_2_YMM_REG(even_idx_arr_32); + __m256i even_perm_idx_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ), + even_idx_arr_32 ); + __m256i all_1s_32 = _mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ), + 0x01 ); + __m256i odd_perm_idx_32 = _mm256_add_epi8( even_perm_idx_32, all_1s_32 ); + __m256i clear_hi_bits_32 = + _mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ), 0x0F ); + + // Next 16 int4 elements selectors. + CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(even_idx_arr_16); + __m128i even_perm_idx_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ), + even_idx_arr_16 ); + __m128i all_1s_16 = _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), + 0x01 ); + __m128i odd_perm_idx_16 = _mm_add_epi8( even_perm_idx_16, all_1s_16 ); + __m128i clear_hi_bits_16 = + _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), 0x0F ); + + __mmask16 sel_all_mask_16 = _cvtu32_mask16( 0xFFFF ); + + __m128i h_a0_32; + __m128i h_b0_32; + __m128i h_b0_32_l4bit; + __m128i a0_16; + __m128i b0_16; + __m128i r_lo_16; + __m128i r_hi_16; + __m128i s4_out_16; + __m256i a0_32; + __m256i b0_32; + __m256i r_lo_32; + __m256i r_hi_32; + __m256i s4_out_32; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // First 32 columns. + h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 0 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + // Last 16 columns. + h_a0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 0 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_32, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + if ( is_odd_stride == FALSE ) + { + // First 32 columns. + h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + // Last 16 columns. + h_b0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_32, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + // First 32 columns. + h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) ); + // Only load the last byte/ 16th byte. + h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, \ + b0_32, shift_idx_32, conv_shift_32, sign_comp_32, \ + signed_upscale); + + // Last 16 columns. + h_b0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) ); + // Only load the last byte/ 8th byte. + h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( ( rs_b * ( kr + 1 ) ) + 32 ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, \ + b0_16, shift_idx_16, conv_shift_16, sign_comp_16, \ + signed_upscale); + } + + // Restructuring at int8 level. + // First 32 columns. + r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 ); + b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 ); + + CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \ + even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32); + + _mm256_storeu_epi64( pack_b_buffer + + ( ( kr * NR ) / incr_adj_factor ), s4_out_32 ); + + // Last 16 columns. + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( ( kr * NR ) + NR_32x2 ) / incr_adj_factor ), s4_out_16 ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + // First 32 columns. + h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + b0_32 = _mm256_setzero_si256(); + + r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 ); + b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 ); + + CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \ + even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32); + + _mm256_storeu_epi64( pack_b_buffer + + ( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_32 ); + + // Last 16 columns. + h_a0_32 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( ( rs_b * ( k_full_pieces + 0 ) ) + 32 ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_32, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + b0_16 = _mm_setzero_si128(); + + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( ( k_full_pieces * NR ) + NR_32x2 ) / incr_adj_factor ), s4_out_16 ); + } +} + +void packb_nr32_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ) +{ + const dim_t NR = 32; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + bool signed_upscale = TRUE; + const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments. + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m256i selector1_32 = _mm256_setr_epi64x( 0x0, 0x1, 0x4, 0x5 ); + __m256i selector1_1_32 = _mm256_setr_epi64x( 0x2, 0x3, 0x6, 0x7 ); + + // Selectors for int4 -> int8 conversion. + __m256i shift_idx_32; + MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); + + __m256i sign_comp_32 = _mm256_set1_epi8( 0x08 ); + __mmask16 hmask_32 = _cvtu32_mask16( 0x0000FFFF ); //16 bytes or 32 int4. + __mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32); + __m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_32 ); + + // Selectors for int8 -> int4 conversion. + CREATE_CVT_INT8_INT4_PERM_IDX_32ELEM_2_YMM_REG(even_idx_arr_32); + __m256i even_perm_idx_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ), + even_idx_arr_32 ); + __m256i all_1s_32 = _mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ), + 0x01 ); + __m256i odd_perm_idx_32 = _mm256_add_epi8( even_perm_idx_32, all_1s_32 ); + __m256i clear_hi_bits_32 = + _mm256_maskz_set1_epi8( _cvtu32_mask32( 0xFFFFFFFF ), 0x0F ); + + __m128i h_a0_32; + __m128i h_b0_32; + __m128i h_b0_32_l4bit; + __m256i a0_32; + __m256i b0_32; + __m256i r_lo_32; + __m256i r_hi_32; + __m256i s4_out_32; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 0 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + + if ( is_odd_stride == FALSE ) + { + h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_b0_32, b0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + } + else + { + h_b0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) ); + // Only load the last byte/ 16th byte. + h_b0_32_l4bit = _mm_maskz_loadu_epi8( hmask_odd_32, + b + ( ( rs_b * ( kr + 1 ) ) / incr_adj_factor ) + 1 ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(h_b0_32, h_b0_32_l4bit, \ + b0_32, shift_idx_32, conv_shift_32, sign_comp_32, \ + signed_upscale); + } + + // Restructuring at int8 level. + // First 32 columns. + r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 ); + b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 ); + + CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \ + even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32); + + _mm256_storeu_epi64( pack_b_buffer + + ( ( kr * NR ) / incr_adj_factor ), s4_out_32 ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + h_a0_32 = _mm_maskz_loadu_epi8( hmask_32, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / incr_adj_factor ) ); + CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(h_a0_32, a0_32, shift_idx_32, \ + sign_comp_32, signed_upscale); + b0_32 = _mm256_setzero_si256(); + + r_lo_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + r_hi_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + a0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_32, r_hi_32 ); + b0_32 = _mm256_permutex2var_epi64( r_lo_32, selector1_1_32, r_hi_32 ); + + CVT_INT8_INT4_32ELEM_2_YMM_REG(a0_32, b0_32, s4_out_32, \ + even_perm_idx_32, odd_perm_idx_32, clear_hi_bits_32); + + _mm256_storeu_epi64( pack_b_buffer + + ( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_32 ); + } +} + +void packb_nr16_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC + ) +{ + const dim_t NR = 16; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + bool signed_upscale = TRUE; + const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments. + + // Selectors for int4 -> int8 conversion. + __m128i shift_idx_16; + MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16); + + __m128i sign_comp_16 = _mm_set1_epi8( 0x08 ); + __mmask16 hmask_16 = _cvtu32_mask16( 0x000000FF ); //8 bytes or 16 int4. + __mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4. + + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); + __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_16 ); + + // Selectors for int8 -> int4 conversion. + CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(even_idx_arr_16); + __m128i even_perm_idx_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ), + even_idx_arr_16 ); + __m128i all_1s_16 = _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), + 0x01 ); + __m128i odd_perm_idx_16 = _mm_add_epi8( even_perm_idx_16, all_1s_16 ); + __m128i clear_hi_bits_16 = + _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), 0x0F ); + + __mmask16 sel_all_mask_16 = _cvtu32_mask16( 0xFFFF ); + + __m128i h_a0_16; + __m128i h_b0_16; + __m128i h_b0_16_l4bit; + __m128i a0_16; + __m128i b0_16; + __m128i r_lo_16; + __m128i r_hi_16; + __m128i s4_out_16; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + if ( is_odd_stride == FALSE ) + { + h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + // Only load the last byte/ 8th byte. + h_b0_16_l4bit = _mm_maskz_loadu_epi8( hmask_odd_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) + 1 ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, \ + b0_16, shift_idx_16, conv_shift_16, sign_comp_16, \ + signed_upscale); + } + + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( kr * NR ) / incr_adj_factor ), s4_out_16 ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + b0_16 = _mm_setzero_si128(); + + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_16 ); + } +} + +void packb_nrlt16_bf16s4f32of32_row_major + ( + int8_t* pack_b_buffer, + const int8_t* b, + const dim_t rs_b, + const dim_t KC, + const dim_t n0_partial_rem + ) +{ + const dim_t NR = 16; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + bool is_odd_stride = ( ( rs_b % 2 ) == 0 ) ? FALSE : TRUE; + bool signed_upscale = TRUE; + const dim_t incr_adj_factor = 2; // (Byte / 2) for int4 increments. + + // Selectors for int4 -> int8 conversion. + __m128i shift_idx_16; + MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16); + + __m128i sign_comp_16 = _mm_set1_epi8( 0x08 ); + // 16 int4 elems in 8 bytes, so adjusting the mask for nr < 16 by + // a factor of 2. In case of odd remainder, the last int4 element + // within the last byte (hi 4 bits) will be ingnored similar to + // padding bits. + __mmask16 hmask_16; + if ( is_odd_stride == FALSE ) + { + hmask_16 = _cvtu32_mask16( 0x000000FF >> + ( ( 16 - n0_partial_rem ) / 2 ) ); + } + else + { + if ( ( n0_partial_rem % 2 ) == 0 ) + { + // An interesting property here is that n0_partial_rem is + // guaranteed to be < 16. In that case the largest even n0 + // rem would be 14, and the max number of bytes that will be + // loaded including the extra 4 bit at the beginning will + // only be 7 bytes out of 8. So in any case loading 1 more + // byte will bring the last int4 in the register, while not + // crossing the register boundaries. + hmask_16 = _cvtu32_mask16( 0x000000FF >> + ( ( ( 16 - n0_partial_rem ) / 2 ) - 1 ) ); + } + else + { + // If the n0 rem is odd, and if the starting position is an odd + // index, then the last odd element will also be loaded as part + // of loading the last byte (high 4 bits of last byte). + hmask_16 = _cvtu32_mask16( 0x000000FF >> + ( ( 16 - n0_partial_rem ) / 2 ) ); + } + } + + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); + __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), + conv_shift_arr_16 ); + + // Selectors for int8 -> int4 conversion. + CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(even_idx_arr_16); + __m128i even_perm_idx_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0xFF ), + even_idx_arr_16 ); + __m128i all_1s_16 = _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), + 0x01 ); + __m128i odd_perm_idx_16 = _mm_add_epi8( even_perm_idx_16, all_1s_16 ); + __m128i clear_hi_bits_16 = + _mm_maskz_set1_epi8( _cvtu32_mask16( 0xFFFF ), 0x0F ); + + __mmask16 sel_all_mask_16 = _cvtu32_mask16( 0xFFFF ); + + __m128i h_a0_16; + __m128i h_b0_16; + __m128i a0_16; + __m128i b0_16; + __m128i r_lo_16; + __m128i r_hi_16; + __m128i s4_out_16; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + + if ( is_odd_stride == FALSE ) + { + h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_b0_16, b0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + } + else + { + h_b0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( kr + 1 ) ) / 2 ) ); + // The last int4 elem is already loaded in the previous + // register. Details given in comments about hmask_16. + __m128i h_b0_16_l4bit = _mm_setzero_si128(); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(h_b0_16, h_b0_16_l4bit, \ + b0_16, shift_idx_16, conv_shift_16, sign_comp_16, \ + signed_upscale); + } + + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( kr * NR ) / incr_adj_factor ), s4_out_16 ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + h_a0_16 = _mm_maskz_loadu_epi8( hmask_16, + b + ( ( rs_b * ( k_full_pieces + 0 ) ) / 2 ) ); + CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(h_a0_16, a0_16, shift_idx_16, \ + sign_comp_16, signed_upscale); + b0_16 = _mm_setzero_si128(); + + r_lo_16 = _mm_maskz_unpacklo_epi8( sel_all_mask_16, a0_16, b0_16 ); + r_hi_16 = _mm_maskz_unpackhi_epi8( sel_all_mask_16, a0_16, b0_16 ); + + CVT_INT8_INT4_16ELEM_2_XMM_REG(r_lo_16, r_hi_16, s4_out_16, \ + even_perm_idx_16, odd_perm_idx_16, clear_hi_bits_16); + + _mm_storeu_epi64( pack_b_buffer + + ( ( k_full_pieces * NR ) / incr_adj_factor ), s4_out_16 ); + } +} + +#endif diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_m_kernel_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_m_kernel_bf16_amd512vnni.c index d6d2185e7..44adf9e96 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_m_kernel_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_m_kernel_bf16_amd512vnni.c @@ -479,17 +479,17 @@ LPGEMV_M_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32) // bf16 zero point value (scalar or vector). if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) { - zero_point0 = _mm512_cvtpbh_ps( - ( __m256bh )_mm256_maskz_set1_epi16( zp_mask, + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); - zero_point1 = _mm512_cvtpbh_ps( - ( __m256bh )_mm256_maskz_set1_epi16( zp_mask, + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); - zero_point2 = _mm512_cvtpbh_ps( - ( __m256bh )_mm256_maskz_set1_epi16( zp_mask, + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); - zero_point3 = _mm512_cvtpbh_ps( - ( __m256bh )_mm256_maskz_set1_epi16( zp_mask, + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); } @@ -518,20 +518,20 @@ LPGEMV_M_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32) if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) { - zero_point0 = _mm512_cvtpbh_ps( - ( __m256bh )_mm256_maskz_loadu_epi16( k1, + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k1, ( ( bfloat16* )post_ops_list_temp->op_args1 ) + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); - zero_point1 = _mm512_cvtpbh_ps( - ( __m256bh )_mm256_maskz_loadu_epi16( k2, + zero_point1 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k2, ( ( bfloat16* )post_ops_list_temp->op_args1 ) + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); - zero_point2 = _mm512_cvtpbh_ps( - ( __m256bh )_mm256_maskz_loadu_epi16( k3, + zero_point2 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k3, ( ( bfloat16* )post_ops_list_temp->op_args1 ) + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); - zero_point3 = _mm512_cvtpbh_ps( - ( __m256bh )_mm256_maskz_loadu_epi16( k4, + zero_point3 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k4, ( ( bfloat16* )post_ops_list_temp->op_args1 ) + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); } diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_n_kernel_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_n_kernel_bf16_amd512vnni.c index 081f957d1..4179eb181 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_n_kernel_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemv_n_kernel_bf16_amd512vnni.c @@ -672,8 +672,8 @@ LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32) // bf16 zero point value (scalar or vector). if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 ) { - zero_point0 = _mm512_cvtpbh_ps( - ( __m256bh )_mm256_maskz_set1_epi16( zp_mask, + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_set1_epi16( zp_mask, *( ( bfloat16* )post_ops_list_temp->op_args1 ) ) ); } @@ -701,8 +701,8 @@ LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32) if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 ) { - zero_point0 = _mm512_cvtpbh_ps( - ( __m256bh )_mm256_maskz_loadu_epi16( k2, + zero_point0 = CVT_BF16_F32_INT_SHIFT( + _mm256_maskz_loadu_epi16( k2, ( ( bfloat16* )post_ops_list_temp->op_args1 ) + post_ops_attr.post_op_c_i + 0 ) ); } diff --git a/kernels/zen4/lpgemm/int4_utils_avx512.h b/kernels/zen4/lpgemm/int4_utils_avx512.h new file mode 100644 index 000000000..5de056b8d --- /dev/null +++ b/kernels/zen4/lpgemm/int4_utils_avx512.h @@ -0,0 +1,397 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_INT4_CVT_UTILS_H +#define LPGEMM_INT4_CVT_UTILS_H + +/* shift_idx:__m512i*/ +#define MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx) \ + /* Multi shift uses indices that corresponds to the bit starting positions + * of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12, + * 16, 20, 24, 28. */ \ + shift_idx = _mm512_set1_epi64( 0x1C1814100C080400lu ); + +/* shift_idx:__m256i*/ +#define MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx) \ + /* Multi shift uses indices that corresponds to the bit starting positions + * of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12, + * 16, 20, 24, 28. */ \ + shift_idx = _mm256_maskz_set1_epi64( _cvtu32_mask8( 0xFF ), \ + 0x1C1814100C080400lu ); + +/* shift_idx:__m128i*/ +#define MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx) \ + /* Multi shift uses indices that corresponds to the bit starting positions + * of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12, + * 16, 20, 24, 28. */ \ + shift_idx = _mm_maskz_set1_epi64( _cvtu32_mask8( 0xFF ), \ + 0x1C1814100C080400lu ); + +/* input:__m256i, output: __m512i*/ +#define UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx) \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). Unsigned conversion is + * used so as to ensure the signed bit in int4 at MSB position of 4 + * byte group is not modified. */ \ + output = _mm512_multishift_epi64_epi8( shift_idx, \ + _mm512_cvtepu32_epi64( input ) ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm512_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm512_set1_epi8( 0x0F ) ); + +/* input:__m256i, output: __m512i*/ +#define UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, \ + output, odd_shift_idx, conv_shift) \ + /* Unsigned conversion is used so as to ensure the signed bit. + * in int4 at MSB position of 4 byte group is not modified. */ \ + __m512i upscale_input = _mm512_cvtepu32_epi64( input_0 ); \ + __m512i shift_input = _mm512_cvtepu32_epi64( input_1 ); \ + \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). */ \ + output = _mm512_multishift_epi64_epi8( odd_shift_idx, upscale_input ); \ + \ + /* Combine both the input registers, starting from elem[1] till elem[n-1] + * in output(without elem[0]), and first non zero element in shift_input. + * It is at this point that the first 4bit and last 4bit elements, the 2 + * that were loaded extra due to byte level access are discarded. */ \ + output = _mm512_permutex2var_epi8( output, conv_shift, shift_input ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm512_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm512_set1_epi8( 0x0F ) ); + +/* input:__m128i, output: __m256i*/ +#define UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx) \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). Unsigned conversion is + * used so as to ensure the signed bit in int4 at MSB position of 4 + * byte group is not modified. */ \ + output = _mm256_multishift_epi64_epi8( shift_idx, \ + _mm256_cvtepu32_epi64( input ) ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm256_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm256_set1_epi8( 0x0F ) ); + +/* input:__m128i, output: __m256i*/ +#define UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, \ + output, odd_shift_idx, conv_shift) \ + /* Unsigned conversion is used so as to ensure the signed bit. + * in int4 at MSB position of 4 byte group is not modified. */ \ + __m256i upscale_input = _mm256_cvtepu32_epi64( input_0 ); \ + __m256i shift_input = _mm256_cvtepu32_epi64( input_1 ); \ + \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). */ \ + output = _mm256_multishift_epi64_epi8( odd_shift_idx, upscale_input ); \ + \ + /* Combine both the input registers, starting from elem[1] till elem[n-1] + * in output(without elem[0]), and first non zero element in shift_input. + * It is at this point that the first 4bit and last 4bit elements, the 2 + * that were loaded extra due to byte level access are discarded. */ \ + output = _mm256_permutex2var_epi8( output, conv_shift, shift_input ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm256_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm256_set1_epi8( 0x0F ) ); + +/* input:int64_t, output: __m128i*/ +#define UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx) \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). Unsigned conversion is + * used so as to ensure the signed bit in int4 at MSB position of 4 + * byte group is not modified. */ \ + output = _mm_multishift_epi64_epi8( shift_idx, \ + _mm_cvtepu32_epi64( input ) ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm_set1_epi8( 0x0F ) ); + +/* input:int64_t, output:__m128i*/ +#define UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, \ + output, odd_shift_idx, conv_shift) \ + /* Unsigned conversion is used so as to ensure the signed bit. + * in int4 at MSB position of 4 byte group is not modified. */ \ + input_0 = _mm_cvtepu32_epi64( input_0 ); \ + input_1 = _mm_cvtepu32_epi64( input_1 ); \ + \ + /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit + * /8 bytes (containing 8 int8 elements). */ \ + output = _mm_multishift_epi64_epi8( odd_shift_idx, input_0 ); \ + \ + /* Combine both the input registers, starting from elem[1] till elem[n-1] + * in output(without elem[0]), and first non zero element in shift_input. + * It is at this point that the first 4bit and last 4bit elements, the 2 + * that were loaded extra due to byte level access are discarded. */ \ + output = _mm_permutex2var_epi8( output, conv_shift, input_1 ); \ + \ + /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ + output = _mm_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ + _mm_set1_epi8( 0x0F ) ); + +#define SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp) \ + /* Comparison of signed bit in int4 and appending sign bits. */ \ + /* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit + * is 1) to 1 and rest every other bits to 0. */ \ + __m512i hi_bits_512 = _mm512_and_epi32( output, sign_comp ); \ + \ + /* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit + * is 0) to 1 and rest every other bits to 0. */ \ + hi_bits_512 = _mm512_xor_epi32( hi_bits_512, sign_comp ); \ + \ + /* Set the sign extension bits on an int8_t size basis, this will then be + * OR with output to get the signed outputs. */ \ + hi_bits_512 = _mm512_add_epi8( hi_bits_512, _mm512_set1_epi8( 0xF8 ) ); \ + \ + output = _mm512_or_epi32( output, hi_bits_512 ); + +#define SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp) \ + /* Comparison of signed bit in int4 and appending sign bits. */ \ + /* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit + * is 1) to 1 and rest every other bits to 0. */ \ + __m256i hi_bits_256 = _mm256_maskz_and_epi32( _cvtu32_mask8( 0xFF ),\ + output, sign_comp ); \ + \ + /* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit + * is 0) to 1 and rest every other bits to 0. */ \ + hi_bits_256 = _mm256_xor_epi32( hi_bits_256, sign_comp ); \ + \ + /* Set the sign extension bits on an int8_t size basis, this will then be + * OR with output to get the signed outputs. */ \ + hi_bits_256 = _mm256_add_epi8( hi_bits_256, _mm256_set1_epi8( 0xF8 ) ); \ + \ + output = _mm256_or_epi32( output, hi_bits_256 ); + +#define SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp) \ + /* Comparison of signed bit in int4 and appending sign bits. */ \ + /* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit + * is 1) to 1 and rest every other bits to 0. */ \ + __m128i hi_bits_128 = _mm_maskz_and_epi32( _cvtu32_mask8( 0xFF ),\ + output, sign_comp ); \ + \ + /* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit + * is 0) to 1 and rest every other bits to 0. */ \ + hi_bits_128 = _mm_xor_epi32( hi_bits_128, sign_comp ); \ + \ + /* Set the sign extension bits on an int8_t size basis, this will then be + * OR with output to get the signed outputs. */ \ + hi_bits_128 = _mm_add_epi8( hi_bits_128, _mm_set1_epi8( 0xF8 ) ); \ + \ + output = _mm_or_epi32( output, hi_bits_128 ); + +/* input:__m256i, output: __m512i*/ +#define CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp); \ + } \ +} while (0); + +#define CREATE_CVT_INT4_INT8_PERM_IDX_64ELEM_ODD_LD(var_name) \ + const int64_t var_name[8] = { \ + 0x0807060504030201, 0x100F0E0D0C0B0A09, \ + 0X1817161514131211, 0X201F1E1D1C1B1A19, \ + 0X2827262524232221, 0X302F2E2D2C2B2A29, \ + 0X3837363534333231, 0X7B3F3E3D3C3B3A39 }; + +/* input:__m256i, output: __m512i*/ +#define CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp); \ + } \ +} while (0); + +/* input:__m128i, output: __m256i*/ +#define CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp); \ + } \ +} while (0); + +#define CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(var_name) \ + const int64_t var_name[4] = { \ + 0x0807060504030201, 0x100F0E0D0C0B0A09, \ + 0X1817161514131211, 0X3B1F1E1D1C1B1A19 }; + +/* input:__m128i, output: __m256i*/ +#define CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp); \ + } \ +} while (0); + +/* input:int64_t, output: __m128i*/ +#define CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp); \ + } \ +} while (0); + +#define CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(var_name) \ + const int64_t var_name[2] = { \ + 0x0807060504030201, 0x1B0F0E0D0C0B0A09 }; + +/* input:int64_t, output: __m128i*/ +#define CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift, sign_comp, signed_scale) \ +do { \ + UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ + odd_shift_idx, conv_shift); \ + \ + if ( signed_scale == TRUE ) \ + { \ + SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp); \ + } \ +} while (0); + +#define CREATE_CVT_INT8_INT4_PERM_IDX_64ELEM_2_ZMM_REG(var_name) \ + int8_t var_name[64] __attribute__((aligned(64))) = \ + {0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, \ + 0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E, \ + 0x20, 0x22, 0x24, 0x26, 0x28, 0x2A, 0x2C, 0x2E, \ + 0x30, 0x32, 0x34, 0x36, 0x38, 0x3A, 0x3C, 0x3E, \ + 0x40, 0x42, 0x44, 0x46, 0x48, 0x4A, 0x4C, 0x4E, \ + 0x50, 0x52, 0x54, 0x56, 0x58, 0x5A, 0x5C, 0x5E, \ + 0x60, 0x62, 0x64, 0x66, 0x68, 0x6A, 0x6C, 0x6E, \ + 0x70, 0x72, 0x74, 0x76, 0x78, 0x7A, 0x7C, 0x7E}; + +/* Conversion from int8 to int4. First split the elements in __m512i + * register at even indices and odd indices into two separate __m256i + * even and odd registers. Then shift the elements in odd by 4 to the + * left and OR with even register. */ +/* input_*:__m512i, output: __m512i */ +#define CVT_INT8_INT4_64ELEM_2_ZMM_REG(input_0, input_1, output, \ + even_perm_idx, odd_perm_idx, clear_hi_bits) \ +do { \ + output = _mm512_permutex2var_epi8( input_0, even_perm_idx, input_1 ); \ + __m512i odd_out = _mm512_permutex2var_epi8( input_0, \ + odd_perm_idx, input_1 ); \ + \ + /* Ensure the hi 4 bits are cleared. */ \ + output = _mm512_and_epi32( output, clear_hi_bits ); \ + \ + __m256i odd1_256 = _mm512_extracti64x4_epi64( odd_out, 0x0 ); \ + __m256i odd2_256 = _mm512_extracti64x4_epi64( odd_out, 0x1 ); \ + \ + /* Shift the elemts in odd register by 4 to the left. */ \ + odd1_256 = _mm512_cvtepi16_epi8( \ + _mm512_slli_epi16( _mm512_cvtepu8_epi16( odd1_256 ), 0x4 ) ); \ + odd2_256 = _mm512_cvtepi16_epi8( \ + _mm512_slli_epi16( _mm512_cvtepu8_epi16( odd2_256 ), 0x4 ) ); \ + \ + odd_out = _mm512_castsi256_si512( odd1_256 ); \ + odd_out = _mm512_inserti64x4( odd_out, odd2_256, 0x01 ); \ + \ + output = _mm512_or_epi32( output, odd_out ); \ +} while (0); + +#define CREATE_CVT_INT8_INT4_PERM_IDX_32ELEM_2_YMM_REG(var_name) \ + int8_t var_name[32] __attribute__((aligned(64))) = \ + {0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, \ + 0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E, \ + 0x20, 0x22, 0x24, 0x26, 0x28, 0x2A, 0x2C, 0x2E, \ + 0x30, 0x32, 0x34, 0x36, 0x38, 0x3A, 0x3C, 0x3E}; + +/* input_*:__m256i, output: __m256i */ +#define CVT_INT8_INT4_32ELEM_2_YMM_REG(input_0, input_1, output, \ + even_perm_idx, odd_perm_idx, clear_hi_bits) \ +do { \ + output = _mm256_permutex2var_epi8( input_0, even_perm_idx, input_1 ); \ + __m256i odd_out = _mm256_permutex2var_epi8( input_0, \ + odd_perm_idx, input_1 ); \ + \ + /* Ensure the hi 4 bits are cleared. */ \ + output = _mm256_maskz_and_epi32( _cvtu32_mask8( 0xFF ), \ + output, clear_hi_bits ); \ + \ + /* Shift the elemts in odd register by 4 to the left. */ \ + odd_out = _mm512_cvtepi16_epi8( \ + _mm512_slli_epi16( _mm512_cvtepu8_epi16( odd_out ), 0x4 ) ); \ + \ + output = _mm256_or_epi32( output, odd_out ); \ +} while (0); + +#define CREATE_CVT_INT8_INT4_PERM_IDX_16ELEM_2_XMM_REG(var_name) \ + int8_t var_name[16] __attribute__((aligned(64))) = \ + {0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, \ + 0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E}; + +/* input_*:__m128i, output: __m128i */ +#define CVT_INT8_INT4_16ELEM_2_XMM_REG(input_0, input_1, output, \ + even_perm_idx, odd_perm_idx, clear_hi_bits) \ +do { \ + output = _mm_permutex2var_epi8( input_0, even_perm_idx, input_1 ); \ + __m128i odd_out = _mm_permutex2var_epi8( input_0, \ + odd_perm_idx, input_1 ); \ + \ + /* Ensure the hi 4 bits are cleared. */ \ + output = _mm_maskz_and_epi32( _cvtu32_mask8( 0xFF ), \ + output, clear_hi_bits ); \ + \ + /* Shift the elemts in odd register by 4 to the left. */ \ + __mmask16 sel_all_mask = _cvtu32_mask16( 0xFFFF ); \ + odd_out = _mm256_maskz_cvtepi16_epi8( sel_all_mask, \ + _mm256_maskz_slli_epi16( sel_all_mask, \ + _mm256_maskz_cvtepu8_epi16( sel_all_mask, odd_out ), 0x4 ) ); \ + \ + output = _mm_or_epi32( output, odd_out ); \ +} while (0); + +#endif //LPGEMM_INT4_CVT_UTILS_H diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c index aa6a10926..0a87245c9 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c @@ -231,11 +231,7 @@ void packb_nr64_u8s8s32o32_row_major __mmask32 hmask = _cvtu32_mask32(0xFFFFFFFF); // 32 bytes or 64 int4. __mmask32 hmask_odd = _cvtu32_mask32(0x80000000); // Last 1 int4. - const int64_t conv_shift_arr[8] = { - 0x0807060504030201, 0x100F0E0D0C0B0A09, \ - 0X1817161514131211, 0X201F1E1D1C1B1A19, \ - 0X2827262524232221, 0X302F2E2D2C2B2A29, \ - 0X3837363534333231, 0X7B3F3E3D3C3B3A39 }; + CREATE_CVT_INT4_INT8_PERM_IDX_64ELEM_ODD_LD(conv_shift_arr); __m512i conv_shift = _mm512_loadu_epi64(conv_shift_arr); for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) @@ -569,6 +565,7 @@ void packb_nr48_u8s8s32o32_row_major __m128i a01_16; __m128i c01_16; + // First 32 int4 elements selectors. __m256i shift_idx_32; MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx_32); @@ -577,12 +574,11 @@ void packb_nr48_u8s8s32o32_row_major __mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4. - const int64_t conv_shift_arr_32[4] = { - 0x0807060504030201, 0x100F0E0D0C0B0A09, \ - 0X1817161514131211, 0X3B1F1E1D1C1B1A19 }; + CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32); __m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), conv_shift_arr_32 ); + // Next 16 int4 elements selectors. __m128i shift_idx_16; MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx_16); @@ -591,8 +587,7 @@ void packb_nr48_u8s8s32o32_row_major __mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4. - const int64_t conv_shift_arr_16[2] = { - 0x0807060504030201, 0x1B0F0E0D0C0B0A09 }; + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), conv_shift_arr_16 ); @@ -1027,9 +1022,7 @@ void packb_nr32_u8s8s32o32_row_major __mmask16 hmask_odd_32 = _cvtu32_mask16( 0x00008000 ); // Last 1 int4. - const int64_t conv_shift_arr_32[4] = { - 0x0807060504030201, 0x100F0E0D0C0B0A09, \ - 0X1817161514131211, 0X3B1F1E1D1C1B1A19 }; + CREATE_CVT_INT4_INT8_PERM_IDX_32ELEM_ODD_LD(conv_shift_arr_32); __m256i conv_shift_32 = _mm256_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), conv_shift_arr_32 ); @@ -1293,8 +1286,7 @@ void packb_nr16_u8s8s32o32_row_major __mmask16 hmask_odd_16 = _cvtu32_mask16( 0x00000080 ); // Last 1 int4. - const int64_t conv_shift_arr_16[2] = { - 0x0807060504030201, 0x1B0F0E0D0C0B0A09 }; + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), conv_shift_arr_16 ); @@ -1547,7 +1539,6 @@ void packb_nrlt16_u8s8s32o32_row_major } else { - if ( ( n0_partial_rem % 2 ) == 0 ) { // An interesting property here is that n0_partial_rem is @@ -1570,8 +1561,7 @@ void packb_nrlt16_u8s8s32o32_row_major } } - const int64_t conv_shift_arr_16[2] = { - 0x0807060504030201, 0x1B0F0E0D0C0B0A09 }; + CREATE_CVT_INT4_INT8_PERM_IDX_16ELEM_ODD_LD(conv_shift_arr_16); __m128i conv_shift_16 = _mm_maskz_loadu_epi64( _cvtu32_mask8( 0X000000FF ), conv_shift_arr_16 ); diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_pack_macros.h b/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_pack_macros.h index f4d2ca61f..1849a8cca 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_pack_macros.h +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_pack_macros.h @@ -35,255 +35,7 @@ #ifndef LPGEMM_S32_PACK_MACROS_H #define LPGEMM_S32_PACK_MACROS_H -/* shift_idx:__m512i*/ -#define MULTISHIFT_32BIT_8_INT4_IDX_64ELEM(shift_idx) \ - /* Multi shift uses indices that corresponds to the bit starting positions - * of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12, - * 16, 20, 24, 28. */ \ - shift_idx = _mm512_set1_epi64( 0x1C1814100C080400lu ); - -/* shift_idx:__m256i*/ -#define MULTISHIFT_32BIT_8_INT4_IDX_32ELEM(shift_idx) \ - /* Multi shift uses indices that corresponds to the bit starting positions - * of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12, - * 16, 20, 24, 28. */ \ - shift_idx = _mm256_maskz_set1_epi64( _cvtu32_mask8( 0xFF ), \ - 0x1C1814100C080400lu ); - -/* shift_idx:__m128i*/ -#define MULTISHIFT_32BIT_8_INT4_IDX_16ELEM(shift_idx) \ - /* Multi shift uses indices that corresponds to the bit starting positions - * of each of the 8 int4 elements in a given 32 bits, which is 0, 4, 8, 12, - * 16, 20, 24, 28. */ \ - shift_idx = _mm_maskz_set1_epi64( _cvtu32_mask8( 0xFF ), \ - 0x1C1814100C080400lu ); - -/* input:__m256i, output: __m512i*/ -#define UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx) \ - /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit - * /8 bytes (containing 8 int8 elements). Unsigned conversion is - * used so as to ensure the signed bit in int4 at MSB position of 4 - * byte group is not modified. */ \ - output = _mm512_multishift_epi64_epi8( shift_idx, \ - _mm512_cvtepu32_epi64( input ) ); \ - \ - /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ - output = _mm512_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ - _mm512_set1_epi8( 0x0F ) ); - -/* input:__m256i, output: __m512i*/ -#define UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, \ - output, odd_shift_idx, conv_shift) \ - /* Unsigned conversion is used so as to ensure the signed bit. - * in int4 at MSB position of 4 byte group is not modified. */ \ - __m512i upscale_input = _mm512_cvtepu32_epi64( input_0 ); \ - __m512i shift_input = _mm512_cvtepu32_epi64( input_1 ); \ - \ - /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit - * /8 bytes (containing 8 int8 elements). */ \ - output = _mm512_multishift_epi64_epi8( odd_shift_idx, upscale_input ); \ - \ - /* Combine both the input registers, starting from elem[1] till elem[n-1] - * in output(without elem[0]), and first non zero element in shift_input. - * It is at this point that the first 4bit and last 4bit elements, the 2 - * that were loaded extra due to byte level access are discarded. */ \ - output = _mm512_permutex2var_epi8( output, conv_shift, shift_input ); \ - \ - /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ - output = _mm512_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ - _mm512_set1_epi8( 0x0F ) ); - -/* input:__m128i, output: __m256i*/ -#define UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx) \ - /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit - * /8 bytes (containing 8 int8 elements). Unsigned conversion is - * used so as to ensure the signed bit in int4 at MSB position of 4 - * byte group is not modified. */ \ - output = _mm256_multishift_epi64_epi8( shift_idx, \ - _mm256_cvtepu32_epi64( input ) ); \ - \ - /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ - output = _mm256_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ - _mm256_set1_epi8( 0x0F ) ); - -/* input:__m128i, output: __m256i*/ -#define UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, \ - output, odd_shift_idx, conv_shift) \ - /* Unsigned conversion is used so as to ensure the signed bit. - * in int4 at MSB position of 4 byte group is not modified. */ \ - __m256i upscale_input = _mm256_cvtepu32_epi64( input_0 ); \ - __m256i shift_input = _mm256_cvtepu32_epi64( input_1 ); \ - \ - /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit - * /8 bytes (containing 8 int8 elements). */ \ - output = _mm256_multishift_epi64_epi8( odd_shift_idx, upscale_input ); \ - \ - /* Combine both the input registers, starting from elem[1] till elem[n-1] - * in output(without elem[0]), and first non zero element in shift_input. - * It is at this point that the first 4bit and last 4bit elements, the 2 - * that were loaded extra due to byte level access are discarded. */ \ - output = _mm256_permutex2var_epi8( output, conv_shift, shift_input ); \ - \ - /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ - output = _mm256_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ - _mm256_set1_epi8( 0x0F ) ); - -/* input:int64_t, output: __m128i*/ -#define UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx) \ - /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit - * /8 bytes (containing 8 int8 elements). Unsigned conversion is - * used so as to ensure the signed bit in int4 at MSB position of 4 - * byte group is not modified. */ \ - output = _mm_multishift_epi64_epi8( shift_idx, \ - _mm_cvtepu32_epi64( input ) ); \ - \ - /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ - output = _mm_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ - _mm_set1_epi8( 0x0F ) ); - -/* input:int64_t, output:__m128i*/ -#define UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, \ - output, odd_shift_idx, conv_shift) \ - /* Unsigned conversion is used so as to ensure the signed bit. - * in int4 at MSB position of 4 byte group is not modified. */ \ - input_0 = _mm_cvtepu32_epi64( input_0 ); \ - input_1 = _mm_cvtepu32_epi64( input_1 ); \ - \ - /* Upscale 32 bits/4 bytes (containing 8 int4 elements) into 64 bit - * /8 bytes (containing 8 int8 elements). */ \ - output = _mm_multishift_epi64_epi8( odd_shift_idx, input_0 ); \ - \ - /* Combine both the input registers, starting from elem[1] till elem[n-1] - * in output(without elem[0]), and first non zero element in shift_input. - * It is at this point that the first 4bit and last 4bit elements, the 2 - * that were loaded extra due to byte level access are discarded. */ \ - output = _mm_permutex2var_epi8( output, conv_shift, input_1 ); \ - \ - /* The upper 4 bits of each converted int8 element is junk, zeroing it. */ \ - output = _mm_maskz_and_epi64( _cvtu32_mask8( 0xFF ), output, \ - _mm_set1_epi8( 0x0F ) ); - -#define SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp) \ - /* Comparison of signed bit in int4 and appending sign bits. */ \ - /* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit - * is 1) to 1 and rest every other bits to 0. */ \ - __m512i hi_bits_512 = _mm512_and_epi32( output, sign_comp ); \ - \ - /* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit - * is 0) to 1 and rest every other bits to 0. */ \ - hi_bits_512 = _mm512_xor_epi32( hi_bits_512, sign_comp ); \ - \ - /* Set the sign extension bits on an int8_t size basis, this will then be - * OR with output to get the signed outputs. */ \ - hi_bits_512 = _mm512_add_epi8( hi_bits_512, _mm512_set1_epi8( 0xF8 ) ); \ - \ - output = _mm512_or_epi32( output, hi_bits_512 ); - -#define SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp) \ - /* Comparison of signed bit in int4 and appending sign bits. */ \ - /* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit - * is 1) to 1 and rest every other bits to 0. */ \ - __m256i hi_bits_256 = _mm256_maskz_and_epi32( _cvtu32_mask8( 0xFF ),\ - output, sign_comp ); \ - \ - /* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit - * is 0) to 1 and rest every other bits to 0. */ \ - hi_bits_256 = _mm256_xor_epi32( hi_bits_256, sign_comp ); \ - \ - /* Set the sign extension bits on an int8_t size basis, this will then be - * OR with output to get the signed outputs. */ \ - hi_bits_256 = _mm256_add_epi8( hi_bits_256, _mm256_set1_epi8( 0xF8 ) ); \ - \ - output = _mm256_or_epi32( output, hi_bits_256 ); - -#define SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp) \ - /* Comparison of signed bit in int4 and appending sign bits. */ \ - /* Set 4th bit (bit[3]/MSB/sign bit) of negative int4 values (signed bit - * is 1) to 1 and rest every other bits to 0. */ \ - __m128i hi_bits_128 = _mm_maskz_and_epi32( _cvtu32_mask8( 0xFF ),\ - output, sign_comp ); \ - \ - /* Set 4th bit (bit[3]/MSB/sign bit) of positive int4 values (signed bit - * is 0) to 1 and rest every other bits to 0. */ \ - hi_bits_128 = _mm_xor_epi32( hi_bits_128, sign_comp ); \ - \ - /* Set the sign extension bits on an int8_t size basis, this will then be - * OR with output to get the signed outputs. */ \ - hi_bits_128 = _mm_add_epi8( hi_bits_128, _mm_set1_epi8( 0xF8 ) ); \ - \ - output = _mm_or_epi32( output, hi_bits_128 ); - -/* input:__m256i, output: __m512i*/ -#define CVT_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \ -do { \ - UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT(input, output, shift_idx); \ - \ - if ( signed_scale == TRUE ) \ - { \ - SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp); \ - } \ -} while (0); - -/* input:__m256i, output: __m512i*/ -#define CVT_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ - odd_shift_idx, conv_shift, sign_comp, signed_scale) \ -do { \ - UPSCALE_INT4_TO_INT8_64ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ - odd_shift_idx, conv_shift); \ - \ - if ( signed_scale == TRUE ) \ - { \ - SIGN_EXTEND_BITWISE_OPS_64ELEM(output, sign_comp); \ - } \ -} while (0); - -/* input:__m128i, output: __m256i*/ -#define CVT_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \ -do { \ - UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT(input, output, shift_idx); \ - \ - if ( signed_scale == TRUE ) \ - { \ - SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp); \ - } \ -} while (0); - -/* input:__m128i, output: __m256i*/ -#define CVT_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ - odd_shift_idx, conv_shift, sign_comp, signed_scale) \ -do { \ - UPSCALE_INT4_TO_INT8_32ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ - odd_shift_idx, conv_shift); \ - \ - if ( signed_scale == TRUE ) \ - { \ - SIGN_EXTEND_BITWISE_OPS_32ELEM(output, sign_comp); \ - } \ -} while (0); - -/* input:int64_t, output: __m128i*/ -#define CVT_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx, sign_comp, signed_scale) \ -do { \ - UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT(input, output, shift_idx); \ - \ - if ( signed_scale == TRUE ) \ - { \ - SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp); \ - } \ -} while (0); - -/* input:int64_t, output: __m128i*/ -#define CVT_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ - odd_shift_idx, conv_shift, sign_comp, signed_scale) \ -do { \ - UPSCALE_INT4_TO_INT8_16ELEM_MULTISHIFT_ODD(input_0, input_1, output, \ - odd_shift_idx, conv_shift); \ - \ - if ( signed_scale == TRUE ) \ - { \ - SIGN_EXTEND_BITWISE_OPS_16ELEM(output, sign_comp); \ - } \ -} while (0); +#include "../int4_utils_avx512.h" #define LOAD_16_COLS_AVX512 \ a_reg[0] = _mm512_loadu_si512(b + (ldb * (jr + 0)) + kr); \