From 2ce47e6f5e28ca89bff7a853ade3cba55af8bc71 Mon Sep 17 00:00:00 2001 From: Bhaskar Nallani Date: Fri, 2 Feb 2024 07:05:20 +0530 Subject: [PATCH] Implemented optimal AVX512-variant of f32 LPGEMV 1. The 5 LOOP LPGEMM path is in-efficient when A or B is a vector (i.e, m == 1 or n == 1). 2. An efficient implementation of lpgemv_rowvar_f32 is developed considering the b matrix reorder in case of m=1 and post-ops fusion. 3. When m = 1 the algorithm divide the GEMM workload in n dimension intelligently at a granularity of NR. Each thread work on A:1xk B:kx(>=NR) and produce C=1x(>NR). K is unrolled by 4 along with remainder loop. 4. When n = 1 the algorithm divide the GEMM workload in m dimension intelligently at a granularity of MR. Each thread work on A:(>=MR)xk B:kx1 and produce C = (>=MR)x1. When n=1 reordering of B is avoided to efficiently process in n one kernel. 5. Fixed few warnings while loading 2 f32 bias elements using _mm_load_sd using float pointer. Typecasted to (const double *) AMD-Internal: [SWLCSG-2391, SWLCSG-2353] Change-Id: If1d0b8d59e0278f5f16b499de1d629e63da5b599 --- .../aocl_gemm/aocl_gemm_f32f32f32of32_utils.c | 29 +- .../frame/f32f32f32/lpgemm_f32f32f32.c | 139 ++++- .../frame/lpgemm_5loop_interface_apis.h | 33 +- addon/aocl_gemm/kernels/lpgemm_kernels.h | 50 +- bench/bench_aocl_gemm/bench_lpgemm.c | 8 +- .../lpgemm/f32f32f32/lpgemm_fringe_f32_avx2.c | 25 +- .../f32f32f32/lpgemm_kernel_macros_f32_avx2.h | 4 +- .../f32f32f32/lpgemm_m_kernel_f32_avx2.c | 5 +- .../f32f32f32/lpgemv_m_kernel_f32_avx2.c | 72 +++ .../f32f32f32/lpgemv_n_kernel_f32_avx2.c | 75 +++ .../f32f32f32/lpgemm_kernel_macros_f32.h | 7 + .../f32f32f32/lpgemv_m_kernel_f32_avx512.c | 424 ++++++++++++++ .../f32f32f32/lpgemv_n_kernel_f32_avx512.c | 517 ++++++++++++++++++ 13 files changed, 1360 insertions(+), 28 deletions(-) create mode 100644 kernels/zen/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx2.c create mode 100644 kernels/zen/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx2.c create mode 100644 kernels/zen4/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx512.c create mode 100644 kernels/zen4/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx512.c diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c index 3b801ce0d..644e28dc7 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -74,7 +74,15 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32) const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); // Extra space since packing does width in multiples of NR. - const dim_t n_reorder = ( ( n + NR - 1 ) / NR ) * NR; + dim_t n_reorder; + if(n == 1) + { + //When n == 1, LPGEMV doesn't expect B to be reordered. + n_reorder = 1; + }else + { + n_reorder = ( ( n + NR - 1 ) / NR ) * NR; + } siz_t size_req = sizeof( float ) * k * n_reorder; @@ -144,6 +152,23 @@ AOCL_GEMM_REORDER(float,f32f32f32of32) dim_t n_threads = bli_rntm_num_threads( &rntm_g ); n_threads = ( n_threads > 0 ) ? n_threads : 1; + //When n == 1, B marix becomes a vector. + //Reordering is avoided so that LPGEMV can process it efficiently. + if(n == 1) + { + if(ldb == 1) + { + memcpy(reorder_buf_addr, input_buf_addr, (k * sizeof(BLIS_FLOAT))); + }else + { + for(dim_t k0 = 0; k0 < k; k0++) + { + reorder_buf_addr[k0] = input_buf_addr[k0*ldb]; + } + } + return; + } + #ifdef BLIS_ENABLE_OPENMP _Pragma( "omp parallel num_threads(n_threads)" ) { diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c index 61e8cf865..11a83204f 100644 --- a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -87,8 +87,139 @@ void lpgemm_pack_b_f32f32f32of32 cntx_t* cntx ); -LPGEMM_5LOOP(float,float,float,f32f32f32of32) +#ifdef BLIS_KERNELS_ZEN4 +LPGEMV(float, float, float, f32f32f32of32) { + cntx_t *cntx = bli_gks_query_cntx(); + num_t dt = BLIS_FLOAT; + + // Query the context for various blocksizes. + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_NR, cntx); + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_NC, cntx); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_KC, cntx); + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_KC, cntx); + + // Strides are updated based on matrix packing/reordering. + float *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr; + post_ops_attr.c_stor_type = c_downscale; + if (c_downscale < F32) post_ops_attr.buf_downscale = c; + else post_ops_attr.buf_downscale = NULL; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + lpgemm_gen_thrinfo(thread, &thread_jc, &thread_ic); + + if(n == 1) + { + //TODO: AVX2 support need to be added + // Increased MR from 6 to 16 to make use of 32 ZMM registers + dim_t MR = 16; + + // Compute the IC loop thread range for the current thread. + dim_t ic_start, ic_end; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + + for (dim_t ic = ic_start; ic < ic_end; ic += MC) + { + dim_t mc0 = bli_min((ic_end - ic), MC); + const float *a_use = a + ic * rs_a; + c_use = c + ic * rs_c; + post_ops_attr.post_op_c_i = ic; + + // Call lpgemv_n_one kernel + lpgemv_n_one_kernel_f32_ker_ft + ( + mc0, k, + a_use, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + MR, KC, + post_op_list, + &post_ops_attr + ); + } + } + else + { + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end); + + for (dim_t jc = jc_start; jc < jc_end; jc += NC) + { + dim_t nc0 = bli_min((jc_end - jc), NC); + c_use = c + jc; + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + const float *b_use = NULL; + + if (mtag_b == REORDERED) + { + get_B_panel_reordered_start_offset_width( + jc, n, NC, NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated); + + b_use = b + (jc_cur_loop * k); + }else + { + b_use = b + jc; + } + + //update post-op pointer + post_ops_attr.post_op_c_j = jc; + + // Call kernel + lpgemv_m_one_kernel_f32_ker_ft + ( + nc0, k, + a, rs_a, cs_a, mtag_a, + b_use, rs_b, cs_b, mtag_b, + c_use, rs_c, cs_c, + alpha, beta, + NR, KC, + n_sub_updated, + jc_cur_loop_rem, + post_op_list, + &post_ops_attr + ); + + if (mtag_b == REORDERED) + { + adjust_B_panel_reordered_jc(&jc, jc_cur_loop); + } + } // jc loop + } +} +#endif + +LPGEMM_5LOOP(float, float, float, f32f32f32of32) +{ +#ifdef BLIS_KERNELS_ZEN4 + // Handle using LPGEMV when m or/and n equal to 1 + // The avx512 check will be removed when avx2 kernels added in future + if ((m == 1 || n == 1) && (bli_cpuid_is_avx512_supported() == TRUE)) + { + lpgemv_rowvar_f32f32f32of32(m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, + beta, + rntm, + thread, + lcntx, + post_op_list, + c_downscale); + return; + } +#endif // Query the global cntx. cntx_t* cntx = bli_gks_query_cntx(); @@ -101,8 +232,6 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); - /*ToDo: Based on context kernel 6x64m or 6x16m will be picked here */ - // Strides are updated based on matrix packing/reordering. const float* a_use = NULL; dim_t rs_a_use = rs_a; @@ -150,7 +279,7 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) bool is_first_k = FALSE; lpgemm_post_op_attr post_ops_attr; - post_ops_attr.c_stor_type = c_downscale; + post_ops_attr.c_stor_type = c_downscale; if ( c_downscale < F32 ) { post_ops_attr.buf_downscale = c; diff --git a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h index a0920edaf..915d13a52 100644 --- a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h +++ b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -71,4 +71,33 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32); LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32); LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32); LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16); -#endif // LPGEMM_5LOOP_INTF_H + +#define LPGEMV(A_type, B_type, C_type, LP_SFX) \ +void lpgemv_rowvar_ ## LP_SFX \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type *a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type *b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type *c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + rntm_t *rntm, \ + lpgemm_thrinfo_t *thread, \ + lpgemm_cntx_t *lcntx, \ + lpgemm_post_op *post_op_list, \ + AOCL_STORAGE_TYPE c_downscale \ + ) \ + +LPGEMV(float, float, float, f32f32f32of32); + +#endif // LPGEMM_5LOOP_INTF_H \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/lpgemm_kernels.h b/addon/aocl_gemm/kernels/lpgemm_kernels.h index 83132e8fb..06e4c3989 100644 --- a/addon/aocl_gemm/kernels/lpgemm_kernels.h +++ b/addon/aocl_gemm/kernels/lpgemm_kernels.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -366,4 +366,52 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16); +void lpgemv_m_one_kernel_f32_ker_ft +( + const dim_t n0, + const dim_t k, + const float *a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const float *b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + float *c, + const dim_t rs_c, + const dim_t cs_c, + const float alpha, + const float beta, + const dim_t NC, + const dim_t KC, + const dim_t n_sub_updated, + const dim_t jc_cur_loop_rem, + lpgemm_post_op *post_op, + lpgemm_post_op_attr *post_op_attr +); + +void lpgemv_n_one_kernel_f32_ker_ft +( + const dim_t m0, + const dim_t k, + const float *a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const float *b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + float *c, + const dim_t rs_c, + const dim_t cs_c, + const float alpha, + const float beta, + const dim_t MR, + const dim_t KC, + lpgemm_post_op *post_op, + lpgemm_post_op_attr *post_op_attr +); + #endif //BLIS_LPGEMM_KERN_H diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index 1f555d28b..85d846032 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -204,7 +204,7 @@ void fill_array_ ## ctype ( void* arr, dim_t size ) \ ctype* temp_arr = ( ctype* ) arr; \ for ( dim_t i = 0; i < size; ++i ) \ { \ - temp_arr[i] = ( ctype )( i % 5 ); \ + temp_arr[i] = ( ctype )( rand() % 5 ); \ } \ } \ @@ -221,7 +221,7 @@ void fill_array_bfloat16( void* arr, dim_t size ) float* c_float = ( float* ) bli_malloc_user( sizeof( float ) * size, &bli_errors ); for ( dim_t i = 0; i < size; ++i ) { - c_float[i] = i % 5; + c_float[i] = (rand() % 5 ); } convert_float_arr_to_bf16( c_float, arr, size ); if ( c_float != NULL ) @@ -236,7 +236,7 @@ void fill_array_post_ops_ ## ctype ( void* arr, dim_t size ) \ ctype* temp_arr = ( ctype* ) arr; \ for ( dim_t i = 0; i < size; ++i ) \ { \ - temp_arr[i] = ( ctype )( i % 20 ); \ + temp_arr[i] = ( ctype )( rand() % 20 ); \ } \ } \ @@ -1595,7 +1595,7 @@ int main( int argc, char** argv ) int32_t stride_a, stride_b, stride_c; const dim_t len_list_omp_cores_for_testing = 2; - const dim_t list_omp_cores_for_testing[2] = { 80, 1 }; + const dim_t list_omp_cores_for_testing[2] = { 64, 1 }; dim_t core_index = 0; bool can_run = TRUE; diff --git a/kernels/zen/lpgemm/f32f32f32/lpgemm_fringe_f32_avx2.c b/kernels/zen/lpgemm/f32f32f32/lpgemm_fringe_f32_avx2.c index 0339af90c..61462b392 100644 --- a/kernels/zen/lpgemm/f32f32f32/lpgemm_fringe_f32_avx2.c +++ b/kernels/zen/lpgemm/f32f32f32/lpgemm_fringe_f32_avx2.c @@ -3831,8 +3831,9 @@ POST_OPS_BIAS_5x2F: if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = ( __m128 )_mm_load_sd( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = (__m128)_mm_load_sd((const double *) + ((float * )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + (0 * 8))); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); @@ -4114,8 +4115,9 @@ POST_OPS_BIAS_4x2F: if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = ( __m128 )_mm_load_sd( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = (__m128)_mm_load_sd((const double *) + ((float *)post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + (0 * 8))); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); @@ -4360,8 +4362,9 @@ POST_OPS_BIAS_3x2F: if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = ( __m128 )_mm_load_sd( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = (__m128)_mm_load_sd( (const double *) + ((float *) post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + (0 * 8))); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); @@ -4575,8 +4578,9 @@ POST_OPS_BIAS_2x2F: if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = ( __m128 )_mm_load_sd( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = (__m128)_mm_load_sd((const double *) + ((float *)post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + (0 * 8))); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); @@ -4750,8 +4754,9 @@ POST_OPS_BIAS_1x2F: if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = ( __m128 )_mm_load_sd( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = (__m128)_mm_load_sd((const double *) + ((float*)post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + (0 * 8))); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); diff --git a/kernels/zen/lpgemm/f32f32f32/lpgemm_kernel_macros_f32_avx2.h b/kernels/zen/lpgemm/f32f32f32/lpgemm_kernel_macros_f32_avx2.h index 9cede8b48..727b83a95 100644 --- a/kernels/zen/lpgemm/f32f32f32/lpgemm_kernel_macros_f32_avx2.h +++ b/kernels/zen/lpgemm/f32f32f32/lpgemm_kernel_macros_f32_avx2.h @@ -153,8 +153,8 @@ #define F32_F32_MATRIX_ADD_LOAD_XMM_2ELE(scr,m_ind,n_ind) \ scr = ( __m128 )_mm_load_sd \ ( \ - matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ - post_ops_attr.post_op_c_j + ( n_ind * 2 ) \ + (double*)(matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 2 )) \ ); \ #define F32_F32_MATRIX_ADD_1COL_XMM_2ELE(scr0,m_ind,r_ind0) \ diff --git a/kernels/zen/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx2.c b/kernels/zen/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx2.c index d4a0208ec..e9d478b61 100644 --- a/kernels/zen/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx2.c +++ b/kernels/zen/lpgemm/f32f32f32/lpgemm_m_kernel_f32_avx2.c @@ -1491,8 +1491,9 @@ POST_OPS_BIAS_6x2F: if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) { - xmm0 = ( __m128 )_mm_load_sd( ( float* )post_ops_list_temp->op_args1 + - post_ops_attr.post_op_c_j + ( 0 * 8 ) ); + xmm0 = ( __m128 )_mm_load_sd( (const double*) + (( float* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 8 ) )); // c[0,0-3] xmm4 = _mm_add_ps( xmm4, xmm0 ); diff --git a/kernels/zen/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx2.c b/kernels/zen/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx2.c new file mode 100644 index 000000000..b39e32fd0 --- /dev/null +++ b/kernels/zen/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx2.c @@ -0,0 +1,72 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_kernel_macros_f32_avx2.h" + +void lpgemv_m_one_kernel_f32_avx2_ker_ft +( + const dim_t n0, + const dim_t k, + const float *a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const float *b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + float *c, + const dim_t rs_c, + const dim_t cs_c, + const float alpha, + const float beta, + const dim_t NR, + const dim_t KC, + const dim_t n_sub_updated, + const dim_t jc_cur_loop_rem, + lpgemm_post_op *post_op_list, + lpgemm_post_op_attr *post_op_attr +) +{ + // TODO: Created dummy function as place holder. + // AVX2 varient wil be implemented in next commits. + // Code will take LPGEMM path for LPGEMV in AVX2 env +} + +#endif // BLIS_ADDON_LPGEMM \ No newline at end of file diff --git a/kernels/zen/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx2.c b/kernels/zen/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx2.c new file mode 100644 index 000000000..cfcd94363 --- /dev/null +++ b/kernels/zen/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx2.c @@ -0,0 +1,75 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_kernel_macros_f32_avx2.h" + +// When n=1 is load 16x1 from B and load MRx16 from A and perform dot product +// to produce C output of MRX1. The vectorization is done in k loop and +// the horizontal reduction done to produce one output from each +// accumulator register +void lpgemv_n_one_kernel_f32_avx2_ker_ft +( + const dim_t m0, + const dim_t k, + const float *a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const float *b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + float *c, + const dim_t rs_c, + const dim_t cs_c, + const float alpha, + const float beta, + const dim_t MR, + const dim_t KC, + lpgemm_post_op *post_op_list, + lpgemm_post_op_attr *post_op_attr +) +{ +//TODO: Created dummy function as place holder to get +//rid of linking issues in other zen configurations. +//AVX2 varient wil be implemented in next commits. +//Code will take LPGEMM path for LPGEMV in AVX2 env. +} + +#endif // BLIS_ADDON_LPGEMM \ No newline at end of file diff --git a/kernels/zen4/lpgemm/f32f32f32/lpgemm_kernel_macros_f32.h b/kernels/zen4/lpgemm/f32f32f32/lpgemm_kernel_macros_f32.h index 5d1019ea7..44fd7e4da 100644 --- a/kernels/zen4/lpgemm/f32f32f32/lpgemm_kernel_macros_f32.h +++ b/kernels/zen4/lpgemm/f32f32f32/lpgemm_kernel_macros_f32.h @@ -67,6 +67,13 @@ zmm2 = _mm512_setzero_ps(); \ zmm3 = _mm512_setzero_ps(); +// Zero-out the given ZMM accumulator registers +#define ZERO_ACC_XMM_4_REG(xmm0, xmm1, xmm2, xmm3) \ + xmm0 = _mm_setzero_ps(); \ + xmm1 = _mm_setzero_ps(); \ + xmm2 = _mm_setzero_ps(); \ + xmm3 = _mm_setzero_ps(); + /*Multiply alpha with accumulator registers and store back*/ #define ALPHA_MUL_ACC_ZMM_4_REG(zmm0,zmm1,zmm2,zmm3,alpha) \ zmm0 = _mm512_mul_ps(zmm0,alpha); \ diff --git a/kernels/zen4/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx512.c b/kernels/zen4/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx512.c new file mode 100644 index 000000000..aeec517b4 --- /dev/null +++ b/kernels/zen4/lpgemm/f32f32f32/lpgemv_m_kernel_f32_avx512.c @@ -0,0 +1,424 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_kernel_macros_f32.h" + +void lpgemv_m_one_kernel_f32_ker_ft +( + const dim_t n0, + const dim_t k, + const float *a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const float *b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + float *c, + const dim_t rs_c, + const dim_t cs_c, + const float alpha, + const float beta, + const dim_t NR, + const dim_t KC, + const dim_t n_sub_updated, + const dim_t jc_cur_loop_rem, + lpgemm_post_op *post_op_list, + lpgemm_post_op_attr *post_op_attr +) +{ + static void *post_ops_labels[] = + { + &&POST_OPS_6x64F_DISABLE, + &&POST_OPS_BIAS_6x64F, + &&POST_OPS_RELU_6x64F, + &&POST_OPS_RELU_SCALE_6x64F, + &&POST_OPS_GELU_TANH_6x64F, + &&POST_OPS_GELU_ERF_6x64F, + &&POST_OPS_CLIP_6x64F, + NULL, // Virtual node for downscale, else segfault + && POST_OPS_MATRIX_ADD_6x64F + }; + + // Strides are updated based on matrix packing/reordering. + const float *a_use = NULL; + const float *b_use = NULL; + float *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + for (dim_t jr = 0; jr < n0; jr += NR) + { + dim_t nr0 = bli_min((n0 - jr), NR); + c_use = c + jr; + __mmask16 k1 = 0xFFFF, k2 = 0xFFFF, k3 = 0xFFFF, k4 = 0xFFFF; + + if (nr0 < NR) + { + __mmask16 k = (0xFFFF >> (16 - (nr0 & 0x0F))); + if (nr0 >= 48) + { + k4 = k; + } + else if (nr0 >= 32) + { + k3 = k; + k4 = 0; + } + else if (nr0 >= 16) + { + k2 = k; + k3 = k4 = 0; + } + else + { + k1 = k; + k2 = k3 = k4 = 0; + } + } + + __m512 zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + __m512 zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14; + __m512 zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512 zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28; + __m512 zmm29, zmm30, zmm31; + + // zero the accumulator registers + ZERO_ACC_ZMM_4_REG(zmm8, zmm9, zmm10, zmm11); + ZERO_ACC_ZMM_4_REG(zmm12, zmm13, zmm14, zmm15); + ZERO_ACC_ZMM_4_REG(zmm16, zmm17, zmm18, zmm19); + ZERO_ACC_ZMM_4_REG(zmm20, zmm21, zmm22, zmm23); + + //Zero out registers used for mask load to avoid warnings + ZERO_ACC_ZMM_4_REG(zmm0, zmm1, zmm2, zmm3); + ZERO_ACC_ZMM_4_REG(zmm24, zmm25, zmm26, zmm27); + ZERO_ACC_ZMM_4_REG(zmm28, zmm29, zmm30, zmm31); + + _mm256_zeroupper(); + + //_mm_prefetch( (MR X NR) from C + _mm_prefetch((c_use + 0 * rs_c), _MM_HINT_T0); + _mm_prefetch((c_use + 16 * rs_c), _MM_HINT_T0); + _mm_prefetch((c_use + 32 * rs_c), _MM_HINT_T0); + _mm_prefetch((c_use + 64 * rs_c), _MM_HINT_T0); + + for (dim_t pc = 0; pc < k; pc += KC) + { + dim_t kc0 = bli_min((k - pc), KC); + uint64_t k_iter = kc0 / 4; + uint64_t k_rem = kc0 % 4; + dim_t ps_b_use = 0; + dim_t rs_b_use = NR; + // No parallelization in k dim, k always starts at 0. + if (mtag_b == REORDERED) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + (n_sub_updated * pc) + (jc_cur_loop_rem * kc0); + ps_b_use = kc0; + } + else + { + b_use = b + (pc * rs_b); + ps_b_use = 1; + rs_b_use = rs_b; + } + + a_use = a + pc; + b_use = b_use + jr * ps_b_use; + + for (dim_t k = 0; k < k_iter; k++) + { + _mm_prefetch((b_use + 4 * rs_b_use), _MM_HINT_T0); + //Using mask loads to avoid writing fringe kernels + + //Load first 4x16 tile from row 0-3 + zmm0 = _mm512_maskz_loadu_ps(k1, b_use); + zmm1 = _mm512_maskz_loadu_ps(k1, b_use + rs_b_use); + zmm2 = _mm512_maskz_loadu_ps(k1, b_use + 2 * rs_b_use); + zmm3 = _mm512_maskz_loadu_ps(k1, b_use + 3 * rs_b_use); + b_use += 16; + + //Broadcast col0 - col3 element of A + zmm4 = _mm512_set1_ps(*(a_use)); // broadcast c0 + zmm5 = _mm512_set1_ps(*(a_use + 1)); // broadcast c1 + zmm6 = _mm512_set1_ps(*(a_use + 2)); // broadcast c2 + zmm7 = _mm512_set1_ps(*(a_use + 3)); // broadcast c3 + + //Load second 4x16 tile from row 0-3 + zmm24 = _mm512_maskz_loadu_ps(k2, b_use); + zmm25 = _mm512_maskz_loadu_ps(k2, b_use + rs_b_use); + zmm26 = _mm512_maskz_loadu_ps(k2, b_use + 2 * rs_b_use); + zmm27 = _mm512_maskz_loadu_ps(k2, b_use + 3 * rs_b_use); + b_use += 16; + + zmm8 = _mm512_fmadd_ps(zmm0, zmm4, zmm8); + zmm9 = _mm512_fmadd_ps(zmm1, zmm5, zmm9); + zmm10 = _mm512_fmadd_ps(zmm2, zmm6, zmm10); + zmm11 = _mm512_fmadd_ps(zmm3, zmm7, zmm11); + + //Load third 4x16 tile from row 0-3 + zmm0 = _mm512_maskz_loadu_ps(k3, b_use); + zmm1 = _mm512_maskz_loadu_ps(k3, b_use + rs_b_use); + zmm2 = _mm512_maskz_loadu_ps(k3, b_use + 2 * rs_b_use); + zmm3 = _mm512_maskz_loadu_ps(k3, b_use + 3 * rs_b_use); + b_use += 16; + + zmm12 = _mm512_fmadd_ps(zmm24, zmm4, zmm12); + zmm13 = _mm512_fmadd_ps(zmm25, zmm5, zmm13); + zmm14 = _mm512_fmadd_ps(zmm26, zmm6, zmm14); + zmm15 = _mm512_fmadd_ps(zmm27, zmm7, zmm15); + + //Load fourth 4x16 tile from row 0-3 + zmm28 = _mm512_maskz_loadu_ps(k4, b_use); + zmm29 = _mm512_maskz_loadu_ps(k4, b_use + rs_b_use); + zmm30 = _mm512_maskz_loadu_ps(k4, b_use + 2 * rs_b_use); + zmm31 = _mm512_maskz_loadu_ps(k4, b_use + 3 * rs_b_use); + + zmm16 = _mm512_fmadd_ps(zmm0, zmm4, zmm16); + zmm17 = _mm512_fmadd_ps(zmm1, zmm5, zmm17); + zmm18 = _mm512_fmadd_ps(zmm2, zmm6, zmm18); + zmm19 = _mm512_fmadd_ps(zmm3, zmm7, zmm19); + + zmm20 = _mm512_fmadd_ps(zmm28, zmm4, zmm20); + zmm21 = _mm512_fmadd_ps(zmm29, zmm5, zmm21); + zmm22 = _mm512_fmadd_ps(zmm30, zmm6, zmm22); + zmm23 = _mm512_fmadd_ps(zmm31, zmm7, zmm23); + + b_use -= 48; // move b point back to start of KCXNR + b_use += (4 * rs_b_use); + a_use += 4; // move a pointer to next col + } // kloop + + for (dim_t kr = 0; kr < k_rem; kr++) + { + //Load 64 elements from a row of B + zmm0 = _mm512_maskz_loadu_ps(k1, b_use); + zmm1 = _mm512_maskz_loadu_ps(k2, b_use + 16); + zmm2 = _mm512_maskz_loadu_ps(k3, b_use + 32); + zmm3 = _mm512_maskz_loadu_ps(k4, b_use + 48); + + //Broadcast col0 elements of 12 rows of A + zmm4 = _mm512_set1_ps(*(a_use)); // broadcast c0r0 + + zmm8 = _mm512_fmadd_ps(zmm0, zmm4, zmm8); + zmm12 = _mm512_fmadd_ps(zmm1, zmm4, zmm12); + zmm16 = _mm512_fmadd_ps(zmm2, zmm4, zmm16); + zmm20 = _mm512_fmadd_ps(zmm3, zmm4, zmm20); + + b_use += rs_b_use; // move b pointer to next row + a_use++; // move a pointer to next col + } // kloop + } // kc loop + + //SUMUP K untoll output + zmm8 = _mm512_add_ps(zmm9, zmm8); + zmm10 = _mm512_add_ps(zmm11, zmm10); + zmm8 = _mm512_add_ps(zmm10, zmm8); // 16 outputs + + zmm12 = _mm512_add_ps(zmm13, zmm12); + zmm14 = _mm512_add_ps(zmm15, zmm14); + zmm12 = _mm512_add_ps(zmm14, zmm12); // 16 outputs + + zmm16 = _mm512_add_ps(zmm17, zmm16); + zmm18 = _mm512_add_ps(zmm19, zmm18); + zmm16 = _mm512_add_ps(zmm18, zmm16); // 16 outputs + + zmm20 = _mm512_add_ps(zmm21, zmm20); + zmm22 = _mm512_add_ps(zmm23, zmm22); + zmm20 = _mm512_add_ps(zmm22, zmm20); // 16 outputs + + //Mulitply A*B output with alpha + zmm0 = _mm512_set1_ps(alpha); + zmm8 = _mm512_mul_ps(zmm0, zmm8); + zmm12 = _mm512_mul_ps(zmm0, zmm12); + zmm16 = _mm512_mul_ps(zmm0, zmm16); + zmm20 = _mm512_mul_ps(zmm0, zmm20); + + if (beta != 0) + { + const float *_cbuf = c_use; + // load c and multiply with beta and + // add to accumulator and store back + zmm3 = _mm512_set1_ps(beta); + zmm0 = _mm512_maskz_loadu_ps(k1, _cbuf); + zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); + + zmm1 = _mm512_maskz_loadu_ps(k2, (_cbuf + 16)); + zmm12 = _mm512_fmadd_ps(zmm1, zmm3, zmm12); + + zmm2 = _mm512_maskz_loadu_ps(k3, (_cbuf + 32)); + zmm16 = _mm512_fmadd_ps(zmm2, zmm3, zmm16); + + zmm4 = _mm512_maskz_loadu_ps(k4, (_cbuf + 48)); + zmm20 = _mm512_fmadd_ps(zmm4, zmm3, zmm20); + } + + // Post Ops + post_ops_attr.is_last_k = TRUE; + lpgemm_post_op *post_ops_list_temp = post_op_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + + POST_OPS_BIAS_6x64F: + { + if ((*(char *)post_ops_list_temp->op_args2 == 'r') || + (*(char *)post_ops_list_temp->op_args2 == 'R')) + { + float* bias_ptr = (float *)post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j; + zmm9 = _mm512_maskz_loadu_ps(k1, bias_ptr + (0 * 16)); + + zmm10 = _mm512_maskz_loadu_ps(k2, bias_ptr + (1 * 16)); + + zmm13 = _mm512_maskz_loadu_ps(k3, bias_ptr + (2 * 16)); + + zmm14 = _mm512_maskz_loadu_ps(k4, bias_ptr + (3 * 16)); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + float bias = (*((float *)post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_i + 0)); + + zmm9 = _mm512_set1_ps(bias); + zmm10 = zmm13 = zmm14 = zmm9; + } + // c[0,0-15] + zmm8 = _mm512_add_ps(zmm9, zmm8); + zmm12 = _mm512_add_ps(zmm10, zmm12); + zmm16 = _mm512_add_ps(zmm13, zmm16); + zmm20 = _mm512_add_ps(zmm14, zmm20); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_6x64F: + { + zmm1 = _mm512_setzero_ps(); + + // c[0,0-15] + zmm8 = _mm512_max_ps(zmm1, zmm8); + zmm12 = _mm512_max_ps(zmm1, zmm12); + zmm16 = _mm512_max_ps(zmm1, zmm16); + zmm20 = _mm512_max_ps(zmm1, zmm20); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_SCALE_6x64F: + { + zmm1 = _mm512_setzero_ps(); + zmm2 = + _mm512_set1_ps(*((float *)post_ops_list_temp->op_args2)); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm8) + RELU_SCALE_OP_F32S_AVX512(zmm12) + RELU_SCALE_OP_F32S_AVX512(zmm16) + RELU_SCALE_OP_F32S_AVX512(zmm20) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_TANH_6x64F: + { + __m512i zmm6; + // c[0, 0-15] + GELU_TANH_F32S_AVX512(zmm8, zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6) + GELU_TANH_F32S_AVX512(zmm12, zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6) + GELU_TANH_F32S_AVX512(zmm16, zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6) + GELU_TANH_F32S_AVX512(zmm20, zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_ERF_6x64F: + { + // c[0, 0-15] + GELU_ERF_F32S_AVX512(zmm8, zmm0, zmm1, zmm2) + GELU_ERF_F32S_AVX512(zmm12, zmm0, zmm1, zmm2) + GELU_ERF_F32S_AVX512(zmm16, zmm0, zmm1, zmm2) + GELU_ERF_F32S_AVX512(zmm20, zmm0, zmm1, zmm2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_CLIP_6x64F: + { + zmm0 = _mm512_set1_ps(*(float *)post_ops_list_temp->op_args2); + zmm1 = _mm512_set1_ps(*(float *)post_ops_list_temp->op_args3); + + // c[0, 0-15] + CLIP_F32S_AVX512(zmm8, zmm0, zmm1) + CLIP_F32S_AVX512(zmm12, zmm0, zmm1) + CLIP_F32S_AVX512(zmm16, zmm0, zmm1) + CLIP_F32S_AVX512(zmm20, zmm0, zmm1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_ADD_6x64F: + { + float *matptr = (float *)post_ops_list_temp->op_args1; + zmm0 = _mm512_maskz_loadu_ps(k1, (matptr + post_ops_attr.post_op_c_j)); + zmm8 = _mm512_add_ps(zmm8, zmm0); + zmm0 = _mm512_maskz_loadu_ps(k2, (matptr + post_ops_attr.post_op_c_j + 16)); + zmm12 = _mm512_add_ps(zmm12, zmm0); + zmm0 = _mm512_maskz_loadu_ps(k3, (matptr + post_ops_attr.post_op_c_j + 32)); + zmm16 = _mm512_add_ps(zmm16, zmm0); + zmm0 = _mm512_maskz_loadu_ps(k4, (matptr + post_ops_attr.post_op_c_j + 48)); + zmm20 = _mm512_add_ps(zmm20, zmm0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_6x64F_DISABLE: + { + _mm512_mask_storeu_ps(c_use, k1, zmm8); + _mm512_mask_storeu_ps((c_use + 16), k2, zmm12); + _mm512_mask_storeu_ps((c_use + 32), k3, zmm16); + _mm512_mask_storeu_ps((c_use + 48), k4, zmm20); + post_ops_attr.post_op_c_j += NR; + } + } // jr loop +} + +#endif // BLIS_ADDON_LPGEMM \ No newline at end of file diff --git a/kernels/zen4/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx512.c b/kernels/zen4/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx512.c new file mode 100644 index 000000000..eab499946 --- /dev/null +++ b/kernels/zen4/lpgemm/f32f32f32/lpgemv_n_kernel_f32_avx512.c @@ -0,0 +1,517 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + +#include "lpgemm_kernel_macros_f32.h" + +#define LPGEMV_N_KERNEL_4_LOADS(zmm0, zmm1, zmm2, zmm3, paddr, stride) \ + zmm0 = _mm512_loadu_ps(paddr); \ + zmm1 = _mm512_loadu_ps(paddr + stride); \ + zmm2 = _mm512_loadu_ps(paddr + 2 * stride); \ + zmm3 = _mm512_loadu_ps(paddr + 3 * stride); + +#define LPGEMV_N_KERNEL_4_MASKLOADS(zmm0, zmm1, zmm2, zmm3, zmm7, k1, paddr, stride) \ + zmm0 = _mm512_mask_loadu_ps(zmm7, k1, paddr); \ + zmm1 = _mm512_mask_loadu_ps(zmm7, k1, paddr + stride); \ + zmm2 = _mm512_mask_loadu_ps(zmm7, k1, paddr + 2 * stride); \ + zmm3 = _mm512_mask_loadu_ps(zmm7, k1, paddr + 3 * stride); + +#define LPGEMV_N_KERNEL_4_FMA(zmm8, zmm9, zmm10, zmm11, zmm6, zmm0, zmm1, zmm2, zmm3) \ + zmm8 = _mm512_fmadd_ps(zmm0, zmm6, zmm8); \ + zmm9 = _mm512_fmadd_ps(zmm1, zmm6, zmm9); \ + zmm10 = _mm512_fmadd_ps(zmm2, zmm6, zmm10); \ + zmm11 = _mm512_fmadd_ps(zmm3, zmm6, zmm11); + +#define LPGEMV_ZMM2XMM(zmm0, zmm1, zmm2, zmm3, ymm0, ymm1, ymm2, ymm3, xmm0) \ + ymm0 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm0, 0x0), \ + _mm512_extractf32x8_ps(zmm0, 0x1)); \ + ymm1 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm1, 0x0), \ + _mm512_extractf32x8_ps(zmm1, 0x1)); \ + ymm0 = _mm256_hadd_ps(ymm0, ymm1); \ + ymm2 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm2, 0x0), \ + _mm512_extractf32x8_ps(zmm2, 0x1)); \ + ymm3 = _mm256_add_ps(_mm512_extractf32x8_ps(zmm3, 0x0), \ + _mm512_extractf32x8_ps(zmm3, 0x1)); \ + ymm1 = _mm256_hadd_ps(ymm2, ymm3); \ + ymm0 = _mm256_hadd_ps(ymm0, ymm1); \ + xmm0 = _mm_add_ps(_mm256_extractf128_ps(ymm0, 0), _mm256_extractf128_ps(ymm0,1)); + +// When n=1 is load 16x1 from B and load MRx16 from A and perform dot product +// to produce C output of MRX1. The vectorization is done in k loop and +// the horizontal reduction done to produce one output from each +// accumulator register +void lpgemv_n_one_kernel_f32_ker_ft +( + const dim_t m0, + const dim_t k, + const float *a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const float *b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + float *c, + const dim_t rs_c, + const dim_t cs_c, + const float alpha, + const float beta, + const dim_t MR, + const dim_t KC, + lpgemm_post_op *post_op_list, + lpgemm_post_op_attr *post_op_attr +) +{ + static void *post_ops_labels[] = + { + &&POST_OPS_6x64F_DISABLE, + &&POST_OPS_BIAS_6x64F, + &&POST_OPS_RELU_6x64F, + &&POST_OPS_RELU_SCALE_6x64F, + &&POST_OPS_GELU_TANH_6x64F, + &&POST_OPS_GELU_ERF_6x64F, + &&POST_OPS_CLIP_6x64F, + NULL, // Virtual node for downscale, else segfault + &&POST_OPS_MATRIX_ADD_6x64F + }; + + // Strides are updated based on matrix packing/reordering. + const float *a_use = NULL; + const float *b_use = NULL; + float *c_use = NULL; + + lpgemm_post_op_attr post_ops_attr = *(post_op_attr); + + for (dim_t mr = 0; mr < m0; mr += MR) + { + dim_t mr0 = bli_min((m0 - mr), MR); + dim_t k_iter = k/16; + dim_t k_rem = k & 0xF; + + //Create load mask for k fringe + __mmask16 k1 = 0xFFFF; + if (k_rem) + { + k1 = (0xFFFF >> (16 - k_rem)); + } + + // Create store mask for C for mr fringe + __mmask16 k2 = 0xFFFF; + if (mr0 < MR) + { + k2 = (0xFFFF >> (MR - mr0)); + } + + __m512 zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + __m512 zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14; + __m512 zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512 zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28; + __m512 zmm29, zmm30, zmm31; + + __m256 ymm0,ymm1,ymm2,ymm3,ymm4,ymm5,ymm6; + __m128 xmm0, xmm1, xmm2, xmm3; + + ZERO_ACC_ZMM_4_REG(zmm0, zmm1, zmm2, zmm3); + ZERO_ACC_ZMM_4_REG(zmm4, zmm5, zmm6, zmm7); + /* zero the accumulator registers */ + ZERO_ACC_ZMM_4_REG(zmm8, zmm9, zmm10, zmm11); + ZERO_ACC_ZMM_4_REG(zmm12, zmm13, zmm14, zmm15); + ZERO_ACC_ZMM_4_REG(zmm16, zmm17, zmm18, zmm19); + ZERO_ACC_ZMM_4_REG(zmm20, zmm21, zmm22, zmm23); + ZERO_ACC_ZMM_4_REG(zmm24, zmm25, zmm26, zmm27); + ZERO_ACC_ZMM_4_REG(zmm28, zmm29, zmm30, zmm31); + ZERO_ACC_XMM_4_REG (xmm0,xmm1,xmm2,xmm3) + + _mm256_zeroupper(); + + //update pointers + a_use = a + mr * rs_a; + b_use = b; + c_use = c + mr * rs_c; + + //prefetch C + _mm_prefetch(c_use, _MM_HINT_T0); + _mm_prefetch(b_use, _MM_HINT_T0); + + //Check for MR whether to process main kernel or mfringe kernel + if (mr0 == MR) + { + //Dot product kernel + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_ps(b_use); // Load 0-15 in b[k+0 - k+15] + b_use += 16; // move b pointer to next 16 elements + + //Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS(zmm0, zmm1, zmm2, zmm3, a_use, rs_a) + a_use += (4 * rs_a); + + // Load 4x16 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_LOADS(zmm24, zmm25, zmm26, zmm27, a_use, rs_a) + a_use += (4 * rs_a); + + LPGEMV_N_KERNEL_4_FMA(zmm8, zmm9, zmm10, zmm11, zmm6, zmm0, zmm1, zmm2, zmm3) + + // Load 4x16 elements from row8-row11 of A + LPGEMV_N_KERNEL_4_LOADS(zmm28, zmm29, zmm30, zmm31, a_use, rs_a) + a_use += (4 * rs_a); + + // Load 4x16 elements from row12-row15 of A + LPGEMV_N_KERNEL_4_LOADS(zmm0, zmm1, zmm2, zmm3, a_use, rs_a) + a_use -= (12 * rs_a); //Update aptr back to move horizontally + + LPGEMV_N_KERNEL_4_FMA(zmm12, zmm13, zmm14, zmm15, zmm6, zmm24, zmm25, zmm26, zmm27) + LPGEMV_N_KERNEL_4_FMA(zmm16, zmm17, zmm18, zmm19, zmm6, zmm28, zmm29, zmm30, zmm31) + LPGEMV_N_KERNEL_4_FMA(zmm20, zmm21, zmm22, zmm23, zmm6, zmm0, zmm1, zmm2, zmm3) + a_use += 16; + }// kloop + + if(k_rem) + { + zmm6 = _mm512_mask_loadu_ps(zmm7, k1, b_use); // Load 0-15 in b[k+0 - k+15] + + // Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS(zmm0, zmm1, zmm2, zmm3, zmm7, k1, a_use, rs_a) + a_use += (4 * rs_a); + + LPGEMV_N_KERNEL_4_MASKLOADS(zmm24, zmm25, zmm26, zmm27, zmm7, k1, a_use, rs_a) + a_use += (4 * rs_a); + + LPGEMV_N_KERNEL_4_FMA(zmm8, zmm9, zmm10, zmm11, zmm6, zmm0, zmm1, zmm2, zmm3) + + LPGEMV_N_KERNEL_4_MASKLOADS(zmm28, zmm29, zmm30, zmm31, zmm7, k1, a_use, rs_a) + a_use += (4 * rs_a); + + LPGEMV_N_KERNEL_4_MASKLOADS(zmm0, zmm1, zmm2, zmm3, zmm7, k1, a_use, rs_a) + + LPGEMV_N_KERNEL_4_FMA(zmm12, zmm13, zmm14, zmm15, zmm6, zmm24, zmm25, zmm26, zmm27) + LPGEMV_N_KERNEL_4_FMA(zmm16, zmm17, zmm18, zmm19, zmm6, zmm28, zmm29, zmm30, zmm31) + LPGEMV_N_KERNEL_4_FMA(zmm20, zmm21, zmm22, zmm23, zmm6, zmm0, zmm1, zmm2, zmm3) + }// kloop + + //Add the registers horizantally to get one + LPGEMV_ZMM2XMM(zmm8, zmm9, zmm10, zmm11, ymm0, ymm1, ymm2, ymm3, xmm0) + LPGEMV_ZMM2XMM(zmm12, zmm13, zmm14, zmm15, ymm4, ymm1, ymm2, ymm3, xmm1) + LPGEMV_ZMM2XMM(zmm16, zmm17, zmm18, zmm19, ymm5, ymm1, ymm2, ymm3, xmm2) + LPGEMV_ZMM2XMM(zmm20, zmm21, zmm22, zmm23, ymm6, ymm1, ymm2, ymm3, xmm3) + + //compose outputs into one zmm to perform post-ops + zmm8 = _mm512_insertf32x4(zmm8, xmm0, 0); + zmm8 = _mm512_insertf32x4(zmm8, xmm1, 1); + zmm8 = _mm512_insertf32x4(zmm8, xmm2, 2); + zmm8 = _mm512_insertf32x4(zmm8, xmm3, 3); + }else + { + //Handle fringe cases when mr0 < MR + const float *a_use_fringe = a_use; + dim_t mr0_use = mr0; + dim_t regidx = 0; + + // Dot product for mfringe 8 + if (mr0_use >= 8) + { + // Dot product kernel for mr0 == 8 + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_ps(b_use); // Load 0-15 in b[k+0 - k+15] + b_use += 16; // move b pointer to next 16 elements + + // Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS(zmm0, zmm1, zmm2, zmm3, a_use, rs_a) + a_use += (4 * rs_a); + + // Load 4x16 elements from row3-row7 of A + LPGEMV_N_KERNEL_4_LOADS(zmm24, zmm25, zmm26, zmm27, a_use, rs_a) + a_use -= (4 * rs_a); + + //Perform FMA on two 4x16 block of A with 16x1 + LPGEMV_N_KERNEL_4_FMA(zmm8, zmm9, zmm10, zmm11, zmm6, zmm0, zmm1, zmm2, zmm3) + LPGEMV_N_KERNEL_4_FMA(zmm12, zmm13, zmm14, zmm15, zmm6, zmm24, zmm25, zmm26, zmm27) + a_use += 16; + } + + if (k_rem) + { + zmm6 = _mm512_mask_loadu_ps(zmm7, k1, b_use); // Load 0-15 in b[k+0 - k+15] + + // Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS(zmm0, zmm1, zmm2, zmm3, zmm7, k1, a_use, rs_a) + a_use += (4 * rs_a); + LPGEMV_N_KERNEL_4_MASKLOADS(zmm24, zmm25, zmm26, zmm27, zmm7, k1, a_use, rs_a) + LPGEMV_N_KERNEL_4_FMA(zmm8, zmm9, zmm10, zmm11, zmm6, zmm0, zmm1, zmm2, zmm3) + LPGEMV_N_KERNEL_4_FMA(zmm12, zmm13, zmm14, zmm15, zmm6, zmm24, zmm25, zmm26, zmm27) + } + + //update pointers + mr0_use -= 8; + a_use = a_use_fringe + 8 * rs_a; + a_use_fringe = a_use; + b_use = b; + + //Horizontal add 8 zmm registers and get output into 2 xmm registers + LPGEMV_ZMM2XMM(zmm8, zmm9, zmm10, zmm11, ymm0, ymm1, ymm2, ymm3, xmm0) + LPGEMV_ZMM2XMM(zmm12, zmm13, zmm14, zmm15, ymm4, ymm1, ymm2, ymm3, xmm1) + + //insert xmm outputs into final output zmm8 reg + zmm8 = _mm512_insertf32x4(zmm8, xmm0, 0); + zmm8 = _mm512_insertf32x4(zmm8, xmm1, 1); + regidx = 2; + } + + // Dot product for mfringe 4 + if (mr0_use >= 4) + { + // Dot product kernel for mr0 == 8 + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_ps(b_use); // Load 0-15 in b[k+0 - k+15] + b_use += 16; // move b pointer to next 16 elements + // Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_LOADS(zmm0, zmm1, zmm2, zmm3, a_use, rs_a) + // Perform FMA on 4x16 block of A with 16x1 + LPGEMV_N_KERNEL_4_FMA(zmm16, zmm17, zmm18, zmm19, zmm6, zmm0, zmm1, zmm2, zmm3) + a_use += 16; + } + + if (k_rem) + { + zmm6 = _mm512_mask_loadu_ps(zmm7, k1, b_use); // Load 0-15 in b[k+0 - k+15] + // Load 4x16 elements from row0-row3 of A + LPGEMV_N_KERNEL_4_MASKLOADS(zmm0, zmm1, zmm2, zmm3, zmm7, k1, a_use, rs_a) + LPGEMV_N_KERNEL_4_FMA(zmm16, zmm17, zmm18, zmm19, zmm6, zmm0, zmm1, zmm2, zmm3) + } + + //update pointers + mr0_use -= 4; + a_use = a_use_fringe + 4 * rs_a; + a_use_fringe = a_use; + b_use = b; + + //Horizontal add 4 zmm reg and get the output into one xmm + LPGEMV_ZMM2XMM(zmm16, zmm17, zmm18, zmm19, ymm5, ymm1, ymm2, ymm3, xmm2) + + //insert xmm outputs into final output zmm8 reg based on regidx + if(regidx == 0) zmm8 = _mm512_insertf32x4(zmm8, xmm2, 0); + else zmm8 = _mm512_insertf32x4(zmm8, xmm2, 2); + regidx++; + } + + // Dot product for <= 3 + if (mr0_use) + { + // Dot product for m = 2 + if (mr0_use >= 2) + { + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_ps(b_use); // Load 0-15 in b[k+0 - k+15] + // Load 2x16 elements from row0-row1 of A + zmm0 = _mm512_loadu_ps(a_use); + zmm1 = _mm512_loadu_ps(a_use + rs_a); + zmm20 = _mm512_fmadd_ps(zmm0, zmm6, zmm20); + zmm21 = _mm512_fmadd_ps(zmm1, zmm6, zmm21); + b_use += 16; // move b pointer to next 16 elements + a_use += 16; + } + if (k_rem) + { + zmm6 = _mm512_mask_loadu_ps(zmm7, k1, b_use); // Load 0-15 in b[k+0 - k+15] + zmm0 = _mm512_mask_loadu_ps(zmm7, k1, a_use); // Load 0-15 in b[k+0 - k+15] + zmm1 = _mm512_mask_loadu_ps(zmm7, k1, a_use + rs_a); // Load 0-15 in b[k+0 - k+15] + zmm20 = _mm512_fmadd_ps(zmm0, zmm6, zmm20); + zmm21 = _mm512_fmadd_ps(zmm1, zmm6, zmm21); + } + mr0_use -= 2; + a_use = a_use_fringe + 2 * rs_a; + a_use_fringe = a_use; + b_use = b; + } + + // Dot product for m = 2 + if (mr0_use == 1) + { + for (dim_t k = 0; k < k_iter; k++) + { + zmm6 = _mm512_loadu_ps(b_use); // Load 0-15 in b[k+0 - k+15] + zmm0 = _mm512_loadu_ps(a_use); + zmm22 = _mm512_fmadd_ps(zmm0, zmm6, zmm22); + b_use += 16; // move b pointer to next 16 elements + a_use += 16; + } + + if (k_rem) + { + zmm6 = _mm512_mask_loadu_ps(zmm7, k1, b_use); + zmm0 = _mm512_mask_loadu_ps(zmm7, k1, a_use); + zmm22 = _mm512_fmadd_ps(zmm22, zmm6, zmm0); + } + // When only fringe 1, update the registers to store in order + if (!(mr0 & 0x2)) zmm20 = zmm22; + } + + // Horizontal add 4 zmm reg and get the output into one xmm + LPGEMV_ZMM2XMM(zmm20, zmm21, zmm22, zmm23, ymm6, ymm1, ymm2, ymm3, xmm3) + + // insert xmm outputs into final output zmm8 reg based on regidx + if (regidx == 0) zmm8 = _mm512_insertf32x4(zmm8, xmm3, 0); + else if(regidx == 1) zmm8 = _mm512_insertf32x4(zmm8, xmm3, 1); + else if (regidx == 2) zmm8 = _mm512_insertf32x4(zmm8, xmm3, 2); + else zmm8 = _mm512_insertf32x4(zmm8, xmm3, 3); + } + } + + //Scale accumulated output with alpha + zmm0 = _mm512_set1_ps(alpha); + zmm8 = _mm512_mul_ps(zmm0, zmm8); + + if (beta != 0) + { + const float *_cbuf = c_use; + + //C = beta*C + alpha*A*B + zmm3 = _mm512_set1_ps(beta); + if (rs_c == 1) + { + zmm0 = _mm512_maskz_loadu_ps(k2, _cbuf); + }else + { + //load C into zmm0 + float ctemp[16]; + for(dim_t i = 0; i < mr0; i++) + { + ctemp[i] = _cbuf[i * rs_c]; + } + zmm0 = _mm512_maskz_loadu_ps(k2, ctemp); + } + zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8); + } + + // Post Ops + post_ops_attr.is_last_k = TRUE; + lpgemm_post_op *post_ops_list_temp = post_op_list; + POST_OP_LABEL_LASTK_SAFE_JUMP + + POST_OPS_BIAS_6x64F: + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + zmm9 = _mm512_set1_ps(*((float *)post_ops_list_temp->op_args1)); + zmm8 = _mm512_add_ps(zmm9, zmm8); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_6x64F: + { + zmm1 = _mm512_setzero_ps(); + + // c[0,0-15] + zmm8 = _mm512_max_ps(zmm1, zmm8); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_RELU_SCALE_6x64F: + { + zmm1 = _mm512_setzero_ps(); + zmm2 = + _mm512_set1_ps(*((float *)post_ops_list_temp->op_args2)); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32S_AVX512(zmm8) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_TANH_6x64F: + { + __m512i zmm6; + // c[0, 0-15] + GELU_TANH_F32S_AVX512(zmm8, zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_GELU_ERF_6x64F: + { + // c[0, 0-15] + GELU_ERF_F32S_AVX512(zmm8, zmm0, zmm1, zmm2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_CLIP_6x64F: + { + zmm0 = _mm512_set1_ps(*(float *)post_ops_list_temp->op_args2); + zmm1 = _mm512_set1_ps(*(float *)post_ops_list_temp->op_args3); + + // c[0, 0-15] + CLIP_F32S_AVX512(zmm8, zmm0, zmm1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_MATRIX_ADD_6x64F: + { + float *matptr = (float *)post_ops_list_temp->op_args1; + zmm0 = _mm512_maskz_loadu_ps(k2, (matptr + post_ops_attr.post_op_c_i)); + zmm8 = _mm512_add_ps(zmm8, zmm0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_6x64F_DISABLE: + { + if (rs_c == 1) + { + _mm512_mask_storeu_ps(c_use, k2, zmm8); + } + else + { + // Store ZMM8 into ctemp buffer and store back + // element by element into output buffer at strides + float ctemp[16]; + _mm512_mask_storeu_ps(ctemp, k2, zmm8); + for (dim_t i = 0; i < mr0; i++) + { + c_use[i * rs_c] = ctemp[i]; + } + } + post_ops_attr.post_op_c_i += MR; + } + } // mr loop +} + +#endif // BLIS_ADDON_LPGEMM \ No newline at end of file