From 24f228df3bf7e9f4cf1d25ec4a2f6bbee62e0480 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 28 Jul 2025 11:34:07 -0700 Subject: [PATCH] upgrade from clang-format-12 to clang-format-18 (#2568) * upgrade to clang-format-18 * update to clang-format-18 in pre-commit-config [ROCm/composable_kernel commit: 504b101da33bd1ae2b39e13342c961eb0ddb4458] --- .pre-commit-config.yaml | 2 +- Dockerfile | 1 + Jenkinsfile | 4 +- .../grouped_conv2d_fwd_ngchw.cpp | 6 +- .../grouped_conv2d_bwd_data.cpp | 6 +- .../grouped_conv2d_bwd_data_ngchw.cpp | 6 +- .../grouped_conv3d_bwd_data.cpp | 6 +- ..._conv3d_bwd_data_input_fp16_comp_bf8f8.cpp | 6 +- .../elementwise_layernorm2d.cpp | 2 +- client_example/15_reduce/reduce_nhwc_c.cpp | 18 +- ...d_conv_bwd_data_bilinear_residual_fp16.cpp | 6 +- .../grouped_conv_bwd_data_scale_fp16.cpp | 6 +- ...rouped_conv_fwd_bilinear_residual_fp16.cpp | 6 +- .../common.hpp | 32 +-- .../grouped_conv_fwd_scale_fp16.cpp | 6 +- .../grouped_conv_fwd_scaleadd_ab.inc | 4 +- client_example/25_wrapper/wrapper_img2col.cpp | 6 +- codegen/include/ck/host/stringutils.hpp | 5 +- ...wd_multiple_abd_operation_xdl_cshuffle.cpp | 11 +- codegen/test/batched_gemm_softmax_gemm.cpp | 12 +- codegen/test/gemm_multiple_d.cpp | 10 +- codegen/test/rtc/include/rtc/tmp_dir.hpp | 2 +- .../Composable-Kernel-prerequisites.rst | 2 +- example/01_gemm/gemm_xdl_fp64.cpp | 11 +- example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp | 6 +- example/12_reduce/reduce_blockwise_impl.hpp | 2 +- .../gemm_reduce_xdl_common.hpp | 6 +- .../batched_gemm_reduce_xdl_fp16.cpp | 6 +- .../run_layernorm_example.inc | 4 +- ...rouped_gemm_scale_softmax_gemm_permute.inc | 8 +- .../sparse_embedding3_forward_layernorm.cpp | 8 +- example/39_permute/common.hpp | 13 +- .../run_groupnorm_fwd_example.inc | 4 +- ...entwise_scale_permute_amax_2D_fp16_fp8.cpp | 6 +- .../contraction_multi_ABD_xdl_fp16.cpp | 2 +- .../contraction_multi_ABD_xdl_fp8.cpp | 4 +- .../convnd_fwd_convscale_reduce_common.hpp | 8 +- .../run_layernorm4d_fwd_example.inc | 4 +- .../moe_gemm1_xdl_pk_i4.cpp | 2 +- .../02_layernorm2d/layernorm2d_fwd.cpp | 3 +- .../matrix_core_swizzle_kernel.hpp | 14 +- .../10_rmsnorm2d/example_rmsnorm2d_fwd.cpp | 3 +- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 3 +- .../add_rmsnorm2d_rdquant_fwd.cpp | 4 +- .../example_add_rmsnorm2d_rdquant_fwd.cpp | 3 +- .../12_smoothquant/example_smoothquant.cpp | 7 +- .../ck_tile/12_smoothquant/smoothquant.cpp | 5 +- .../ck_tile/13_moe_sorting/moe_sorting.cpp | 30 +-- .../13_moe_sorting/moe_sorting_api.cpp | 60 +++--- .../14_moe_smoothquant/moe_smoothquant.cpp | 6 +- .../15_fused_moe/instances/fused_moe_api.cpp | 38 ++-- .../instances/fused_moegemm_api_internal.hpp | 10 +- .../instances/fused_moesorting_api.cpp | 60 +++--- example/ck_tile/15_fused_moe/main.cpp | 3 +- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 7 +- .../17_grouped_gemm/grouped_gemm_tileloop.cpp | 7 +- .../run_grouped_gemm_example.inc | 6 +- .../19_gemm_multi_d/gemm_multi_d_fp16.cpp | 7 +- ...grouped_convolution_bwd_weight_example.inc | 3 +- .../38_block_scale_gemm/gemm_aquant_basic.cpp | 48 +++-- example/ck_tile/remod.py | 2 +- include/ck/host_utility/hip_check_error.hpp | 5 +- include/ck/library/utility/algorithm.hpp | 8 +- include/ck/library/utility/fill.hpp | 7 +- include/ck/library/utility/host_tensor.hpp | 4 +- .../ck/tensor_description/tensor_adaptor.hpp | 24 +-- .../tensor_description/tensor_descriptor.hpp | 12 +- .../tensor_space_filling_curve.hpp | 6 +- ...blockwise_gemm_mx_pipeline_xdlops_base.hpp | 2 +- .../blockwise_gemm_pipeline_wmmaops_base.hpp | 4 +- .../block/blockwise_gemm_pipeline_xdlops.hpp | 8 +- .../blockwise_gemm_pipeline_xdlops_base.hpp | 4 +- .../block/blockwise_gemm_smfmac_xdlops.hpp | 4 +- .../gpu/block/blockwise_gemm_xdlops.hpp | 12 +- .../blockwise_gemm_xdlops_skip_b_lds.hpp | 2 +- ...roup_tensor_slice_transfer_direct_load.hpp | 6 +- ...nsor_slice_transfer_gather_direct_load.hpp | 12 +- .../gpu/device/device_base.hpp | 12 +- .../gpu/device/device_grouped_gemm.hpp | 12 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 36 ++-- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 36 ++-- .../device_batched_gemm_e_permute_xdl.hpp | 28 +-- .../device_batched_gemm_gemm_xdl_cshuffle.hpp | 37 ++-- .../impl/device_batched_gemm_multi_d_xdl.hpp | 34 ++-- .../device_batched_gemm_multiple_d_dl.hpp | 32 +-- ...ultiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 56 +++--- ...atched_gemm_multiple_d_xdl_cshuffle_v3.hpp | 8 +- ...evice_batched_gemm_reduce_xdl_cshuffle.hpp | 38 ++-- ...emm_softmax_gemm_permute_wmma_cshuffle.hpp | 64 +++--- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 46 ++--- ...batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 40 ++-- .../device_batched_gemm_wmma_cshuffle_v3.hpp | 13 +- .../device/impl/device_batched_gemm_xdl.hpp | 4 +- ...evice_batched_gemm_xdl_fpAintB_b_scale.hpp | 8 +- .../impl/device_cgemm_4gemm_xdl_cshuffle.hpp | 4 +- .../impl/device_column_to_image_impl.hpp | 12 +- ..._contraction_multiple_abd_xdl_cshuffle.hpp | 32 +-- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 32 +-- .../device/impl/device_contraction_utils.hpp | 10 +- ...fle_bias_activation_add_nhwc_kyxc_nhwk.hpp | 5 +- ...shuffle_bias_activation_nhwc_kyxc_nhwk.hpp | 5 +- ...onv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 5 +- ...evice_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp | 32 +-- .../device/impl/device_gemm_multiple_d_dl.hpp | 28 +-- ...gemm_multiple_d_layernorm_xdl_cshuffle.hpp | 85 ++++---- ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 40 ++-- .../device_gemm_multiple_d_xdl_cshuffle.hpp | 30 +-- .../device_gemm_xdl_waveletmodel_cshuffle.hpp | 25 ++- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 14 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 28 +-- .../device_grouped_conv_bwd_weight_dl.hpp | 22 +- ...e_grouped_conv_bwd_weight_explicit_xdl.hpp | 30 +-- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 29 ++- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 36 ++-- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 29 ++- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 36 ++-- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 32 +-- ...ice_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp | 22 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 66 +++--- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 88 ++++---- ...fwd_multiple_d_multiple_r_xdl_cshuffle.hpp | 44 ++-- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 18 +- ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 14 +- .../device_grouped_gemm_multiple_d_dl.hpp | 12 +- ...ltiple_d_splitk_xdl_cshuffle_two_stage.hpp | 20 +- ...gemm_multiple_d_xdl_cshuffle_tile_loop.hpp | 20 +- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 18 +- .../device/impl/device_grouped_gemm_xdl.hpp | 12 +- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 20 +- ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 16 +- ...e_grouped_query_attention_forward_wmma.hpp | 28 +-- .../impl/device_moe_gemm_blockscale.hpp | 116 +++++------ .../impl/device_moe_mx_gemm_bpreshuffle.hpp | 112 +++++------ ...ice_multi_query_attention_forward_wmma.hpp | 28 +-- ...tk_contraction_multiple_d_xdl_cshuffle.hpp | 36 ++-- .../gpu/device/masking_specialization.hpp | 2 +- .../element/unary_element_wise_operation.hpp | 2 +- ...iple_d_welford_first_half_xdl_cshuffle.hpp | 17 +- ...idwise_2d_reduction_threadwise_multi_d.hpp | 5 +- ...iple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp | 34 ++-- .../gpu/grid/gridwise_elementwise_2d.hpp | 99 +++++---- .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 28 +-- ...e_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 48 ++--- .../gpu/grid/gridwise_gemm_dl_v1r3.hpp | 16 +- .../gpu/grid/gridwise_gemm_dpp.hpp | 21 +- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 24 +-- ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 2 +- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 103 +++++----- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 24 +-- ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 70 ++++--- ...se_gemm_multiple_d_xdl_splitk_cshuffle.hpp | 24 +-- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 34 ++-- ...e_gemm_split_k_multiple_d_xdl_cshuffle.hpp | 34 ++-- ...emm_split_k_multiple_d_xdl_cshuffle_v2.hpp | 24 +-- .../gpu/grid/gridwise_gemm_wmma.hpp | 24 +-- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 19 +- ...gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp | 22 +- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 4 +- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 19 +- .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 33 +-- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 43 ++-- .../grid/gridwise_gemm_xdl_cshuffle_v2.hpp | 41 ++-- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 23 +-- ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 33 ++- .../gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp | 26 +-- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 49 ++--- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 51 ++--- ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 35 ++-- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 61 +++--- ...fle_v3_multi_d_blockscale_b_preshuffle.hpp | 51 ++--- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 32 +-- ...se_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 28 +-- ...ridwise_gemm_xdl_layernorm_cshuffle_v1.hpp | 38 ++-- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 54 ++--- .../gridwise_gemm_xdlops_skip_b_lds_v1.hpp | 24 +-- ...ise_gemm_xdlops_splitk_lds_direct_load.hpp | 27 +-- .../gpu/grid/gridwise_gemm_xdlops_streamk.hpp | 34 ++-- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 50 ++--- .../gpu/grid/gridwise_gemm_xdlops_v2r4.hpp | 22 +- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 27 +-- .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 26 +-- .../gpu/grid/gridwise_gemm_xdlops_v3r2.hpp | 32 +-- .../gpu/grid/gridwise_gemm_xdlops_v3r3.hpp | 38 ++-- .../gpu/grid/gridwise_moe_gemm.hpp | 190 ++++++++---------- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 190 ++++++++---------- .../gpu/grid/gridwise_moe_mx_gemm.hpp | 113 +++++------ .../gpu/grid/gridwise_moe_mx_gemm_bns.hpp | 123 +++++------- .../grid/gridwise_moe_mx_gemm_bpreshuffle.hpp | 115 +++++------ .../gpu/grid/gridwise_permute.hpp | 2 +- .../gpu/grid/gridwise_tensor_rearrange.hpp | 16 +- .../gridwise_normalization_bwd_data.hpp | 2 +- .../threadwise_tensor_slice_transfer.hpp | 12 +- .../threadwise_tensor_slice_transfer_v3r1.hpp | 44 ++-- ...ise_tensor_slice_transfer_v3r1_dequant.hpp | 18 +- ...wise_tensor_slice_transfer_v3r1_gather.hpp | 12 +- .../threadwise_tensor_slice_transfer_v3r2.hpp | 12 +- .../threadwise_tensor_slice_transfer_v5r1.hpp | 12 +- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 12 +- include/ck/utility/amd_ck_fp8.hpp | 10 +- include/ck/utility/container_helper.hpp | 2 +- include/ck/utility/dynamic_buffer.hpp | 2 +- include/ck/utility/is_detected.hpp | 4 +- include/ck/utility/magic_division.hpp | 12 +- include/ck/utility/sequence.hpp | 4 +- include/ck/utility/type_convert.hpp | 14 +- include/ck/wrapper/tensor.hpp | 22 +- .../core/algorithm/coordinate_transform.hpp | 2 +- .../core/algorithm/space_filling_curve.hpp | 6 +- .../core/arch/amd_buffer_addressing.hpp | 28 +-- include/ck_tile/core/arch/arch.hpp | 2 +- .../core/container/container_helper.hpp | 2 +- include/ck_tile/core/container/sequence.hpp | 5 +- include/ck_tile/core/numeric/float8.hpp | 2 +- include/ck_tile/core/numeric/math.hpp | 66 +++--- .../core/tensor/load_tile_transpose.hpp | 10 +- include/ck_tile/core/tensor/sweep_tile.hpp | 2 +- .../ck_tile/core/tensor/tensor_adaptor.hpp | 32 ++- .../ck_tile/core/tensor/tile_distribution.hpp | 10 +- .../ck_tile/core/tensor/tile_elementwise.hpp | 5 +- .../core/tensor/tile_window_linear.hpp | 15 +- include/ck_tile/core/utility/debug.hpp | 6 +- include/ck_tile/core/utility/type_traits.hpp | 4 +- .../core/utility/unary_element_function.hpp | 6 +- include/ck_tile/host/concat.hpp | 19 +- include/ck_tile/host/fill.hpp | 25 ++- include/ck_tile/host/host_tensor.hpp | 2 +- include/ck_tile/host/joinable_thread.hpp | 2 +- .../host/reference/reference_moe_sorting.hpp | 2 +- .../ops/epilogue/cshuffle_epilogue.hpp | 8 +- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 2 +- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 12 +- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 2 +- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 4 +- .../fused_moe/kernel/fused_moegemm_kernel.hpp | 2 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 2 +- .../fused_moegemm_pipeline_flatmm_ex.hpp | 50 +++-- .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 17 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 8 +- ...peline_ag_bg_cr_comp_v4_default_policy.hpp | 20 +- ...peline_ag_bg_cr_comp_v5_default_policy.hpp | 12 +- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 4 +- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 12 +- .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 4 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 26 +-- .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 2 +- ...wp_pipeline_agmem_bgmem_creg_v1_policy.hpp | 12 +- .../block_universal_gemm_as_aquant_bs_cr.hpp | 25 +-- .../gemm_aquant_pipeline_ag_bg_cr_policy.hpp | 12 +- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 3 +- ...ped_convolution_backward_weight_kernel.hpp | 71 +++---- .../grouped_convolution_forward_kernel.hpp | 84 ++++---- .../utils/grouped_convolution_utils.hpp | 10 +- .../ck_tile/ops/reduce/block/block_reduce.hpp | 2 +- include/ck_tile/ref/naive_attention.hpp | 24 +-- include/ck_tile/remod.py | 16 +- .../cpu/reference_moe_gemm.hpp | 2 +- .../cpu/reference_moe_gemm1_blockscale.hpp | 2 +- .../gpu/reference_gemm.hpp | 20 +- .../device_column_to_image_instance.hpp | 36 ++-- .../device_image_to_column_instance.hpp | 36 ++-- ...p_gemm_xdl_universal_km_kn_mn_instance.hpp | 9 +- ...ce_grouped_conv_bwd_weight_dl_instance.hpp | 27 +-- ..._grouped_conv_bwd_weight_wmma_instance.hpp | 18 +- ...al_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp | 9 +- ...al_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp | 9 +- ...al_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp | 9 +- ...al_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp | 9 +- ...ersal_f16_f16_f16_gkm_gkn_gmn_instance.cpp | 9 +- ...ersal_f16_f16_f16_gkm_gnk_gmn_instance.cpp | 9 +- ...ersal_f16_f16_f16_gmk_gkn_gmn_instance.cpp | 9 +- ...ersal_f16_f16_f16_gmk_gnk_gmn_instance.cpp | 9 +- ...dl_int8_int8_int8_gkm_gnk_gmn_instance.cpp | 9 +- ..._data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp | 9 +- ..._shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp | 9 +- ..._shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp | 9 +- ..._shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp | 9 +- ...l_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp | 9 +- ...l_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp | 9 +- ...l_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp | 9 +- ...l_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp | 18 +- .../km_kn_mn_default_pipeline_v1_instance.cpp | 9 +- .../km_kn_mn_default_pipeline_v2_instance.cpp | 7 +- ...kn_mn_default_pipeline_v2_opt_instance.cpp | 7 +- ...m_kn_mn_interwave_pipeline_v1_instance.cpp | 7 +- .../km_nk_mn_default_pipeline_v1_instance.cpp | 9 +- .../km_nk_mn_default_pipeline_v2_instance.cpp | 7 +- ...nk_mn_default_pipeline_v2_opt_instance.cpp | 7 +- ...m_nk_mn_interwave_pipeline_v1_instance.cpp | 7 +- .../mk_kn_mn_default_pipeline_v1_instance.cpp | 9 +- .../mk_kn_mn_default_pipeline_v2_instance.cpp | 7 +- ...kn_mn_default_pipeline_v2_opt_instance.cpp | 7 +- ...k_kn_mn_interwave_pipeline_v1_instance.cpp | 7 +- ...gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp | 9 +- ...gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp | 9 +- ...gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp | 9 +- ...gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp | 9 +- ...gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp | 9 +- ...gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp | 9 +- ...gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp | 9 +- ...gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp | 9 +- ...le_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp | 9 +- ...wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp | 9 +- ...wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp | 9 +- ...wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 9 +- ...wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp | 9 +- ...mm_wmma_universal_f16_f16_f16_km_kn_mn.hpp | 9 +- ...mm_wmma_universal_f16_f16_f16_km_nk_mn.hpp | 9 +- ...mm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp | 9 +- ...mm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp | 9 +- ...emm_wmma_universal_f16_f8_f16_km_kn_mn.hpp | 9 +- ...emm_wmma_universal_f16_f8_f16_km_nk_mn.hpp | 9 +- ...emm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp | 9 +- ...emm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp | 9 +- ...emm_wmma_universal_f8_f16_f16_km_kn_mn.hpp | 9 +- ...emm_wmma_universal_f8_f16_f16_km_nk_mn.hpp | 9 +- ...emm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp | 9 +- ...emm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp | 9 +- ..._xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp | 9 +- ..._xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp | 9 +- ..._xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 9 +- ..._xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp | 9 +- ...mm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp | 9 +- ...emm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp | 9 +- ...emm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp | 9 +- ...gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp | 9 +- ...gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp | 9 +- ...gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp | 9 +- ...gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp | 9 +- ...gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp | 9 +- ...versal_streamk_bf16_bf16_bf16_km_kn_mn.hpp | 9 +- ...versal_streamk_bf16_bf16_bf16_km_nk_mn.hpp | 9 +- ...versal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp | 9 +- ...versal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp | 9 +- ...universal_streamk_f16_f16_f16_mk_kn_mn.hpp | 9 +- ...universal_streamk_f16_f16_f16_mk_nk_mn.hpp | 9 +- ..._universal_streamk_f16_f8_f16_mk_kn_mn.hpp | 7 +- ..._universal_streamk_f16_f8_f16_mk_nk_mn.hpp | 7 +- ..._universal_streamk_f8_f16_f16_mk_kn_mn.hpp | 7 +- ..._universal_streamk_f8_f16_f16_mk_nk_mn.hpp | 7 +- ...le_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp | 9 +- library/src/utility/convolution_parameter.cpp | 5 +- .../profiler/profile_conv_bwd_data_impl.hpp | 6 +- .../profiler/profile_conv_fwd_impl.hpp | 6 +- .../profile_conv_tensor_rearrange_impl.hpp | 5 +- .../profile_grouped_conv_bwd_data_impl.hpp | 7 +- .../profile_grouped_conv_bwd_weight_impl.hpp | 19 +- ...ofile_grouped_conv_fwd_bias_clamp_impl.hpp | 10 +- .../profile_grouped_conv_fwd_impl.hpp | 6 +- ...ile_grouped_conv_fwd_outelementop_impl.hpp | 6 +- .../include/profiler/profile_softmax_impl.hpp | 23 +-- profiler/src/profile_contraction_bilinear.cpp | 3 +- profiler/src/profile_contraction_scale.cpp | 3 +- script/clang-format-overwrite.sh | 4 +- .../add_rmsnorm2d_rdquant_fwd.inc | 4 +- test/ck_tile/data_type/test_pk_int4.cpp | 8 +- .../elementwise/test_elementwise_1d.cpp | 18 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 7 +- .../test_run_gemm_aquant_example.inc | 48 +++-- .../test_gemm_pipeline_util.hpp | 7 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 29 ++- test/ck_tile/layernorm2d/layernorm2d_fwd.inc | 3 +- .../moe_smoothquant/moe_smoothquant.inc | 6 +- test/ck_tile/moe_sorting/moe_sorting_api.cpp | 60 +++--- test/ck_tile/moe_sorting/moe_sorting_fp32.cpp | 30 +-- .../matrix_core_swizzle_kernel.hpp | 14 +- test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc | 3 +- test/ck_tile/smoothquant/smoothquant.inc | 5 +- test/data_type/test_pk_i4.cpp | 8 +- test/mx_mfma_op/mx_mfma_op.cpp | 180 ++++++++--------- test/pool/test_max_pool2d_fwd.cpp | 4 +- .../reference_conv_fwd/reference_conv_fwd.cpp | 12 +- tile_engine/ops/gemm/benchmark_gemm.hpp | 6 +- tile_engine/ops/gemm/gemm_profiler.hpp | 8 +- 373 files changed, 3351 insertions(+), 3760 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e4e85651f6..664c5219e2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: hooks: - id: clang-format name: clang-format - entry: clang-format-12 -i --style=file + entry: clang-format-18 -i --style=file language: system types_or: [c++, inc] - id: copyright-year-checker diff --git a/Dockerfile b/Dockerfile index 0219f99238..6f5cd0115d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -62,6 +62,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libzstd-dev \ openssh-server \ clang-format-12 \ + clang-format-18 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ diff --git a/Jenkinsfile b/Jenkinsfile index 7a8452f25e..b34e366f1b 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -994,7 +994,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ @@ -1023,7 +1023,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'" + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp index 480abf23d2..13f1a3acc1 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp @@ -107,14 +107,14 @@ int execute_conv_fwd() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp index ae5f1b6f6e..f31ffe302a 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp @@ -130,14 +130,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp index 2309d757f0..a9918f6ab3 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp @@ -105,14 +105,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp index 93709a7901..baa2b02bce 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp @@ -109,14 +109,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp index a62a1d911b..ac7eb3cf41 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp @@ -111,14 +111,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index 69d7c8936c..37cafc190e 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -59,7 +59,7 @@ int main() SimpleDeviceMem y_dev_buf(sizeof(YDataType) * mn_size); std::array ab_input = {a_dev_buf.GetDeviceBuffer(), - b_dev_buf.GetDeviceBuffer()}; + b_dev_buf.GetDeviceBuffer()}; std::vector abStride = {Stride, 1}; std::array, 2> abStrides = {abStride, abStride}; diff --git a/client_example/15_reduce/reduce_nhwc_c.cpp b/client_example/15_reduce/reduce_nhwc_c.cpp index e2b1fbcb54..12aa31dec3 100644 --- a/client_example/15_reduce/reduce_nhwc_c.cpp +++ b/client_example/15_reduce/reduce_nhwc_c.cpp @@ -68,15 +68,15 @@ int main(int argc, char* argv[]) SimpleDeviceMem out(sizeof(OutDataType) * num_out_elements); using DeviceOp = ck::tensor_operation::device::DeviceReduce; + AccDataType, + OutDataType, + Rank, + NumReduceDim, + ReduceAdd, + PassThrough, + UnaryDivide, + PropagateNan, + OutputIndex>; const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp index bb106e8d8e..e8e33a3de2 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp @@ -117,14 +117,14 @@ int execute_conv_bwd_data_bilinear() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {in.GetDeviceBuffer()}, + {in.GetDeviceBuffer()}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {in_lengths}, - {in_strides}, + {in_lengths}, + {in_strides}, in_lengths, in_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp index e53ecc6c99..d81b5fd03e 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp @@ -116,14 +116,14 @@ int execute_conv_bwd_data_scale() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp index 32ab481319..2ec70b8b9b 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp @@ -121,14 +121,14 @@ int execute_conv_fwd_bilinear() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {out.GetDeviceBuffer()}, + {out.GetDeviceBuffer()}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {out_lengths}, - {out_strides}, + {out_lengths}, + {out_strides}, out_lengths, out_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp index c78cacf266..98f41dc7fb 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp @@ -222,13 +222,13 @@ bool run_grouped_conv_fwd_convscale_reduce( ck::tensor_operation::element_wise::Scale{scale_wei}, {}}; auto conv_ok = ConvolutionScale(in, + WeiDataType, + ConvOutDataType, + ConvElementOp, + InLayout, + WeiLayout, + OutLayout, + NumDimSpatial>(in, wei, conv_out, elementwise_op, @@ -717,15 +717,15 @@ bool TensorFullReduction(SimpleDeviceMem& tensor, { std::cout << "\nReduction of spatial dimensions:" << std::endl; using DeviceOp = ck::tensor_operation::device::DeviceReduce; // OutputIndex + OutDataType, + OutDataType, + NumDimSpatial, + NumDimSpatial, + ReduceOperation, + PassThrough, + AccElementwiseOperation, + true, // PropagateNan + false>; // OutputIndex const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp index 11e69f5bb2..11f24b39c7 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp @@ -120,14 +120,14 @@ int execute_conv_fwd_scale() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc index 3f6f7b0773..4cf3a4cf82 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc @@ -129,8 +129,8 @@ int execute_conv_fwd_scaleadd_ab() in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp index ceccc5eb8f..f7f893fda2 100644 --- a/client_example/25_wrapper/wrapper_img2col.cpp +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -132,9 +132,9 @@ void PerformImageToColumnPad0(const ck::index_t G, ck::wrapper::size<0>(tile_shape)); const auto kernel = DeviceImageToColumnPad0; + decltype(output_tensor_global), + decltype(tile_shape), + decltype(thread_layout)>; const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, kernel, dim3(grid_size_x, grid_size_y, 1), diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp index 89c1884d2e..81b312ec95 100644 --- a/codegen/include/ck/host/stringutils.hpp +++ b/codegen/include/ck/host/stringutils.hpp @@ -91,8 +91,9 @@ inline auto Transform(const Range& r, F f) -> std::vector -inline auto Transform(const Range1& r1, const Range2& r2, F f) - -> std::vector +inline auto Transform(const Range1& r1, + const Range2& r2, + F f) -> std::vector { std::vector result; assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end())); diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp index 36c9a13b4c..a2f322c50f 100644 --- a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp @@ -142,12 +142,11 @@ std::vector Operation_Conv_Fwd_Xdl_Cshuffle::Cr x.A = TensorDesc{prob.ADataType, prob.ALayout}; x.B = TensorDesc{prob.BDataType, prob.BLayout}; x.E = TensorDesc{prob.EDataType, prob.ELayout}; - x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { - return TensorDesc{dt, lo}; - }); - x.a_elem_op = prob.AElementOp; - x.b_elem_op = prob.BElementOp; - x.cde_elem_op = prob.CDEElementOp; + x.Ds = Transform( + prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { return TensorDesc{dt, lo}; }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; x.update_prologue(prologue); x.update_epilogue(epilogue); result.push_back(x); diff --git a/codegen/test/batched_gemm_softmax_gemm.cpp b/codegen/test/batched_gemm_softmax_gemm.cpp index 13035df355..98e78fc148 100644 --- a/codegen/test/batched_gemm_softmax_gemm.cpp +++ b/codegen/test/batched_gemm_softmax_gemm.cpp @@ -55,12 +55,12 @@ TEST_CASE(test_problem_kernel) std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}, - {"o", std::to_string(prob.O)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}, + {"o", std::to_string(prob.O)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index adc8e1ff02..dd908e8b58 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -60,11 +60,11 @@ TEST_CASE(test_problem_kernel) std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp index 2f3b26cc43..f4983debd9 100644 --- a/codegen/test/rtc/include/rtc/tmp_dir.hpp +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -16,7 +16,7 @@ struct tmp_dir void execute(const std::string& cmd) const; - tmp_dir(tmp_dir const&) = delete; + tmp_dir(tmp_dir const&) = delete; tmp_dir& operator=(tmp_dir const&) = delete; ~tmp_dir(); diff --git a/docs/install/Composable-Kernel-prerequisites.rst b/docs/install/Composable-Kernel-prerequisites.rst index 10be849ea6..9dc082599a 100644 --- a/docs/install/Composable-Kernel-prerequisites.rst +++ b/docs/install/Composable-Kernel-prerequisites.rst @@ -29,4 +29,4 @@ The following prerequisites are required to build and install Composable Kernel: * zlib1g-dev * libzstd-dev * openssh-server -* clang-format-12 +* clang-format-18 diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 5afb3d1554..b55627f3ee 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -31,15 +31,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl #else < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; #endif - // clang-format on +// clang-format on - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; template std::ostream& show_2d_matrix(std::ostream& os, Tensor& matrix) diff --git a/example/12_reduce/reduce_blockwise_impl.hpp b/example/12_reduce/reduce_blockwise_impl.hpp index f1225d86e4..57a86a9dc4 100644 --- a/example/12_reduce/reduce_blockwise_impl.hpp +++ b/example/12_reduce/reduce_blockwise_impl.hpp @@ -117,7 +117,7 @@ int reduce_blockwise_impl(bool do_verification, using InOutDataTypeInDevice = typename std:: conditional::value, int8_t, InOutDataType>::type; #else - using InOutDataTypeInDevice = InOutDataType; + using InOutDataTypeInDevice = InOutDataType; #endif using DeviceReduceInstance = diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp index 1bea1bcf3e..3e3c586dba 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp @@ -175,15 +175,15 @@ auto run_gemm_reduce_max_xdl(ck::index_t M, auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), - {}, + {}, e_device_buf.GetDeviceBuffer(), - {r0_device_buf.GetDeviceBuffer()}, + {r0_device_buf.GetDeviceBuffer()}, M, N, K, StrideA, StrideB, - {}, + {}, StrideE, a_element_op, b_element_op, diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index 62295c57eb..42bfea372e 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -207,7 +207,7 @@ int main(int argc, char* argv[]) auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), nullptr, - {}, + {}, c_device_buf.GetDeviceBuffer(), p_reduces, M, @@ -216,9 +216,9 @@ int main(int argc, char* argv[]) StrideA, StrideB, StrideC, - {}, + {}, gemm_element_ops, - {}, + {}, reduce_in_element_ops, reduce_out_element_ops, BatchCount); diff --git a/example/27_layernorm2d_fwd/run_layernorm_example.inc b/example/27_layernorm2d_fwd/run_layernorm_example.inc index 23608a1eea..02b60fe548 100644 --- a/example/27_layernorm2d_fwd/run_layernorm_example.inc +++ b/example/27_layernorm2d_fwd/run_layernorm_example.inc @@ -44,9 +44,9 @@ int run_layernorm2d_fwd_example() {0, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1}, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index cdfd86dff4..c693995140 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -126,10 +126,10 @@ int run(int argc, char* argv[]) if(i < 4) { - std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " - << "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", " - << "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", " - << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; + std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " << "b0_gs_ns_ks[" + << i << "]: " << b0_gs_ns_ks.mDesc << ", " << "b1_gs_os_ns[" << i + << "]: " << b1_gs_os_ns.mDesc << ", " << "c_gs_ms_os[" << i + << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; } switch(init_method) diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index d2337dcda5..26a03f289d 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -129,11 +129,11 @@ int main() auto argument_ptr = device_instance.MakeArgumentPointer( out_dev.GetDeviceBuffer(), {ck::type_convert(emb_a_dev.GetDeviceBuffer()), - ck::type_convert(emb_b_dev.GetDeviceBuffer()), - ck::type_convert(emb_c_dev.GetDeviceBuffer())}, + ck::type_convert(emb_b_dev.GetDeviceBuffer()), + ck::type_convert(emb_c_dev.GetDeviceBuffer())}, {ck::type_convert(index_a_dev.GetDeviceBuffer()), - ck::type_convert(index_b_dev.GetDeviceBuffer()), - ck::type_convert(index_c_dev.GetDeviceBuffer())}, + ck::type_convert(index_b_dev.GetDeviceBuffer()), + ck::type_convert(index_c_dev.GetDeviceBuffer())}, gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), current_dim, diff --git a/example/39_permute/common.hpp b/example/39_permute/common.hpp index 54f3a78809..b23128a536 100644 --- a/example/39_permute/common.hpp +++ b/example/39_permute/common.hpp @@ -249,8 +249,8 @@ inline auto to_array(Range& range) noexcept } template -inline auto is_valid_axes(const Axes& axes) - -> std::enable_if_t, bool> +inline auto +is_valid_axes(const Axes& axes) -> std::enable_if_t, bool> { using std::empty; if(empty(axes)) @@ -357,10 +357,11 @@ auto extend_axes(const Problem::Axes& axes) } template -auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t< - detail::is_bidirectional_range_v && detail::is_sized_range_v && - detail::is_bidirectional_range_v && detail::is_sized_range_v, - bool> +auto advance_indices(const Shape& shape, Indices& indices) + -> std::enable_if_t< + detail::is_bidirectional_range_v && detail::is_sized_range_v && + detail::is_bidirectional_range_v && detail::is_sized_range_v, + bool> { using std::size; if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices))) diff --git a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc index 853ff791a6..ab6f317bc6 100644 --- a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc +++ b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc @@ -65,9 +65,9 @@ int run_groupnorm_fwd_example(int argc, char* argv[]) {0, 0, 0, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1, 2, 4}, // reduction dimension: [H, W, C] 1e-6, x_dev.GetDeviceBuffer(), diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp index 9431a8cde4..c40447e1f9 100644 --- a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -152,7 +152,7 @@ int main(int argc, char* argv[]) std::array inputs = {input_dev_buf.GetDeviceBuffer()}; std::array outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(), - output_scaled_casted_dev_buf.GetDeviceBuffer()}; + output_scaled_casted_dev_buf.GetDeviceBuffer()}; std::cout << "Input: " << input.mDesc << std::endl; std::cout << "Scale: " << scale << std::endl; @@ -164,8 +164,8 @@ int main(int argc, char* argv[]) auto launch_transpose_scale = [&]() { auto transposeScale = DeviceElementwisePermuteInstance{}; auto argument = transposeScale.MakeArgumentPointer(dims, - {in_strides}, - {out_strides, in_strides}, + {in_strides}, + {out_strides, in_strides}, inputs, outputs, ScalePassThrough{scale}); diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index 8b88e2482d..e7c1d6f0be 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -213,7 +213,7 @@ int main(int argc, char* argv[]) auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument( std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, + a1_device_buf.GetDeviceBuffer()}, std::array{b_device_buf.GetDeviceBuffer()}, std::array{d_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp index eaabccdf2a..ec1b2d6018 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -194,9 +194,9 @@ int main(int argc, char* argv[]) auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument( std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, + a1_device_buf.GetDeviceBuffer()}, std::array{b0_device_buf.GetDeviceBuffer(), - b1_device_buf.GetDeviceBuffer()}, + b1_device_buf.GetDeviceBuffer()}, std::array{}, e_device_buf.GetDeviceBuffer(), std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp index 6940c20695..f521c51d67 100644 --- a/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp @@ -265,10 +265,10 @@ bool run_grouped_conv_fwd(bool do_verification, auto device_ew_scale = DeviceElementwiseScale{}; auto scale_invoker = device_ew_scale.MakeInvoker(); auto scale_argument = device_ew_scale.MakeArgument(e_g_n_k_wos_lengths, - {e_g_n_k_wos_strides}, - {e_g_n_k_wos_strides}, - {conv_device_buf.GetDeviceBuffer()}, - {out_device_buf.GetDeviceBuffer()}, + {e_g_n_k_wos_strides}, + {e_g_n_k_wos_strides}, + {conv_device_buf.GetDeviceBuffer()}, + {out_device_buf.GetDeviceBuffer()}, scale_convert); if(!device_ew_scale.IsSupportedArgument(scale_argument)) diff --git a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc index 1a0b558e2c..f75c01ec61 100644 --- a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc +++ b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc @@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example() {0, W * C, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1, 2, 3}, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 9e80a2ca35..f78e6e48a5 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -357,7 +357,7 @@ int main(int argc, char* argv[]) int n1 = n % NLane; int k0 = k / (KLane * KPack); - tempk = k % (KLane * KPack); + tempk = k % (KLane * KPack); int k1 = tempk / KPack; int k2 = tempk % KPack; diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index b72485222e..bdd5f2da1b 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -191,8 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << std::flush; diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp index 28f4c452bc..688f4f3d50 100644 --- a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -333,12 +333,12 @@ struct matrix_core_swizzle_kernel return tmp_1; #else // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, - constexpr index_t kv = Alignment; - constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t waveflatten = kw * nw * kv; - const index_t kr = a_.k / (k1 * k2); - const index_t nr = a_.n / nw; + const index_t kr = a_.k / (k1 * k2); + const index_t nr = a_.n / nw; auto tmp = make_naive_tensor_view_packed( p_dst, make_tuple(nr, kr, waveflatten), @@ -387,8 +387,8 @@ struct matrix_core_swizzle_kernel constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t waveflatten_tile = kw * nw * kv; - constexpr index_t nr_tile = NPerBlock / nw; - constexpr index_t kr_tile = KPerBlock / (kw * kv); + constexpr index_t nr_tile = NPerBlock / nw; + constexpr index_t kr_tile = KPerBlock / (kw * kv); return make_tile_window(dst_view, make_tuple(number{}, number{}, diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index 13924f5fe9..e0a71452ea 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -183,8 +183,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride << ", s:" << USEModelSensitive << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 049a0cad41..751b868411 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -193,8 +193,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp index 06c04b763e..1cd375d0f5 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -105,8 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser) b_buf.ToDevice(b_host.data()); gamma_buf.ToDevice(gamma_host.data()); - std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" << " m:" << m + << ", n:" << n << ", stride:" << stride << std::flush; add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX}; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index c43d9c9a2e..449bc17e04 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -256,8 +256,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } diff --git a/example/ck_tile/12_smoothquant/example_smoothquant.cpp b/example/ck_tile/12_smoothquant/example_smoothquant.cpp index 20e1591516..5fcacacee8 100644 --- a/example/ck_tile/12_smoothquant/example_smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/example_smoothquant.cpp @@ -216,10 +216,9 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride - << ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush - << std::endl; + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n + << ", x_stride:" << x_stride << ", y_stride:" << y_stride + << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } return pass; diff --git a/example/ck_tile/12_smoothquant/smoothquant.cpp b/example/ck_tile/12_smoothquant/smoothquant.cpp index f3ba587132..02ab1cd9b1 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/smoothquant.cpp @@ -93,9 +93,8 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); smscale_buf.ToDevice(smscale_host.data()); - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride - << std::flush; + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", y_stride:" << y_stride << std::flush; smoothquant_traits traits{data_type}; diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index 16fe0ef150..e9b4ea5cd3 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -228,20 +228,26 @@ bool test_moe_sorting(ck_tile::ArgParser args) moe_sorting_trait trait{ index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy}; - moe_sorting_args karg - { - topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(), - local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr, - is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, - sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(), - sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(), - moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, - workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size, - num_experts, topk, + moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), + weights_dev.GetDeviceBuffer(), + local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() + : nullptr, + is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, + sorted_ids_dev.GetDeviceBuffer(), + sorted_weights_dev.GetDeviceBuffer(), + sorted_expert_ids_dev.GetDeviceBuffer(), + sorted_id_cnt_dev.GetDeviceBuffer(), + moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, + workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, + tokens, + unit_size, + num_experts, + topk, #if MOE_SORTING_FMOE_2D_BUF - moe_buf_interm_dim, moe_buf_elem_bytes + moe_buf_interm_dim, + moe_buf_elem_bytes #else - static_cast(moe_buf_size * sizeof(float)) + static_cast(moe_buf_size * sizeof(float)) #endif }; diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 037891353e..a71c5e51a6 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -40,11 +40,11 @@ constexpr bool local_expert_masking = local_expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemEx; \ + ms_weight_type, \ + sub_token_tile, \ + sub_token_onshot, \ + local_expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingKernel; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -200,11 +200,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -218,11 +218,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -236,11 +236,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -254,11 +254,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -273,11 +273,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp index dc5b397c85..848fb87dcf 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp @@ -124,9 +124,9 @@ bool run(const ck_tile::ArgParser& arg_parser) smscale_buf.ToDevice(smscale_host.data()); topk_ids_buf.ToDevice(topk_ids_host.data()); - std::cout << "[" << prec_i << "-" << prec_o << "]" - << " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride - << ", experts:" << experts << ", topk:" << topk << std::flush; + std::cout << "[" << prec_i << "-" << prec_o << "]" << " tokens:" << tokens + << ", hidden_size:" << hidden_size << ", stride:" << stride << ", experts:" << experts + << ", topk:" << topk << std::flush; moe_smoothquant_traits traits{prec_i, prec_o}; diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index 78f664a671..43ae5cf677 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -25,27 +25,27 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf }(); auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking}; - auto a0 = fused_moesorting_args - { - a.topk_ids_ptr, // const void* p_topk_ids; - a.topk_weight_ptr, // const void* p_weights; - a.local_expert_mask_ptr, // const void* p_local_expert_mask; - a.local_tokens, - a.sorted_token_ids_ptr, // void* p_sorted_token_ids; - a.sorted_weight_ptr, // void* p_sorted_weights; - a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; - a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; - a.o_ptr, // void* p_moe_buf; - a.ws_ptr, // void* p_ws; - a.num_tokens, // index_t tokens; - a.block_m, // index_t unit_size; - a.num_experts, // index_t num_experts; - a.topk, // index_t topk; + auto a0 = fused_moesorting_args{ + a.topk_ids_ptr, // const void* p_topk_ids; + a.topk_weight_ptr, // const void* p_weights; + a.local_expert_mask_ptr, // const void* p_local_expert_mask; + a.local_tokens, + a.sorted_token_ids_ptr, // void* p_sorted_token_ids; + a.sorted_weight_ptr, // void* p_sorted_weights; + a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; + a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; + a.o_ptr, // void* p_moe_buf; + a.ws_ptr, // void* p_ws; + a.num_tokens, // index_t tokens; + a.block_m, // index_t unit_size; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; #if MOE_SORTING_FMOE_2D_BUF - a.stride_token, o_data_bytes, + a.stride_token, + o_data_bytes, #else - static_cast(a.num_tokens) * - a.stride_token* o_data_bytes // index_t moe_buf_bytes; + static_cast(a.num_tokens) * a.stride_token * + o_data_bytes // index_t moe_buf_bytes; #endif }; diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp index 343ddbed13..6e54df9fde 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp @@ -16,11 +16,11 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) { using f_traits = ck_tile::FusedMoeGemmTraits; using f_shape = ck_tile::FusedMoeGemmShape; + typename Ts_::WarpPerBlock_0, + typename Ts_::WarpTile_0, + typename Ts_::BlockTile_1, + typename Ts_::WarpPerBlock_0, + typename Ts_::WarpTile_0>; constexpr auto get_activation_ = []() { if constexpr(Ts_::Activation == 0) diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 83454a3969..5f87393a0a 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -40,11 +40,11 @@ constexpr bool local_expert_masking = local_expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemEx; \ + ms_weight_type, \ + sub_token_tile, \ + sub_token_onshot, \ + local_expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingKernel; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -204,11 +204,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -222,11 +222,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -240,11 +240,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -258,11 +258,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -277,11 +277,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index 35f24c1155..e4d87e5fef 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -218,8 +218,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return std::string(", st:") + std::to_string(stride); }(); - std::cout << "[" << api_str << "|" << prec_str << "]" - << " t:" << tokens; + std::cout << "[" << api_str << "|" << prec_str << "]" << " t:" << tokens; if(is_local_token) { diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 85d75320c5..bb0a0d5840 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -173,10 +173,9 @@ float grouped_gemm(const std::vector& gemm_descs, if(s.log_level_ > 0) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } ave_time = diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp index 4107181520..897952f03c 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp @@ -138,10 +138,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, if(s.log_level_ > 0) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } ave_time = diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 7532923f9a..fa7f1a31c1 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -216,9 +216,9 @@ int run_grouped_gemm_example_with_layouts(int argc, c_m_n_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); - std::cout << "gemm[" << i << "]" - << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc - << " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl; + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc + << std::endl; ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp index 3debfa7f42..8971871c14 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -170,10 +170,9 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " + << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " + << blocks.y << ", " << blocks.z << "}" << std::endl; } ave_time = ck_tile::launch_kernel( diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc index 9c32e2a11e..637ea2fbfb 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc @@ -161,8 +161,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts( conv_param.conv_filter_dilations_, conv_param.input_left_pads_, conv_param.input_right_pads_); - const ck_tile::index_t GemmK = - weight.get_element_size() / (conv_param.G_ * conv_param.K_); + const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_); const float max_accumulated_value = *std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end()); const auto rtol_atol = diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp index a1ed3c4920..2667cae788 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp @@ -87,24 +87,24 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s tail_number_v>; using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - transposed_warp_gemm, - ck_tile::memory_operation_enum::set>>; + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; using Kernel = ck_tile::AQuantGemmKernel; @@ -195,14 +195,18 @@ int run_gemm_example(int argc, char* argv[]) } else if(data_type == "i4fp8") { - using TypeConfig = decltype( - GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); } else if(data_type == "i4bf8") { - using TypeConfig = decltype( - GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); } else if(data_type == "i4f32fp8") diff --git a/example/ck_tile/remod.py b/example/ck_tile/remod.py index fdc0dcf5d7..b64fac7b06 100644 --- a/example/ck_tile/remod.py +++ b/example/ck_tile/remod.py @@ -13,7 +13,7 @@ for p in sorted(Path("./").rglob("*")): # formatting for x in all_files: subprocess.Popen(f'dos2unix {str(x)}', shell=True) - cmd = f'clang-format-12 -style=file -i {str(x)}' + cmd = f'clang-format-18 -style=file -i {str(x)}' #for xp in x.parents: #print(get_file_base(x)) subprocess.Popen(cmd, shell=True) diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index 0dfd275269..e6e3402e64 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -12,9 +12,8 @@ inline void hip_check_error(hipError_t x) if(x != hipSuccess) { std::ostringstream ss; - ss << "HIP runtime error: " << hipGetErrorString(x) << ". " - << "hip_check_error.hpp" - << ": " << __LINE__ << "in function: " << __func__; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << "hip_check_error.hpp" << ": " + << __LINE__ << "in function: " << __func__; throw std::runtime_error(ss.str()); } } diff --git a/include/ck/library/utility/algorithm.hpp b/include/ck/library/utility/algorithm.hpp index 57136f8a2a..185a147cce 100644 --- a/include/ck/library/utility/algorithm.hpp +++ b/include/ck/library/utility/algorithm.hpp @@ -11,10 +11,10 @@ namespace ck { namespace ranges { template -auto copy(InputRange&& range, OutputIterator iter) - -> decltype(std::copy(std::begin(std::forward(range)), - std::end(std::forward(range)), - iter)) +auto copy(InputRange&& range, + OutputIterator iter) -> decltype(std::copy(std::begin(std::forward(range)), + std::end(std::forward(range)), + iter)) { return std::copy(std::begin(std::forward(range)), std::end(std::forward(range)), diff --git a/include/ck/library/utility/fill.hpp b/include/ck/library/utility/fill.hpp index 4f421b4282..05357b1637 100644 --- a/include/ck/library/utility/fill.hpp +++ b/include/ck/library/utility/fill.hpp @@ -138,9 +138,10 @@ struct FillConstant } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 33c918c997..fb8f6e79dc 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -202,7 +202,7 @@ struct joinable_thread : std::thread { } - joinable_thread(joinable_thread&&) = default; + joinable_thread(joinable_thread&&) = default; joinable_thread& operator=(joinable_thread&&) = default; ~joinable_thread() @@ -320,7 +320,7 @@ struct Tensor ~Tensor() = default; Tensor& operator=(const Tensor&) = default; - Tensor& operator=(Tensor&&) = default; + Tensor& operator=(Tensor&&) = default; template explicit Tensor(const Tensor& other) : Tensor(other.template CopyAsType()) diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index 3ffac32469..28974427d7 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -108,13 +108,13 @@ struct TensorAdaptor __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() { - constexpr auto all_low_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - LowerDimensionHiddenIdss{}); + constexpr auto all_low_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); - constexpr auto all_up_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - UpperDimensionHiddenIdss{}); + constexpr auto all_up_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); @@ -338,8 +338,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran]; // sequence in, sequence out - constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr { auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); // shift hidden id so every dim id is unique @@ -361,8 +360,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return low_dim_hidden_ids_1_mod_; - } - (); + }(); return generate_sequence_v2( [&](auto i) constexpr { return Number{}; }, @@ -384,8 +382,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran]; // sequence in, constexpr tuple out - constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr { auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); // shift hidden id @@ -394,8 +391,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return up_dim_hidden_ids_1_mod_; - } - (); + }(); // constexpr tuple to sequence return generate_sequence_v2( diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index f1df2eedd4..a82f69fb3f 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -365,7 +365,7 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus{}, Number<0>{})); constexpr auto up_dim_hidden_idss = generate_tuple( - [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + [old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr { return typename arithmetic_sequence_gen{}); // new visible dimension's hidden ids - constexpr auto unordered_new_visible_dim_hidden_ids = unpack( - [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + constexpr auto unordered_new_visible_dim_hidden_ids = + unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); - constexpr auto new_visible_dim_unordered2ordered = unpack( - [](auto... xs) constexpr { return merge_sequences(xs...); }, - NewUpperDimensionNewVisibleIdss{}); + constexpr auto new_visible_dim_unordered2ordered = + unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); constexpr auto new_visible_dim_hidden_ids = unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); diff --git a/include/ck/tensor_description/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp index 9a326092d2..67da37cc90 100644 --- a/include/ck/tensor_description/tensor_space_filling_curve.hpp +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -94,10 +94,8 @@ struct SpaceFillingCurve // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the // idim-th element of multidimensional index. // All constexpr variables have to be captured by VALUE. - constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr - { - constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr - { + constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr { + constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr { auto res = idx_1d.value; auto id = 0; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp index c929956124..d0a594e2c6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -152,7 +152,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index d46c5b737d..6fb62bc677 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -93,7 +93,7 @@ struct BlockwiseGemmWmmaops_pipeline_base struct Empty { - __device__ Empty(){}; + __device__ Empty() {}; template __device__ void GlobalLoad(bool cond) { @@ -119,7 +119,7 @@ struct BlockwiseGemmWmmaops_pipeline_base GridBuffer b_scale_grid_buf_) : b_scale_thread_copy(b_scale_thread_copy_), b_scale_grid_desc(b_scale_grid_desc_), - b_scale_grid_buf(b_scale_grid_buf_){}; + b_scale_grid_buf(b_scale_grid_buf_) {}; static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{}); static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index 438d7d8ac3..231dbf817c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -96,9 +96,9 @@ template < index_t KPack, bool TransposeC = false, index_t AMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops, + KPack * XdlopsGemm{}.K0PerXdlops, index_t BMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops> + KPack * XdlopsGemm{}.K0PerXdlops> struct BlockwiseGemmXdlops_pipeline_v4 { static constexpr auto I0 = Number<0>{}; @@ -188,7 +188,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -217,7 +217,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 9296b8136f..cd13dbb836 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -153,7 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_base template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -182,7 +182,7 @@ struct BlockwiseGemmXdlops_pipeline_base template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp index e9f9b0be7e..90f356987d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp @@ -110,7 +110,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); const auto waveId_m = wave_idx[I0]; @@ -138,7 +138,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); const auto waveId_m = wave_idx[I0]; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index d3f6344c27..e6bb2d8db3 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -114,7 +114,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -143,7 +143,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -667,9 +667,9 @@ template < index_t KPack, bool TransposeC = false, index_t AMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops, + KPack * XdlopsGemm{}.K0PerXdlops, index_t BMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops> + KPack * XdlopsGemm{}.K0PerXdlops> struct BlockwiseGemmXdlops_v2 { static constexpr auto I0 = Number<0>{}; @@ -742,7 +742,7 @@ struct BlockwiseGemmXdlops_v2 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -771,7 +771,7 @@ struct BlockwiseGemmXdlops_v2 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index 287c6701c3..84ee096cba 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -90,7 +90,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp index 98cc149f4d..aa06f8c6c1 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -258,8 +258,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad src_buf.template DirectCopyToLds, ScalarPerVector>( dst_buf, src_offset, dst_offset, is_src_valid); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -271,8 +270,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad }); return move_on_dim_; - } - (); + }(); // Decide whether to move forward or backward. constexpr auto forward_sweep = [&]() { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp index 3e9e501126..55dd924f8c 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp @@ -281,8 +281,7 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad src_buf.template DirectCopyToLds, ScalarPerVector>( dst_buf, src_offset, dst_offset, true); - constexpr auto move_src_on_dim = [&]() constexpr - { + constexpr auto move_src_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -295,11 +294,9 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad }); return move_on_dim_; - } - (); + }(); - constexpr auto move_dst_on_dim = [&]() constexpr - { + constexpr auto move_dst_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -311,8 +308,7 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad }); return move_on_dim_; - } - (); + }(); // Decide whether to move forward or backward. constexpr auto forward_sweep = [&]() { diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 9285211519..c946abb77d 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -49,8 +49,8 @@ namespace device { #ifndef CK_CODE_GEN_RTC struct BaseArgument { - BaseArgument() = default; - BaseArgument(const BaseArgument&) = default; + BaseArgument() = default; + BaseArgument(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default; virtual ~BaseArgument() {} @@ -60,8 +60,8 @@ struct BaseArgument struct BaseInvoker { - BaseInvoker() = default; - BaseInvoker(const BaseInvoker&) = default; + BaseInvoker() = default; + BaseInvoker(const BaseInvoker&) = default; BaseInvoker& operator=(const BaseInvoker&) = default; virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{}) @@ -75,8 +75,8 @@ struct BaseInvoker struct BaseOperator { - BaseOperator() = default; - BaseOperator(const BaseOperator&) = default; + BaseOperator() = default; + BaseOperator(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default; #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) virtual bool IsSupportedArgument(const BaseArgument*) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp index 267a970ee5..52632785bd 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp @@ -70,15 +70,9 @@ struct GroupedGemmKernelArgument for(auto sd : StrideDs) str << sd << ","; - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SE:" << StrideE << ", " - << "SDs: {" << str.str() << "}" - << "}" << std::endl; + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SE:" << StrideE + << ", " << "SDs: {" << str.str() << "}" << "}" << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 72c011bfb2..1dd143f6a3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -205,25 +205,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( - AsPointer p_as_grid, - BsPointer p_bs_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index fc1a2b995a..c57d5316ba 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -36,25 +36,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index 0cd1d84a43..c82da32313 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -58,21 +58,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index 985752796b..efe8fe92c7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -39,26 +39,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_gemm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatAB* __restrict__ p_b1_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const B1ElementwiseOperation b1_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) + kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 12085edaae..811924a189 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -63,24 +63,24 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp index 1b487502f4..a38e0d25e7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp @@ -52,23 +52,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dl_multiple_d( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, - const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index d38698af4b..2ae4794d00 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -42,32 +42,32 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_gemm_xdl_cshuffle_v1( - const A0B0B1DataType* __restrict__ p_a0_grid, - const A0B0B1DataType* __restrict__ p_b0_grid, - D0sPointer p_d0s_grid, - const A0B0B1DataType* __restrict__ p_b1_grid, - D1sPointer p_d1s_grid, - E1DataType* __restrict__ p_e1_grid, - const A0ElementwiseOperation a0_element_op, - const B0ElementwiseOperation b0_element_op, - const CDE0ElementwiseOperation cde0_element_op, - const B1ElementwiseOperation b1_element_op, - const CDE1ElementwiseOperation cde1_element_op, - const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, - const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, - const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - d1s_grid_desc_mblock_mperblock_nblock_nperblock, - const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e1_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2E1TileMap block_2_e1tile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) + kernel_batched_gemm_gemm_xdl_cshuffle_v1( + const A0B0B1DataType* __restrict__ p_a0_grid, + const A0B0B1DataType* __restrict__ p_b0_grid, + D0sPointer p_d0s_grid, + const A0B0B1DataType* __restrict__ p_b1_grid, + D1sPointer p_d1s_grid, + E1DataType* __restrict__ p_e1_grid, + const A0ElementwiseOperation a0_element_op, + const B0ElementwiseOperation b0_element_op, + const CDE0ElementwiseOperation cde0_element_op, + const B1ElementwiseOperation b1_element_op, + const CDE1ElementwiseOperation cde1_element_op, + const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, + const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e1_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2E1TileMap block_2_e1tile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -829,10 +829,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle is_same_v && CheckDLayout() && (is_same_v || - is_same_v)&&CheckDLayout() && + is_same_v) && + CheckDLayout() && is_same_v)) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp index 6624570b27..2e0b5da113 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -33,9 +33,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg) + kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -79,9 +79,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg) + kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // Pass two lds pointer is the key to tell compiler that ds_read/write diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index de7d67f08b..851f6a5f97 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -39,26 +39,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_reduce_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - ReducePtrsGlobal p_reduces_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const ReduceInElementwiseOperations reduce_in_element_ops, - const ReduceAccElementwiseOperations reduce_out_element_ops, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, - const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, - const Block2CTileMap block_2_ctile_map) + kernel_batched_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp index 1026118381..2e1684adb6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp @@ -40,21 +40,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, - const B0DataType* __restrict__ p_b0_grid, - const B1DataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid, - index_t M, - index_t N, - index_t K, - index_t O, - index_t G0, - index_t G1, - float alpha, - bool input_permute, - bool output_permute) + kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) @@ -178,15 +178,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid, - ODataType* __restrict__ p_out_grid, - index_t batch_size, - index_t sequence_length, - index_t head_count, - index_t head_size, - float alpha) + kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) @@ -310,17 +310,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid, - const KVDataType* __restrict__ p_kv_grid, - ODataType* __restrict__ p_out_grid, - index_t batch_size, - index_t q_sequence_length, - index_t kv_sequence_length, - index_t head_count, - index_t head_size, - float alpha) + kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid, + const KVDataType* __restrict__ p_kv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index bae5c6019d..18b9e6ce74 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -43,30 +43,30 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatAB* __restrict__ p_b1_grid, - FloatC* __restrict__ p_c_grid, - D0sPointer p_d0s_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const C0DEElementwiseOperation c0de_element_op, - const B1ElementwiseOperation b1_element_op, - const C1DEElementwiseOperation c1de_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c1_grid_desc_mblock_mperblock_nblock_nperblock, - const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - const Block2CTileMap block_2_ctile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, - const C0MatrixMask c0_matrix_mask) + kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + D0sPointer p_d0s_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const C0DEElementwiseOperation c0de_element_op, + const B1ElementwiseOperation b1_element_op, + const C1DEElementwiseOperation c1de_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, + const C0MatrixMask c0_matrix_mask) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index e846b0630b..ec0fb7b98d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -42,27 +42,27 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatAB* __restrict__ p_b1_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const B1ElementwiseOperation b1_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, - const C0MatrixMask c0_matrix_mask) + kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, + const C0MatrixMask c0_matrix_mask) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index abd6574d8c..cecd312879 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -29,14 +29,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_wmma_cshuffle_v3( - typename GridwiseGemm::Argument - karg, // This works for now but it actually receives a - // DeviceBatchedGemm_Wmma_CShuffleV3::Argument - // argument through implicit conversion to base class! - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index 494524b6f0..16d5feccf2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -48,9 +48,9 @@ namespace device { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) + kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp index 7d9555dc82..1419f5ee7c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp @@ -33,9 +33,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) + kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -71,9 +71,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg) + kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // Pass two lds pointer is the key to tell compiler that ds_read/write diff --git a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp index 8843e520a6..4934993693 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -610,8 +610,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle if(!parg) { std::ostringstream err; - err << "Provided argument pointer is not of an Argument class!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "Provided argument pointer is not of an Argument class!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp index 9482812f75..dee3a51df7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -467,12 +467,12 @@ struct DeviceColumnToImageImpl float elapsed_time = 0.f; const auto kernel = kernel_tensor_rearrange, - GridwiseTensorRearrangeKernel>; + InputDataType, + OutputGridDesc, + OutputDataType, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch<>, + GridwiseTensorRearrangeKernel>; // Execute each set of independent filters for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index df5922a04f..b99032fb9f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -37,23 +37,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_abd_xdl_cshuffle( - AsPointer p_as_grid, - BsPointer p_bs_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, - const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 77974f84ae..de8e524dc3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -35,23 +35,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp index 1b0db73fdd..dc07f8b445 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp @@ -35,17 +35,15 @@ auto CalculateMaxRead(const std::vector& lengths, const std::vector __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r3_for_conv3d( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const index_t num_batches, - const index_t a_batch_stride, - const index_t b_batch_stride, - const index_t c_batch_stride, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v2r3_for_conv3d( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t num_batches, + const index_t a_batch_stride, + const index_t b_batch_stride, + const index_t c_batch_stride, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index b9467ac194..9e8c959f98 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -34,21 +34,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dl_multiple_d( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, - const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \ defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index 47fb630ea9..8f4c41b69c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -37,31 +37,30 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EMeanVarDataType* __restrict__ p_e_grid, - EMeanVarDataType* __restrict__ p_welford_mean_grid, - EMeanVarDataType* __restrict__ p_welford_var_grid, - int32_t* __restrict__ p_welford_count_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock - mean_var_grid_desc_mblock_mperblock_nblock, - const CountGridDescriptor_MBlock_MPerBlock_NBlock - count_grid_desc_mblock_mperblock_nblock, - const Block2ETileMap block_2_etile_map, - index_t NRaw) + kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EMeanVarDataType* __restrict__ p_e_grid, + EMeanVarDataType* __restrict__ p_welford_mean_grid, + EMeanVarDataType* __restrict__ p_welford_var_grid, + int32_t* __restrict__ p_welford_count_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock + mean_var_grid_desc_mblock_mperblock_nblock, + const CountGridDescriptor_MBlock_MPerBlock_NBlock count_grid_desc_mblock_mperblock_nblock, + const Block2ETileMap block_2_etile_map, + index_t NRaw) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; @@ -121,26 +120,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_welford_layernorm2d_second_half( - const EMeanVarDataType* __restrict__ p_e_grid, - const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, - const EMeanVarDataType* __restrict__ p_in_welford_var_grid, - const int32_t* __restrict__ p_in_welford_count_grid, - const GammaDataType* __restrict__ p_gamma_grid, - const BetaDataType* __restrict__ p_beta_grid, - HDataType* __restrict__ p_h_grid, - const EHGridDesc_M_N e_grid_desc_m_n, - const EHGridDesc_M_N h_grid_desc_m_n, - const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, - const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, - const GammaBetaGridDesc_N gamma_grid_desc_n, - const GammaBetaGridDesc_N beta_grid_desc_n, - index_t numMeanVarCountBlockTileIteration_N, - index_t NBlockClusterLength, - ComputeDataType epsilon, - HElementwiseOperation h_element_op) + kernel_welford_layernorm2d_second_half( + const EMeanVarDataType* __restrict__ p_e_grid, + const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, + const EMeanVarDataType* __restrict__ p_in_welford_var_grid, + const int32_t* __restrict__ p_in_welford_count_grid, + const GammaDataType* __restrict__ p_gamma_grid, + const BetaDataType* __restrict__ p_beta_grid, + HDataType* __restrict__ p_h_grid, + const EHGridDesc_M_N e_grid_desc_m_n, + const EHGridDesc_M_N h_grid_desc_m_n, + const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, + const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, + const GammaBetaGridDesc_N gamma_grid_desc_n, + const GammaBetaGridDesc_N beta_grid_desc_n, + index_t numMeanVarCountBlockTileIteration_N, + index_t NBlockClusterLength, + ComputeDataType epsilon, + HElementwiseOperation h_element_op) { GridwiseWelfordLayernorm::Run(p_e_grid, p_in_welford_mean_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index c048e7249c..c1b3f98bc9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -38,27 +38,27 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_multiple_r_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - FloatRsPointer p_rs_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const QsElementwiseOperation qs_element_op, - const RsElementwiseOperation rs_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_multiple_d_multiple_r_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + FloatRsPointer p_rs_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const QsElementwiseOperation qs_element_op, + const RsElementwiseOperation rs_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index f193b093d1..e36816df64 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -37,22 +37,22 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp index 2554ffea46..0f6457f48e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp @@ -32,20 +32,19 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_waveletmodel_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const EElementwiseOperation e_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const EElementwiseOperation e_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 884175eaca..f32334cd91 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -28,14 +28,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_contraction_multiple_d_xdl_cshuffle( - const void CK_CONSTANT_ADDRESS_SPACE* contraction_args, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_contraction_multiple_d_xdl_cshuffle( + const void CK_CONSTANT_ADDRESS_SPACE* contraction_args, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index db2426518a..fe9e4ff7e8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -80,21 +80,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const std::array gemm_kernel_args, - const index_t gemms_count, - const AElementwiseOp a_element_op, - const BElementwiseOp b_element_op, - const CDEElementwiseOp cde_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n, - const index_t KBatch) + kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const std::array gemm_kernel_args, + const index_t gemms_count, + const AElementwiseOp a_element_op, + const BElementwiseOp b_element_op, + const CDEElementwiseOp cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t KBatch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 0b3f1a0255..3306e311b3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -35,18 +35,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_dlops_bwd_weight( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const index_t batch_count, - const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, - const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, - const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_dlops_bwd_weight( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, + const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp index a819b91b05..e5872816f5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp @@ -77,21 +77,21 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl using CElementwiseGridDesc = remove_cvref_t; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<1, ElemsPerBlock>; using GridwiseElementwiseCast = GridwiseElementwise, - Tuple, - Tuple, - Tuple, - Block2TileMapElementwise, - WeiElementwiseOperation, - ElementwiseBlockSize, - I1, - ElemsPerBlock, - I1, - ElemsPerBlock / ElementwiseBlockSize, - Sequence<0, 1>, - Sequence<1>, - Sequence<1>, - I1, - I1>; + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + WeiElementwiseOperation, + ElementwiseBlockSize, + I1, + ElemsPerBlock, + I1, + ElemsPerBlock / ElementwiseBlockSize, + Sequence<0, 1>, + Sequence<1>, + Sequence<1>, + I1, + I1>; struct Argument : public BaseArgument { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 672c7dd2f7..601bf4eb5a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -43,22 +43,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_bwd_weight( - const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const index_t batch_count, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index c7c463f43d..8796f5520e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -44,16 +44,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( - typename GridwiseGemm::Argument karg, - [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); @@ -99,16 +99,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( - typename GridwiseGemm::Argument karg, - [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 6c53161ded..6f6a3587ac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -41,22 +41,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_bwd_weight( - const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const index_t batch_count, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index f13a256d6b..bbaa04536c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -42,16 +42,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( - typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); @@ -100,16 +100,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( - typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 3e14f66a09..e7446bb995 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -72,23 +72,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_dl_multiple_d( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, - const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_fwd_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index 50e171e503..393ee80881 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -93,18 +93,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_dl( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - CDataType* __restrict__ p_c_grid, - const index_t batch_count, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_fwd_dl( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ defined(__gfx11__) || defined(__gfx12__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 6d2988ba24..ac40d363b5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -81,25 +81,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( - AsPointer p_as_grid, - BsPointer p_bs_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfG compute_ptr_offset_of_groups, - const ComputePtrOffsetOfN compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfG compute_ptr_offset_of_groups, + const ComputePtrOffsetOfN compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) @@ -383,11 +383,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW() && NeedTransposeKernel, - ctc::NHWGC, - std::conditional_t() && NeedTransposeKernel, - ctc::NDHWGC, - ALay>>; + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::NHWGC, + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGC, + ALay>>; const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -403,11 +403,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW() && NeedTransposeKernel, - ctc::GKYXC, - std::conditional_t() && NeedTransposeKernel, - ctc::GKZYXC, - BLay>>; + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::GKYXC, + std::conditional_t() && NeedTransposeKernel, + ctc::GKZYXC, + BLay>>; const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -423,11 +423,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW() && NeedTransposeKernel, - ctc::NHWGK, - std::conditional_t() && NeedTransposeKernel, - ctc::NDHWGK, - ELay>>; + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::NHWGK, + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGK, + ELay>>; const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index e30caf3aac..a938820e6c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -72,15 +72,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const DsGridDesc_M_N ds_grid_desc_m_n, - const EGridDesc_M_N c_grid_desc_m_n, - const ComputePtrOffset compute_ptr_offset_of_groups, - const ComputePtrOffset compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N ds_grid_desc_m_n, + const EGridDesc_M_N c_grid_desc_m_n, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group @@ -151,16 +151,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( - typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const DsGridDesc_M_N ds_grid_desc_m_n, - const EGridDesc_M_N c_grid_desc_m_n, - const ComputePtrOffset compute_ptr_offset_of_groups, - const ComputePtrOffset compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N ds_grid_desc_m_n, + const EGridDesc_M_N c_grid_desc_m_n, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group @@ -369,11 +369,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKCYX_NGKHW(), - ctc::NHWGC, - std::conditional_t(), - ctc::NDHWGC, - ALay>>; + is_NGCHW_GKCYX_NGKHW(), + ctc::NHWGC, + std::conditional_t(), + ctc::NDHWGC, + ALay>>; const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -399,11 +399,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKCYX_NGKHW(), - ctc::GKYXC, - std::conditional_t(), - ctc::GKZYXC, - BLay>>; + is_NGCHW_GKCYX_NGKHW(), + ctc::GKYXC, + std::conditional_t(), + ctc::GKZYXC, + BLay>>; const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -429,11 +429,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKCYX_NGKHW(), - ctc::NHWGK, - std::conditional_t(), - ctc::NDHWGK, - ELay>>; + is_NGCHW_GKCYX_NGKHW(), + ctc::NHWGK, + std::conditional_t(), + ctc::NDHWGK, + ELay>>; const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); @@ -1347,9 +1347,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return false; if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "The MultiABD is not supported!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + std::cout << "The MultiABD is not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } } @@ -1374,8 +1373,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "Current device does not support xdl instructions!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + std::cout << "Current device does not support xdl instructions!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } return false; @@ -1455,9 +1454,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "Unsupported A Layout!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; } return false; } @@ -1488,9 +1486,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "Unsupported A Layout!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; } return false; } @@ -1602,9 +1599,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "Unsupported E Layout!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; } return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index ec1a05366e..1e5c67aac7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -131,29 +131,29 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batch_gemm_multiple_d_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - RsPointer p_rs_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const QsElementwiseOperation qs_element_op, - const RsElementwiseOperation rs_element_op, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batch_gemm_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + RsPointer p_rs_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const QsElementwiseOperation qs_element_op, + const RsElementwiseOperation rs_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 9988367959..b1494a36bf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -41,16 +41,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle( - Array gemm_desc_kernel_args, - const index_t gemms_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op, - const ComputePtrOffset compute_ptr_offset_of_groups, - const ComputePtrOffset compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle( + Array gemm_desc_kernel_args, + const index_t gemms_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 21afc06040..7cfc73fab6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -36,14 +36,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const index_t grid_size_grp, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const index_t grid_size_grp, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index 10d8a4a44d..d0d613af8f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -32,13 +32,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index 18872e38ea..7b5dd55a8f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -576,16 +576,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage if(dev_gemm_args == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } if(dev_gemm_workspace == nullptr) { std::ostringstream err; - err << "The gemm workspace buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm workspace buffer is not allocated!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } @@ -624,16 +624,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage if(arg.p_dev_gemm_kargs_ == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } if(arg.p_workspace_ == nullptr) { std::ostringstream err; - err << "The gemm workspace buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm workspace buffer is not allocated!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } @@ -711,8 +711,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage if(not_all_have_kbatch_value_same) { std::ostringstream err; - err << "Not all gemms have same kbatch value (=1 or >1)! " - << "group [" << i << "], kbatch: " << gemm_arg.k_batch + err << "Not all gemms have same kbatch value (=1 or >1)! " << "group [" << i + << "], kbatch: " << gemm_arg.k_batch << ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 61058dec2b..38bb19b712 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -60,13 +60,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) @@ -600,8 +600,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop if(dev_gemm_args == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } @@ -629,8 +629,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop if(arg.p_dev_gemm_args_ == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 3fb2c5ae86..1754b542c5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -32,16 +32,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( - const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const B1ElementwiseOperation b1_element_op, - const CElementwiseOperation c_element_op) + kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( + const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index cbee4e09f4..a528149ecd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -31,13 +31,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op) + kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 8fe71fb9a2..81134465af 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -38,17 +38,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - uint32_t* barrier_count, - const index_t barrier_size_grp, - const index_t group_count, - const index_t grid_size_grp, - const index_t KBatch, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op) + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + uint32_t* barrier_count, + const index_t barrier_size_grp, + const index_t group_count, + const index_t grid_size_grp, + const index_t KBatch, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 01f52881f4..ea14087698 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -33,13 +33,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) + kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); @@ -416,8 +416,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK1)! " - << "group [" << i << "], kbatch: " << kbatch + err << "Not all gemms have same kbatch value (=1 or >1)! " << "group [" << i + << "], kbatch: " << kbatch << ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp index 67a100a112..b66ab997bb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp @@ -45,21 +45,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid, - const B0DataType* __restrict__ p_b0_grid, - const B1DataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid, - index_t M, // SequenceQ - index_t N, // SequenceK - index_t K, // HeadDim - index_t O, // SequenceK - index_t G0, // Batch - index_t G1, // HeadNum - float alpha, - bool input_permute, - bool output_permute) + kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 48a10f219c..efa85a357c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -100,64 +100,64 @@ struct DeviceMoeGemmBlockScale { static constexpr index_t NumDTensor = DsDataType::Size(); using GridwiseGemm = GridwiseMoeGemmBlockScale< - ALayout, - BLayout, - DsLayout, - CLayout, - ADataType, - BDataType, - GemmAccDataType, - CShuffleDataType, - DsDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - BlockSize, - ScaleBlockM, - ScaleBlockN, - ScaleBlockK, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ActivationOP, - NSwizzle, - IsInputGemm, - MulRoutedWeight, - IndexType, - ComputeTypeA, - ComputeTypeB, - LDSTypeA, - LDSTypeB>; + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; using Argument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp index 6dc3a5f881..4bf38d9d1f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp @@ -92,62 +92,62 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle; + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB>; using Argument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp index cc88c1a104..e196ed5e3a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp @@ -44,21 +44,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid, - const B0DataType* __restrict__ p_b0_grid, - const B1DataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid, - index_t M, // SequenceQ - index_t N, // SequenceK - index_t K, // HeadDim - index_t O, // SequenceK - index_t G0, // Batch - index_t G1, // HeadNum - float alpha, - bool input_permute, - bool output_permute) + kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp index 63b49d9aa0..c1d3aa43de 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -36,25 +36,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1, - const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index 9fe2f0d976..cc500bb9cb 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -33,7 +33,7 @@ struct MaskDisabledPredicate }; __host__ __device__ constexpr bool - IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const + IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const { return false; } diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 8f829496da..4a87e8a277 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -266,7 +266,7 @@ struct DequantPack8 dst.template AsType()(Number<3>{}) = type_convert(src.template AsType()[Number<3>{}]); - y = dst.template AsType()[Number<0>{}]; + y = dst.template AsType()[Number<0>{}]; #endif } diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index 02dba97430..36dc8aa6ba 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -527,11 +527,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ABDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -997,9 +997,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle static_for<0, post_shuffle_thread_desc_m_n.GetElementSize(), 1>{}([&](auto i) { const auto c_ds_src_data_refs = concat_tuple_of_reference( tie(e_thread_buf[i]), - generate_tie( - [&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; }, - Number{})); + generate_tie([&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; }, + Number{})); auto e_dst_data_refs = tie(e_thread_buf(i)); unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs); }); @@ -1124,7 +1123,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle }); } // shuffle C + Ds + welford + write out - } // run + } // run }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp index e3c50ef06c..cc3306e1bd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp @@ -228,9 +228,8 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d static_for<0, MThreadSliceSize, 1>{}([&](auto I) { const auto c_ds_buf_refs = concat_tuple_of_reference( tie(accu_value_buf[I]), - generate_tie( - [&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; }, - Number{})); + generate_tie([&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; }, + Number{})); unpack2(out_elementwise_op, tie(out_value_buf(I)), c_ds_buf_refs); }); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index 53a45c7f16..e8f8caa10d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -372,11 +372,11 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle : false; constexpr auto is_scale_mfma = false; constexpr auto mfma = MfmaSelector::selected_mfma; + Gemm0MPerXdl, + Gemm0NPerXdl, + A0B0B1DataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma; constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N5 = mfma.group_size; return transform_tensor_descriptor( @@ -669,11 +669,11 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_A0K1_B0K1, MfmaSelector::selected_mfma.k_per_blk); + Gemm0MPerXdl, + Gemm0NPerXdl, + A0B0B1DataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm0 = BlockwiseGemmXdlops_v2< BlockSize, @@ -1176,18 +1176,16 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle // tuple of reference to C/Ds tensor descriptors const auto c1_d1s_desc_refs = concat_tuple_of_reference( tie(c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c1_d1s_buf_refs = concat_tuple_of_reference( tie(c1_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return d1s_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return d1s_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c1_d1s_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp index 1326c5d62d..839a68a978 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -24,14 +24,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, - const OutGridDescTuple out_grid_desc_tuple, - const InDataTypePointerTuple p_in_global_tuple, - const OutDataTypePointerTuple p_out_global_tuple, - const Block2TileMap block_2_tile_map, - const ElementwiseOperation elementwise_op) + kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op) { GridwiseElementwiseFunctor::Run(in_grid_desc_tuple, out_grid_desc_tuple, @@ -56,20 +56,20 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, - const InBGridDescTuple in_grid_desc_tuple_b, - const OutAGridDescTuple out_grid_desc_tuple_a, - const OutBGridDescTuple out_grid_desc_tuple_b, - const InADataTypePointerTuple p_in_global_tuple_a, - const InBDataTypePointerTuple p_in_global_tuple_b, - const OutADataTypePointerTuple p_out_global_tuple_a, - const OutBDataTypePointerTuple p_out_global_tuple_b, - const Block2TileMapA block_2_tile_map_a, - const Block2TileMapB block_2_tile_map_b, - const ElementwiseOperation elementwise_op, - const index_t a_grid_size) + kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, + const InBGridDescTuple in_grid_desc_tuple_b, + const OutAGridDescTuple out_grid_desc_tuple_a, + const OutBGridDescTuple out_grid_desc_tuple_b, + const InADataTypePointerTuple p_in_global_tuple_a, + const InBDataTypePointerTuple p_in_global_tuple_b, + const OutADataTypePointerTuple p_out_global_tuple_a, + const OutBDataTypePointerTuple p_out_global_tuple_b, + const Block2TileMapA block_2_tile_map_a, + const Block2TileMapB block_2_tile_map_b, + const ElementwiseOperation elementwise_op, + const index_t a_grid_size) { if(get_block_1d_id() < a_grid_size) { @@ -112,27 +112,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_elementwise_batched_dual( - const InAGridDescTuple in_grid_desc_tuple_a, - const InBGridDescTuple in_grid_desc_tuple_b, - const OutAGridDescTuple out_grid_desc_tuple_a, - const OutBGridDescTuple out_grid_desc_tuple_b, - const InADataTypePointerTuple p_in_global_tuple_a, - const InBDataTypePointerTuple p_in_global_tuple_b, - const OutADataTypePointerTuple p_out_global_tuple_a, - const OutBDataTypePointerTuple p_out_global_tuple_b, - const Block2TileMapA block_2_tile_map_a, - const Block2TileMapB block_2_tile_map_b, - const ElementwiseOperation elementwise_op, - const index_t a_grid_size, - const index_t batch_count_a, - const index_t batch_count_b, - const std::array input_batch_strides_a, - const std::array input_batch_strides_b, - const std::array output_batch_strides_a, - const std::array output_batch_strides_b) + kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a, + const InBGridDescTuple in_grid_desc_tuple_b, + const OutAGridDescTuple out_grid_desc_tuple_a, + const OutBGridDescTuple out_grid_desc_tuple_b, + const InADataTypePointerTuple p_in_global_tuple_a, + const InBDataTypePointerTuple p_in_global_tuple_b, + const OutADataTypePointerTuple p_out_global_tuple_a, + const OutBDataTypePointerTuple p_out_global_tuple_b, + const Block2TileMapA block_2_tile_map_a, + const Block2TileMapB block_2_tile_map_b, + const ElementwiseOperation elementwise_op, + const index_t a_grid_size, + const index_t batch_count_a, + const index_t batch_count_b, + const std::array input_batch_strides_a, + const std::array input_batch_strides_b, + const std::array output_batch_strides_a, + const std::array output_batch_strides_b) { static_assert(InAGridDescTuple::Size() == NumInputsA && InADataTypePointerTuple::Size() == NumInputsA); @@ -217,17 +216,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, - const OutGridDescTuple out_grid_desc_tuple, - const InDataTypePointerTuple p_in_global_tuple, - const OutDataTypePointerTuple p_out_global_tuple, - const Block2TileMap block_2_tile_map, - const ElementwiseOperation elementwise_op, - const index_t batch_count, - const std::array input_batch_strides, - const std::array output_batch_strides) + kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op, + const index_t batch_count, + const std::array input_batch_strides, + const std::array output_batch_strides) { static_assert(InGridDescTuple::Size() == NumInputs && InDataTypePointerTuple::Size() == NumInputs); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 21dac6f9e9..fab0fbab1d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -34,21 +34,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - const ScaleDataType* __restrict__ p_scale_grid, - CDataType* __restrict__ p_c_grid, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const ScaleGridDesc scale_grid_desc, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const ScaleGridDesc scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index f406bfb95a..6e73f0955b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -40,31 +40,31 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC0* __restrict__ p_bias_grid, - const FloatC1* __restrict__ p_d0_grid, - ReducePtrsGlobal p_reduces_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const C1ElementwiseOperation c1_element_op, - const ReduceInElementwiseOperations reduce_in_element_ops, - const ReduceAccElementwiseOperations reduce_out_element_ops, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c0_grid_desc_mblock_mperblock_nblock_nperblock, - const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c1_grid_desc_mblock_mperblock_nblock_nperblock, - const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_bias_grid, + const FloatC1* __restrict__ p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const C1ElementwiseOperation c1_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 562b9b8ffa..5e779b2881 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -28,15 +28,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp index b473d7cbf2..7deda48f7b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -21,12 +21,12 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif #if CK_USE_WAVES_PER_EU - __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) #endif - kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) + kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -154,17 +154,10 @@ struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 054aca2936..c37ffb6263 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -687,11 +687,11 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle static constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + BComputeDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -863,18 +863,16 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 127d889572..df5c8b10f3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -952,7 +952,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 }); // copy c, d, e + reduction } // shuffle C + Ds + reduction + write out - } // Run + } // Run }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index de6c9c1601..36eb4489e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -34,25 +34,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_multiple_d_wmma_cshuffle( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc, - const BGridDesc_BK0_N_BK1 b_grid_desc, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc, + const BGridDesc_BK0_N_BK1 b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group @@ -127,25 +127,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_wmma_cshuffle( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2CTileMap block_2_etile_map) + kernel_contraction_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // printf("entry kernel launch"); @@ -219,23 +219,22 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_mupltipe_d_wmma_cshuffle( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_mupltipe_d_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index acbccf1889..318ff59383 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -657,11 +657,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + BComputeDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -856,18 +856,16 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 1e79d67f93..769bc5b877 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -38,23 +38,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -73,18 +73,18 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, block_2_etile_map); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_ds_grid; - ignore = p_e_grid; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = block_2_etile_map; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; #endif } @@ -814,18 +814,16 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad // A tuple of reference to C/Ds tensor descriptors. const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // A tuple of reference to C/Ds grid buffers. const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // A tuple of starting index of C/Ds blockwise copy. const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index 5815eb5b0b..85b5b5faab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -611,11 +611,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + AComputeType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -855,18 +855,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index db227bb7ef..b257fa4aa3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -35,24 +35,24 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_reduce_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - ReducePtrsGlobal p_reduces_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const ReduceInElementwiseOperations reduce_in_element_ops, - const ReduceAccElementwiseOperations reduce_out_element_ops, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 70301c326a..b4848c7077 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -593,11 +593,11 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ABDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -769,18 +769,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -1032,11 +1030,11 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ABDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp index f64838ea4e..1b4c2666ab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -607,11 +607,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ComputeType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -845,18 +845,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 4458b9356d..51cd5ada91 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -31,19 +31,19 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - CDataType* __restrict__ p_c_grid, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 75f12d094e..9a8d09e5e4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -337,20 +337,11 @@ struct GridwiseGemm_wmma_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp index 7b6ad5ca3e..37ffbf1c51 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -240,22 +240,12 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "SScaleB:" << StrideScaleB << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded + << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 5a4a41e507..fc01866ddf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -24,9 +24,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) + kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 63d40f6ff8..68112489ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -217,20 +217,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index d45ed79ae3..9089bd2ce2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -33,9 +33,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -54,9 +54,9 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // Pass two lds pointer is the key to tell compiler that ds_read/write @@ -538,24 +538,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << ", " - << "Stream-K Selection:" << Streamk_sel << ", " - << "Grid size:" << Grid_size << ", " - << "Reduction Strategy:" + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << ", " << "Stream-K Selection:" << Streamk_sel + << ", " << "Grid size:" << Grid_size << ", " << "Reduction Strategy:" << (reduction_strategy == StreamKReductionStrategy::Atomic ? "Atomic" : "Reduction") << "}" << std::endl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 7edcd7270f..c22229a183 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -20,9 +20,9 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -42,12 +42,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - typename GridwiseGemm::Problem problem) + kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + typename GridwiseGemm::Problem problem) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -436,20 +436,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; @@ -822,11 +813,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ComputeTypeB, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index f92268265f..48c577b2e0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -20,7 +20,7 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) @@ -46,12 +46,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) #endif - kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, - const FloatB* p_b_grid, - FloatC* p_c_grid, - typename GridwiseGemm::Problem problem) + kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + typename GridwiseGemm::Problem problem) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -475,20 +475,11 @@ struct GridwiseGemm_xdl_cshuffle_v2 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; @@ -881,11 +872,11 @@ struct GridwiseGemm_xdl_cshuffle_v2 constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ComputeTypeA, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< // BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 6270d0c4dc..5f3950b29e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -30,7 +30,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -58,7 +58,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) @@ -666,20 +666,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 8d5c844103..91f08413af 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -30,7 +30,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg) @@ -58,7 +58,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) @@ -155,11 +155,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle static constexpr bool is_single_rate_mfma = true; static constexpr auto is_scale_mfma = false; static constexpr auto mfma = MfmaSelector{}; + MPerXdl, + NPerXdl, + ComputeTypeA, + is_single_rate_mfma, + is_scale_mfma>{}; static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); @@ -575,20 +575,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index 93c1779a80..d8c697823a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -30,7 +30,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -60,7 +60,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) @@ -563,22 +563,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "SScaleB:" << StrideScaleB << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded + << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index 97d0e2a4eb..9f442906f5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -29,7 +29,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -59,7 +59,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) @@ -589,18 +589,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead + << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 + << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" + << std::endl; } index_t M; @@ -1757,18 +1750,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -2340,18 +2331,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index a3694e3767..17b4cd7c68 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -33,7 +33,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg) @@ -65,7 +65,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg) @@ -577,20 +577,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1636,18 +1627,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -2170,18 +2159,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 64fbda7a44..b41f1220fb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -33,7 +33,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -538,20 +538,11 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1556,18 +1547,16 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 3553a1d040..27926e5290 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -33,7 +33,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg) @@ -65,7 +65,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) @@ -174,11 +174,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle : false; static constexpr auto is_scale_mfma = false; static constexpr auto mfma = MfmaSelector{}; + MPerXdl, + NPerXdl, + ComputeTypeA, + is_single_rate_mfma, + is_scale_mfma>{}; static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); static constexpr index_t KGroup = []() { if constexpr(is_same_v, f8_t>) @@ -599,20 +599,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1414,18 +1405,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -1855,18 +1844,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp index 909376e5f7..20711f0c5e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -33,7 +33,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle( @@ -66,7 +66,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds( @@ -555,20 +555,11 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1446,18 +1437,16 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -1948,18 +1937,16 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index ca3902188e..bc87559c43 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -34,7 +34,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -66,7 +66,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -422,8 +422,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static_assert(!((is_same_v, f6x16_pk_t> || is_same_v, bf6x16_pk_t> || is_same_v, f6x32_pk_t> || - is_same_v, bf6x32_pk_t>)&&GemmSpec != - GemmSpecialization::Default), + is_same_v, bf6x32_pk_t>) && + GemmSpec != GemmSpecialization::Default), "Packed F6 types do not support padding"); if constexpr(GemmSpec == GemmSpecialization::NKPadding || @@ -648,23 +648,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 6691c63484..7902a16fb3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -34,7 +34,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -66,7 +66,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -674,23 +674,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 67fb4d651e..80ce6a1bc4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -36,26 +36,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_layernorm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, // MxN - const FloatC0* __restrict__ p_c0_bias_grid, // 1xN - const FloatC0* __restrict__ p_c0_add_grid, // MxN - const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN - const FloatC0* __restrict__ p_c0_beta_grid, // 1xN - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_layernorm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, // MxN + const FloatC0* __restrict__ p_c0_bias_grid, // 1xN + const FloatC0* __restrict__ p_c0_add_grid, // MxN + const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN + const FloatC0* __restrict__ p_c0_beta_grid, // 1xN + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index b7947309e4..697d0f90d9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -152,19 +152,19 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -182,16 +182,16 @@ __global__ void c_element_op, c_block_cluster_adaptor); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = c_block_cluster_adaptor; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -752,11 +752,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + MPerXDL, + NPerXDL, + FloatBAdjusted, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_skip_b_lds_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_skip_b_lds_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp index 3e23008a5f..0c5f8de1e4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -30,13 +30,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, - const Block2CTileMap& b2c_map, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) + kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); @@ -168,17 +168,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load void Print() const { - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "K0Padded:" << K0Padded << ", " + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "K0Padded:" << K0Padded << ", " << "KB:" << k_batch << "}" << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp index e9190dee29..104632d3f0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -23,19 +23,19 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, - const typename GridwiseGemm::FloatAB* p_b_grid, - typename GridwiseGemm::FloatC* p_c_grid, - void* p_workspace, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - typename GridwiseGemm::Block2CTileMap block_mapping) + kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, + const typename GridwiseGemm::FloatAB* p_b_grid, + typename GridwiseGemm::FloatC* p_c_grid, + void* p_workspace, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + typename GridwiseGemm::Block2CTileMap block_mapping) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -174,13 +174,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk void Print() const { - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << std::endl; + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 5c3d9b7ba4..dc9429ea6e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -26,17 +26,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif #if CK_USE_WAVES_PER_EU - __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) #endif - kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M_N c_grid_desc_m_n) + kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M_N c_grid_desc_m_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -50,24 +50,24 @@ __global__ void b_grid_desc_k0_n_k1, c_grid_desc_m_n); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; - ignore = c_grid_desc_m_n; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m_n; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif #if CK_USE_WAVES_PER_EU - __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) #endif - kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) + kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -90,7 +90,7 @@ __global__ void b_grid_desc_k0_n_k1, c_grid_desc_m_n); #else - ignore = karg; + ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -200,16 +200,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "K0:" << K0 << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "K0:" << K0 + << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index 7d8e94c001..978f08ad4a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -29,18 +29,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, - const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, - const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 256b495c6e..a546b471bf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -28,13 +28,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, - const Block2CTileMap& b2c_map, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) + kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -175,17 +175,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 void Print() const { - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "K0Padded:" << K0Padded << ", " + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "K0Padded:" << K0Padded << ", " << "KB:" << k_batch << "}" << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 15c2da9d32..66a3fef4eb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -31,20 +31,20 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v3r1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v3r1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index e22bfb6439..eb4e7d3db3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -31,23 +31,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v3r2( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v3r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index 3da5e66018..5bd5f75fa9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -32,26 +32,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v3r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const FloatC* __restrict__ p_c1_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v3r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 3d5066d52d..ca68fe9f86 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -40,7 +40,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) @@ -75,7 +75,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) @@ -619,22 +619,12 @@ struct GridwiseMoeGemm __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1714,18 +1704,16 @@ struct GridwiseMoeGemm // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -1746,40 +1734,40 @@ struct GridwiseMoeGemm const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -2436,18 +2424,16 @@ struct GridwiseMoeGemm // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2468,40 +2454,40 @@ struct GridwiseMoeGemm const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index f092c9c1eb..7145efbd97 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -40,7 +40,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) @@ -77,7 +77,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) @@ -626,22 +626,12 @@ struct GridwiseMoeGemmBlockScale __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1764,18 +1754,16 @@ struct GridwiseMoeGemmBlockScale // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -1796,40 +1784,40 @@ struct GridwiseMoeGemmBlockScale const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -2506,18 +2494,16 @@ struct GridwiseMoeGemmBlockScale // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2538,40 +2524,40 @@ struct GridwiseMoeGemmBlockScale const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index 5f8e524fb2..6731a7dda6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -81,7 +81,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) @@ -678,25 +678,14 @@ struct GridwiseMoeGemmMX __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t NumTokens; @@ -2769,18 +2758,16 @@ struct GridwiseMoeGemmMX // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2801,41 +2788,41 @@ struct GridwiseMoeGemmMX const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 9ccd334262..d8d77ae388 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -42,7 +42,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) @@ -205,11 +205,11 @@ struct GridwiseMoeGemmMXBNS static constexpr bool is_single_rate_mfma = false; static constexpr auto is_scale_mfma = true; using mfma_selector = MfmaSelector; + MPerXdl, + NPerXdl, + ComputeTypeB, + is_single_rate_mfma, + is_scale_mfma>; static constexpr index_t KPack = math::max( math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize); @@ -611,25 +611,14 @@ struct GridwiseMoeGemmMXBNS __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t NumTokens; @@ -1956,18 +1945,16 @@ struct GridwiseMoeGemmMXBNS // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -1988,41 +1975,41 @@ struct GridwiseMoeGemmMXBNS const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index be85528f28..7c3dbceeaa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -42,7 +42,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) @@ -79,7 +79,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) @@ -708,25 +708,14 @@ struct GridwiseMoeGemmMX_BPreshuffle __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t NumTokens; @@ -2588,18 +2577,16 @@ struct GridwiseMoeGemmMX_BPreshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2620,41 +2607,41 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp index 61d0f9e0d5..fa9b5fb2ce 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp @@ -86,7 +86,7 @@ struct GridwisePermute ~Block2TileMap() = default; Block2TileMap& operator=(const Block2TileMap&) = delete; - Block2TileMap& operator=(Block2TileMap&&) = delete; + Block2TileMap& operator=(Block2TileMap&&) = delete; explicit Block2TileMap(const InGridDesc& desc) : desc_(desc) {} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp index ddf0b4a58d..bffc3c696c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -25,15 +25,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_tensor_rearrange(const InputGridDesc in_grid_desc, - const InputDataType* __restrict__ p_in_global, - const OutputGridDesc out_grid_desc, - OutputDataType* __restrict__ p_out_global, - const index_t batch_count, - const Block2ETileMap block_2_tile_map, - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) + kernel_tensor_rearrange(const InputGridDesc in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const index_t batch_count, + const Block2ETileMap block_2_tile_map, + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp index 8a0e16d7f6..e399499cc8 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp @@ -399,7 +399,7 @@ struct GridwiseNormalizationBwdData_mk_to_mk dx_grid_desc_m_k, dx_global_val_buf); - } // end of sweep once + } // end of sweep once else // Sweep Twice pipeline { constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 4e4c92de40..2305997f70 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -823,8 +823,7 @@ struct ThreadwiseTensorSliceTransfer_v3 buffer_(Number{}) = src_tmp_vector.template AsType()[i]; }); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -837,8 +836,7 @@ struct ThreadwiseTensorSliceTransfer_v3 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { @@ -983,8 +981,7 @@ struct ThreadwiseTensorSliceTransfer_v3 is_dst_valid, dst_tmp_vector.template AsType()[Number<0>{}]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -997,8 +994,7 @@ struct ThreadwiseTensorSliceTransfer_v3 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 79e22018a6..4a6ed62c0e 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -246,22 +246,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using dst_elem_op_vec_t = typename vector_type::type; using VectorSizeLookupTable = Tuple, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence>; + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; using VectorOffsetsLookupTable = Tuple, Sequence, Sequence, @@ -308,8 +308,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 .template SetAsType(src_data_idx_seq, op_r_v.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -322,8 +321,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -636,8 +634,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -650,8 +647,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp index 174b82f870..8af6a2148b 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp @@ -229,8 +229,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant .template SetAsType( src_data_idx_seq, src_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -243,8 +242,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -376,8 +374,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant scale_thread_scratch_.template SetAsType( scale_data_idx_seq, scale_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -391,8 +388,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }); return move_on_dim_; - } - (); + }(); // move scale coord static_for<0, nDim, 1>{}([&](auto i) { @@ -666,8 +662,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -680,8 +675,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index 50f1e21beb..8574fd055c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -277,8 +277,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather .template SetAsType(src_data_idx_seq, op_r_v.template AsType()[I0]); - auto move_on_dim = [&]() constexpr - { + auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -292,8 +291,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { if(move_on_dim[i]) @@ -603,8 +601,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -617,8 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp index f0d793456d..9383e3f829 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -229,8 +229,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 src_data_idx_seq, src_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -245,8 +244,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -438,8 +436,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -454,8 +451,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index 40ebdeff08..4e9c188115 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -198,8 +198,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 src_vector.template AsType()[Number{}]; }); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -212,8 +211,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { @@ -368,8 +366,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 is_dst_valid, dst_vector.template AsType()[Number<0>{}]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -382,8 +379,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 9b1ff3dbf8..65e63993a6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -421,8 +421,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter { constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); - auto forward_step_scatter = [&]() constexpr - { + auto forward_step_scatter = [&]() constexpr { Index step_; static_for<0, nDim, 1>{}([&](auto i) { @@ -430,8 +429,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter }); return step_; - } - (); + }(); static_for<0, nDst, 1>{}([&](auto i) { move_tensor_coordinate( dst_descs[i], @@ -493,8 +491,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter { constexpr auto reset_step = DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - auto reset_step_scatter = [&]() constexpr - { + auto reset_step_scatter = [&]() constexpr { Index step_; static_for<0, nDim, 1>{}([&](auto i) { step_(i) = @@ -502,8 +499,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter }); return step_; - } - (); + }(); return reset_step_scatter; } } diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index b7af32d3dc..2edbb7c789 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -1400,7 +1400,7 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); #endif // #ifndef CK_CODE_GEN_RTC @@ -1426,7 +1426,7 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); #endif // #ifndef CK_CODE_GEN_RTC @@ -1503,7 +1503,7 @@ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f[0]); + rng = prand_generator(reinterpret_cast(&f), f[0]); #else rng = prand_generator(reinterpret_cast(&f), f[0]); #endif // #ifndef CK_CODE_GEN_RTC @@ -1704,7 +1704,7 @@ __host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&x), + rng = prand_generator(reinterpret_cast(&x), static_cast(x[0])); #else rng = prand_generator(reinterpret_cast(&x), @@ -1734,7 +1734,7 @@ using bf8_t = bf8_ocp_t; #define CK_FP8_TYPE_FNUZ 0 #define CK_FP8_TYPE_OCP 1 #else -using f8_t = f8_fnuz_t; +using f8_t = f8_fnuz_t; using bf8_t = bf8_fnuz_t; #define CK_FP8_TYPE_FNUZ 1 #define CK_FP8_TYPE_OCP 0 diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index bd0ca42ecd..d6524283db 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -19,7 +19,7 @@ __host__ __device__ constexpr auto container_push_back(const Array { Array r; - static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; }); r(Number{}) = x; diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index ed42b22daf..027290dbf8 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -232,7 +232,7 @@ struct DynamicBuffer #if CK_USE_AMD_BUFFER_LOAD bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); #else - bool constexpr use_amd_buffer_addressing = false; + bool constexpr use_amd_buffer_addressing = false; #endif #if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE diff --git a/include/ck/utility/is_detected.hpp b/include/ck/utility/is_detected.hpp index a700fcfff1..8cb37b68b2 100644 --- a/include/ck/utility/is_detected.hpp +++ b/include/ck/utility/is_detected.hpp @@ -25,8 +25,8 @@ struct detector>, Op, Args...> struct nonesuch { - ~nonesuch() = delete; - nonesuch(nonesuch const&) = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; void operator=(nonesuch const&) = delete; }; diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index 7b079c541c..993b70a3fb 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -75,7 +75,7 @@ struct MagicDivision // integral_constant template __host__ __device__ static constexpr auto - CalculateMagicNumbers(integral_constant) + CalculateMagicNumbers(integral_constant) { constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor}); @@ -88,7 +88,7 @@ struct MagicDivision template __host__ __device__ static constexpr auto - CalculateMagicMultiplier(integral_constant) + CalculateMagicMultiplier(integral_constant) { constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor}); @@ -97,7 +97,7 @@ struct MagicDivision template __host__ __device__ static constexpr auto - CalculateMagicShift(integral_constant) + CalculateMagicShift(integral_constant) { constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor}); @@ -107,21 +107,21 @@ struct MagicDivision // integral_constant template __host__ __device__ static constexpr auto - CalculateMagicNumbers(integral_constant) + CalculateMagicNumbers(integral_constant) { return CalculateMagicNumbers(integral_constant{}); } template __host__ __device__ static constexpr auto - CalculateMagicMultiplier(integral_constant) + CalculateMagicMultiplier(integral_constant) { return CalculateMagicMultiplier(integral_constant{}); } template __host__ __device__ static constexpr auto - CalculateMagicShift(integral_constant) + CalculateMagicShift(integral_constant) { return CalculateMagicShift(integral_constant{}); } diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 497625f7e2..75f0c92c58 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -342,8 +342,8 @@ struct sequence_reverse using seq_split = sequence_split; using type = typename sequence_merge< - typename sequence_reverse::type, - typename sequence_reverse::type>::type; + typename sequence_reverse::type, + typename sequence_reverse::type>::type; }; template diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index e9fd1ea88f..99538ac78c 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -259,7 +259,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(float x) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif // #ifndef CK_CODE_GEN_RTC @@ -327,7 +327,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(float x) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif // #ifndef CK_CODE_GEN_RTC @@ -1495,7 +1495,7 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif @@ -1520,7 +1520,7 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #endif @@ -1565,7 +1565,7 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #endif @@ -1817,7 +1817,7 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif @@ -2155,7 +2155,7 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 8dabb58451..26cfcaa2f0 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -407,17 +407,17 @@ struct Tensor ElementSpaceSize, true /*InvalidElementUseNumericalZeroValue*/>; using StaticBufferType = std::conditional_t< - is_scalar_type::value, - StaticBuffer, - StaticBufferTupleOfVector>::vector_size, - scalar_type>::vector_size, - true /*InvalidElementUseNumericalZeroValue*/>>; + is_scalar_type::value, + StaticBuffer, + StaticBufferTupleOfVector>::vector_size, + scalar_type>::vector_size, + true /*InvalidElementUseNumericalZeroValue*/>>; // If register use static buffer, else use dynamic buffer using Buffer = std::conditional_t; diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index aaa7db2574..f7f9489f4c 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1259,7 +1259,7 @@ struct slice : public base_transform<1, 1> printf("}"); } // namespace ck -}; // namespace ck +}; // namespace ck /* * \brief lower_idx = upper_idx % modulus. diff --git a/include/ck_tile/core/algorithm/space_filling_curve.hpp b/include/ck_tile/core/algorithm/space_filling_curve.hpp index 6591acddb9..648a1251be 100644 --- a/include/ck_tile/core/algorithm/space_filling_curve.hpp +++ b/include/ck_tile/core/algorithm/space_filling_curve.hpp @@ -100,10 +100,8 @@ struct space_filling_curve // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the // idim-th element of multidimensional index. // All constexpr variables have to be captured by VALUE. - constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr - { - constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr - { + constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr { + constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr { auto res = idx_1d.value; auto id = 0; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index add6b1dbdc..0932f39ca7 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -302,12 +302,12 @@ struct buffer_load_if<16, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); if constexpr(pre_nop) asm volatile("s_nop 4\n" @@ -336,12 +336,12 @@ struct buffer_load_if<8, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -369,12 +369,12 @@ struct buffer_load_if<4, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -402,12 +402,12 @@ struct buffer_load_if<2, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -435,12 +435,12 @@ struct buffer_load_if<1, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -624,7 +624,7 @@ struct buffer_store_if<16> { static_assert(sizeof(T) == 16); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; + using mbuf_t = fp32x4_t; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -681,7 +681,7 @@ struct buffer_store_if<4> { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -709,7 +709,7 @@ struct buffer_store_if<2> { static_assert(sizeof(T) == 2); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = short; + using mbuf_t = short; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -737,7 +737,7 @@ struct buffer_store_if<1> { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index e2a73e6242..0723026836 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -13,7 +13,7 @@ #define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111 #define CK_TILE_VMCNT(cnt) \ ([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \ - ((cnt)&0b1111) | (((cnt)&0b110000) << 10)) + ((cnt) & 0b1111) | (((cnt) & 0b110000) << 10)) #define CK_TILE_EXPCNT(cnt) \ ([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4)) #define CK_TILE_LGKMCNT(cnt) \ diff --git a/include/ck_tile/core/container/container_helper.hpp b/include/ck_tile/core/container/container_helper.hpp index 474eda80d1..1a631bd95e 100644 --- a/include/ck_tile/core/container/container_helper.hpp +++ b/include/ck_tile/core/container/container_helper.hpp @@ -16,7 +16,7 @@ template CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array& a, const TData& x) { array r; - static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; }); r[number{}] = x; return r; } diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index b187b71830..94309dd5dd 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -1236,9 +1236,8 @@ constexpr auto reverse_slice_sequence(Seq, template ::type> -constexpr auto slice_sequence(Seq, - number, - Mask = typename uniform_sequence_gen::type{}) +constexpr auto +slice_sequence(Seq, number, Mask = typename uniform_sequence_gen::type{}) { constexpr auto r = reverse_slice_sequence(Seq{}.reverse(), number{}, Mask{}.reverse()); diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index b5da468319..a3ce614f84 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -75,7 +75,7 @@ struct alignas(1) float8_e4m3_t #if CK_TILE_USE_OCP_FP8 static constexpr int bias = 7; // OCP #else - static constexpr int bias = 8; // FNUZ + static constexpr int bias = 8; // FNUZ #endif using raw_type = uint8_t; raw_type data; diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 8176fe551c..b8a31ba8fc 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -31,8 +31,8 @@ struct scales CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {} template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const - -> decltype(std::declval() * rhs) + CK_TILE_HOST_DEVICE constexpr auto + operator()(const Right& rhs) const -> decltype(std::declval() * rhs) { return lhs_ * rhs; } @@ -43,13 +43,13 @@ struct scales /// FIXME: create macro to replace '__host__ __device__' and nothing more template -__host__ __device__ scales(Scale)->scales; +__host__ __device__ scales(Scale) -> scales; template struct plus { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs + rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs + rhs) { return lhs + rhs; } @@ -59,21 +59,21 @@ template <> struct plus { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs + rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs + rhs) { return lhs + rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ plus()->plus; +__host__ __device__ plus() -> plus; template struct minus { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs - rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs - rhs) { return lhs - rhs; } @@ -83,21 +83,21 @@ template <> struct minus { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs - rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs - rhs) { return lhs - rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ minus()->minus; +__host__ __device__ minus() -> minus; template struct multiplies { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs * rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs * rhs) { return lhs * rhs; } @@ -107,15 +107,15 @@ template <> struct multiplies { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs * rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs * rhs) { return lhs * rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ multiplies()->multiplies; +__host__ __device__ multiplies() -> multiplies; template struct maximize @@ -327,8 +327,8 @@ CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys) template struct equal { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs == rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs == rhs) { return lhs == rhs; } @@ -338,15 +338,15 @@ template <> struct equal { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs == rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs == rhs) { return lhs == rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ equal()->equal; +__host__ __device__ equal() -> equal; template <> struct equal @@ -369,8 +369,8 @@ struct equal template struct less { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs < rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs < rhs) { return lhs < rhs; } @@ -380,21 +380,21 @@ template <> struct less { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs < rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs < rhs) { return lhs < rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ less()->less; +__host__ __device__ less() -> less; template struct less_equal { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs <= rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs <= rhs) { return lhs <= rhs; } @@ -404,15 +404,15 @@ template <> struct less_equal { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs <= rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs <= rhs) { return lhs <= rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ less_equal()->less_equal; +__host__ __device__ less_equal() -> less_equal; template <> struct less_equal diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index ceb7e18556..1535250722 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -117,8 +117,8 @@ struct DefaultTranspose struct ValidationTraitsImpl { using QuadEncoding = std::conditional_t, - QuadInputEncoding>; + QuadOutputEncoding, + QuadInputEncoding>; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto input_hs = InDstrEncode::hs_lengthss_; @@ -396,9 +396,9 @@ template < index_t NumCoord, typename Policy = DefaultTranspose, typename = std::enable_if_t::distr_encoding_valid, - Policy>> + typename BottomTensorView_::DataType, + Policy>::distr_encoding_valid, + Policy>> CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution::type> -CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper; +CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper; } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index 6bcba4019c..e2a6ae6555 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -81,7 +81,7 @@ struct tensor_adaptor template CK_TILE_HOST_DEVICE static constexpr auto - get_transform_and_its_upper_dimension(number) + get_transform_and_its_upper_dimension(number) { // FIXME: length of bottom dimension is not known, since info about lower dim length are not // saved in transformation @@ -119,13 +119,13 @@ struct tensor_adaptor CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension() { - constexpr auto all_low_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - LowerDimensionHiddenIdss{}); + constexpr auto all_low_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); - constexpr auto all_up_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - UpperDimensionHiddenIdss{}); + constexpr auto all_up_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); @@ -461,7 +461,7 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor, sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus{}, number<0>{})); constexpr auto up_dim_hidden_idss = generate_tuple( - [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + [old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr { return typename arithmetic_sequence_gen{}); // new top dimension's hidden ids - constexpr auto unordered_new_top_dim_hidden_ids = unpack( - [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + constexpr auto unordered_new_top_dim_hidden_ids = + unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); constexpr auto new_top_dim_unordered2ordered = unpack( [](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{}); @@ -595,8 +595,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::get_lower_dimension_hidden_idss()[itran]; // sequence in, sequence out - constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr { auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); // shift hidden id so every dim id is unique @@ -619,8 +618,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return low_dim_hidden_ids_1_mod_; - } - (); + }(); return generate_sequence_v2( [&](auto i) constexpr { return number{}; }, @@ -643,8 +641,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::get_upper_dimension_hidden_idss()[itran]; // sequence in, constexpr tuple out - constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr { auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); // shift hidden id @@ -653,8 +650,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return up_dim_hidden_ids_1_mod_; - } - (); + }(); // constexpr tuple to sequence return generate_sequence_v2( diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index d7be5957c6..11e6b35c39 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -202,7 +202,7 @@ struct tile_distribution // FIXME: it's hacky to get Y index from Distributed-Index template CK_TILE_HOST_DEVICE static constexpr auto - get_y_indices_from_distributed_indices(DistributedIndices) + get_y_indices_from_distributed_indices(DistributedIndices) { constexpr auto ys_idx_arr = [] { array ys_idx; @@ -266,7 +266,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t // this returns a constexpr encoding of tile_distribution template CK_TILE_HOST_DEVICE constexpr auto - make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_) +make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_) { using RsLengths = typename StaticTileDistributionEncoding_::RsLengths; using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss; @@ -614,8 +614,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( constexpr auto src_y_maps = src_y_info[number<1>{}]; constexpr auto src_y_prefix_sum = src_y_info[number<2>{}]; - constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr - { + constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr { auto y_slice_sorted_origins = make_zero_multi_index(); auto y_slice_lengths = Encoding::detail::ys_lengths_; constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks(); @@ -685,8 +684,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps); return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths); - } - (); + }(); constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}]; constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}]; diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index d2b24ad54e..284efd5d70 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -327,9 +327,8 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors) template CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) { - if constexpr((std::is_same_v || - std::is_same_v)&&std::is_same_v && + if constexpr((std::is_same_v || std::is_same_v) && + std::is_same_v && (SrcTensor::get_thread_buffer_size() % 4 == 0)) { return impl::cast_tile_pk_fp8_fp32(src_tensor); diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index c4b24fba93..b5a89e5f51 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -74,8 +74,9 @@ struct tile_window_linear static constexpr auto get_num_non_linear_access() { constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths; - using ys_to_rhs_major = typename decltype( - typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + using ys_to_rhs_major = + typename decltype(typename Base::TileDstr{} + .get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto non_linear = [&]() { index_t cnt = 1; @@ -109,8 +110,9 @@ struct tile_window_linear static constexpr auto get_non_linear_access_map() { constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths; - using ys_to_rhs_major = typename decltype( - typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + using ys_to_rhs_major = + typename decltype(typename Base::TileDstr{} + .get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto non_linear_map = [&]() { array m_{0}; index_t cumulative_len_ = 1; @@ -244,8 +246,9 @@ struct tile_window_linear { using SFC_Ys = typename Base::Traits::SFC_Ys; constexpr auto idx_ys = SFC_Ys::get_index(number{}); - using ys_to_rhs_major = typename decltype( - typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + using ys_to_rhs_major = + typename decltype(typename Base::TileDstr{} + .get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto modified_idx_ys = generate_tuple( [&](auto i_dim_y) { diff --git a/include/ck_tile/core/utility/debug.hpp b/include/ck_tile/core/utility/debug.hpp index 261bf50148..15f0718dc2 100644 --- a/include/ck_tile/core/utility/debug.hpp +++ b/include/ck_tile/core/utility/debug.hpp @@ -48,7 +48,7 @@ struct str_literal template constexpr std::tuple...> - makeTuple(std::index_sequence) noexcept +makeTuple(std::index_sequence) noexcept { return {}; } @@ -113,8 +113,8 @@ struct CK_PRINTF) const { using FMT1 = std::conditional_t()), - str_literal>; + decltype(default_format()), + str_literal>; constexpr auto fmt_v = FMT1::template duplicate_n(make_str_literal(" ")); constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix(); diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 95fb1bd834..c43a64edaa 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -58,8 +58,8 @@ struct detector>, Op, Args...> struct nonesuch { - ~nonesuch() = delete; - nonesuch(nonesuch const&) = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; void operator=(nonesuch const&) = delete; }; diff --git a/include/ck_tile/core/utility/unary_element_function.hpp b/include/ck_tile/core/utility/unary_element_function.hpp index ed3b464660..6bd6e33bd3 100644 --- a/include/ck_tile/core/utility/unary_element_function.hpp +++ b/include/ck_tile/core/utility/unary_element_function.hpp @@ -49,7 +49,7 @@ struct composes /// FIXME: create macro to replace '__host__ __device__' and nothing more template -__host__ __device__ composes(Ts&&...)->composes...>; +__host__ __device__ composes(Ts&&...) -> composes...>; template struct saturates @@ -57,8 +57,8 @@ struct saturates // NOTE: this function does not return SaturateType value // it is user's responsiblity to do further cast or not template - CK_TILE_HOST_DEVICE constexpr auto operator()(const AccType& a_) const - -> std::enable_if_t, AccType> + CK_TILE_HOST_DEVICE constexpr auto + operator()(const AccType& a_) const -> std::enable_if_t, AccType> { return clamp(a_, type_convert(numeric::lowest()), diff --git a/include/ck_tile/host/concat.hpp b/include/ck_tile/host/concat.hpp index c68b908149..e9ba9a7d7b 100644 --- a/include/ck_tile/host/concat.hpp +++ b/include/ck_tile/host/concat.hpp @@ -33,13 +33,14 @@ struct IsCharArray : std::true_type }; template -inline constexpr bool AllConvertibleToStringView = ((std::is_convertible_v || - IsCharArray::value || - std::is_same_v)&&...); +inline constexpr bool AllConvertibleToStringView = + ((std::is_convertible_v || IsCharArray::value || + std::is_same_v) && + ...); template -[[nodiscard]] auto concat(const Ts&... xs) - -> std::enable_if_t, std::string> +[[nodiscard]] auto +concat(const Ts&... xs) -> std::enable_if_t, std::string> { using ::operator<<; thread_local std::ostringstream oss; @@ -78,8 +79,8 @@ template } template -auto concatInto(std::string& result, const Ts&... xs) - -> std::enable_if_t, void> +auto concatInto(std::string& result, + const Ts&... xs) -> std::enable_if_t, void> { const std::size_t space = (1 + ... + getSize(xs)); result.reserve(result.size() + space); @@ -87,8 +88,8 @@ auto concatInto(std::string& result, const Ts&... xs) } template -[[nodiscard]] auto concat(const Ts&... xs) - -> std::enable_if_t, std::string> +[[nodiscard]] auto +concat(const Ts&... xs) -> std::enable_if_t, std::string> { std::string result; concatInto(result, xs...); diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 9b31a7889d..e03881a1c7 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -64,7 +64,7 @@ struct FillUniformDistribution return; // need to make each thread unique, add an offset to current seed std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin) - : std::random_device{}()); + : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); @@ -242,7 +242,7 @@ struct FillNormalDistribution return; // need to make each thread unique, add an offset to current seed std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin) - : std::random_device{}()); + : std::random_device{}()); std::normal_distribution dis(mean_, std::sqrt(variance_)); std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); @@ -407,9 +407,10 @@ struct FillStepRange } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); @@ -428,9 +429,10 @@ struct FillConstant } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); @@ -512,9 +514,10 @@ struct FillTrigValue } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index ecbc009b85..c3f1b7d221 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -378,7 +378,7 @@ struct HostTensor ~HostTensor() = default; HostTensor& operator=(const HostTensor&) = default; - HostTensor& operator=(HostTensor&&) = default; + HostTensor& operator=(HostTensor&&) = default; template explicit HostTensor(const HostTensor& other) : HostTensor(other.template CopyAsType()) diff --git a/include/ck_tile/host/joinable_thread.hpp b/include/ck_tile/host/joinable_thread.hpp index a822f967dc..a42b567fb4 100644 --- a/include/ck_tile/host/joinable_thread.hpp +++ b/include/ck_tile/host/joinable_thread.hpp @@ -15,7 +15,7 @@ struct joinable_thread : std::thread { } - joinable_thread(joinable_thread&&) = default; + joinable_thread(joinable_thread&&) = default; joinable_thread& operator=(joinable_thread&&) = default; ~joinable_thread() diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp index 1e877b9933..b7615d0478 100644 --- a/include/ck_tile/host/reference/reference_moe_sorting.hpp +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -9,7 +9,7 @@ namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ - static_cast(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24)) + static_cast(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24)) template CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 7ae63e17a7..d42f144baa 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -284,8 +284,8 @@ struct CShuffleEpilogue {0, 0}); using SFC = space_filling_curve, - sequence<0, 1>, - sequence>; + sequence<0, 1>, + sequence>; constexpr index_t num_access = SFC::get_num_of_access(); static_assert(std::is_same_v, @@ -336,8 +336,8 @@ struct CShuffleEpilogue const auto c_ds_tiles = concat_tuple_of_reference( tie(c_out_tensor, c_out_tensor), - generate_tie( - [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); + generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; }, + number{})); tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index edb5853c7f..54f2a777bf 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -458,7 +458,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV { return operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const ADataType & a) { return a; }, b_flat_dram_block_window_tmp, num_loop, p_smem); diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 837aeb13e3..cc00000efc 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -431,12 +431,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::BDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< typename Problem::ADataType, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 0b8e5836cd..3489d6f9a1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -509,7 +509,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto - MakeKLdsStoreBlockDescriptor(number = number<0>{}) + MakeKLdsStoreBlockDescriptor(number = number<0>{}) { // K is always k-major, we use async-copy to load into LDS constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 76ba34115f..570cff8bf0 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -60,8 +60,8 @@ struct TileFmhaShape // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; using VLayout = std::conditional_t; + ck_tile::tensor_layout::gemm::RowMajor, + ck_tile::tensor_layout::gemm::ColumnMajor>; }; template (kargs.o_ptr); auto o_view_ = make_naive_tensor_view( + memory_operation_enum::atomic_add>( o_ptr, make_tuple(kargs.num_tokens, kargs.hidden_size), make_tuple(kargs.stride_token, 1), diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index db85fae643..a5f9f31d6a 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -13,7 +13,7 @@ namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ - static_cast(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24)) + static_cast(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24)) #ifndef MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_USE_EX_KERNEL 1 diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp index e9577e2304..17c38a2632 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp @@ -267,8 +267,7 @@ struct FusedMoeGemmPipeline_FlatmmEx statically_indexed_array as; auto gld_a = [&]>( - auto& a_store_, auto i_access, PreNop = {}) - { + auto& a_store_, auto i_access, PreNop = {}) { async_load_tile_raw(a_store_, a_win, i_access, PreNop{}); }; auto move_a = [&]() { @@ -278,43 +277,40 @@ struct FusedMoeGemmPipeline_FlatmmEx load_tile_raw(a_, win_, i_access); }; - auto gld_g = [&]>( - auto& g_, auto i_access, PreNop = {}) - { - if constexpr(IsGateOnly) - { - // TODO: hack! - if constexpr(i_access.value == 0) + auto gld_g = + [&]>(auto& g_, auto i_access, PreNop = {}) { + if constexpr(IsGateOnly) { - g_win.bottom_tensor_view_ = g_view; + // TODO: hack! + if constexpr(i_access.value == 0) + { + g_win.bottom_tensor_view_ = g_view; + } + else if constexpr(i_access.value == issues_g / 2) + { + g_win.bottom_tensor_view_ = u_view; + } } - else if constexpr(i_access.value == issues_g / 2) - { - g_win.bottom_tensor_view_ = u_view; - } - } - load_tile_raw(g_, g_win, i_access, FALSE, PreNop{}); - }; + load_tile_raw(g_, g_win, i_access, FALSE, PreNop{}); + }; auto move_g = [&]() { move_tile_window(g_win, {number<0>{}, number{}, number<0>{}}); }; statically_indexed_array ds; - auto gld_d = [&]>( - auto& d_, auto i_access, PreNop = {}) - { - load_tile_raw(d_, d_win, i_access, FALSE, PreNop{}); - }; + auto gld_d = + [&]>(auto& d_, auto i_access, PreNop = {}) { + load_tile_raw(d_, d_win, i_access, FALSE, PreNop{}); + }; auto move_d = [&]() { // d move along gemm-n move_tile_window(d_win, {number{}, number<0>{}}); }; - auto atomic_add_o = [&]>( - auto& o_, auto i_access, PreNop = {}) - { - update_tile_raw(o_win, o_, i_access, TRUE, PreNop{}); - }; + auto atomic_add_o = + [&]>(auto& o_, auto i_access, PreNop = {}) { + update_tile_raw(o_win, o_, i_access, TRUE, PreNop{}); + }; auto acc_0 = Policy::template MakeCBlockTile_Gemm0(); auto acc_1s = generate_tuple( diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 28e8bee908..0a6bacdc42 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -69,8 +69,8 @@ struct GemmTile2DPartitioner * @param blockIdy WGP's Y index. * @return const tuple Tuple containing 2D output C-tile index. */ - CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept - -> const tuple + CK_TILE_DEVICE static auto + GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple { const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx); const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy); @@ -137,8 +137,8 @@ struct GemmTile1DPartitioner * @param blockIdx WGP's index. * @return const tuple Tuple containing 2D output C-tile index. */ - CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx) noexcept - -> const tuple + CK_TILE_DEVICE static auto + GetOutputTileIndex(index_t blockIdx) noexcept -> const tuple { const index_t NBlocks = integer_divide_ceil(N_, NPerBlock); @@ -188,9 +188,8 @@ struct OffsettedTile1DPartitioner * @param [in] N Gemm's N dimension. * @return Returns a `tuple` [Im, In] with shifted index. */ - [[nodiscard]] CK_TILE_DEVICE static auto - GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept - -> const tuple + [[nodiscard]] CK_TILE_DEVICE static auto GetOffsetedTileIndex( + index_t block_start, index_t M, index_t N) noexcept -> const tuple { const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start); return make_tuple(iM, iN); @@ -271,8 +270,8 @@ struct GemmSpatiallyLocalTilePartitioner * @param [in] block_1d_id WGP's index. * @return const tuple Tuple containing 2D output C-tile index. */ - CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept - -> const tuple + CK_TILE_DEVICE auto + GetOutputTileIndex(index_t block_1d_id) noexcept -> const tuple { const auto M0 = integer_divide_ceil(M, MPerBlock); const auto N0 = integer_divide_ceil(N, NPerBlock); diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 8716475869..921ea11720 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -144,8 +144,8 @@ struct GroupedGemmKernel // clang-format on } - CK_TILE_HOST static auto GetWorkSpaceSize(const std::vector& gemm_descs) - -> std::size_t + CK_TILE_HOST static auto + GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } @@ -185,8 +185,8 @@ struct GroupedGemmKernel return dim3(grid_size, 1, 1); } - CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) - -> std::vector + CK_TILE_HOST static auto + MakeKargs(const std::vector& gemm_descs) -> std::vector { std::vector gemm_kernel_args_; index_t group_count = ck_tile::type_convert(gemm_descs.size()); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index 4e9a70140e..7d88c804f3 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -28,20 +28,20 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy (DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType)) == (WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size()); constexpr auto wg_attr_num_access = - ((is_a_load_tr || is_b_load_tr)&&!single_load_tr_length) + ((is_a_load_tr || is_b_load_tr) && !single_load_tr_length) ? WGAttrNumAccessEnum::Double : WGAttrNumAccessEnum::Single; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::BDataType, + typename Problem::CDataType, // AccDataType + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC, + false, + false, + wg_attr_num_access>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + typename Problem::BDataType, + typename Problem::CDataType, // AccDataType + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + typename Problem::ComputeDataType, + AccDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy(); using TileEncodingPattern = TileDistributionEncodingPattern2D; + KPerBlock, + NPerBlock, + VecLoadSize, + BTileAccessPattern>; constexpr auto BK0 = number{}; constexpr auto BK1 = number{}; @@ -636,15 +636,15 @@ struct UniversalGemmPipelineAgBgCrPolicy : WGAttrNumAccessEnum::Invalid; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::ComputeDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC, + false, + Problem::UseStructuredSparsity, + wg_attr_num_access>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + typename Problem::BDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockWeightPreshufflePolicy = BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy // 2. bf8, fp32, bf8 -> f32 // 3. i4, (fp8/fp32) fp8 -> f32 // 4. i4, (fp8/fp32) bf8 -> f32 - static_assert( - (std::is_same_v || std::is_same_v || - std::is_same_v< - ADataType, - bf8_t>)&&(std::is_same_v || - std::is_same_v< - BDataType, - bf8_t>)&&(std::is_same_v || - std::is_same_v || - std::is_same_v< - AQDataType, - ck_tile::bf8_t>)&&(std::is_same_v || - std::is_same_v)&&std:: - is_same_v); + static_assert((std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v) && + (std::is_same_v || + std::is_same_v || + std::is_same_v) && + (std::is_same_v || + std::is_same_v) && + std::is_same_v); static constexpr index_t InterWaveSchedulingMacClusters = 1; diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index 83b61e23fc..2004f7d90e 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -44,12 +44,12 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t VecLoadSize = GetVectorSizeAQ(); using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::ComputeDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + false>; static_assert(std::is_same_v); using TileEncodingPattern = TileDistributionEncodingPatternAQ(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -106,15 +106,15 @@ struct GroupedConvBwdWeightKernelArgs CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -122,13 +122,13 @@ struct GroupedConvBwdWeightKernelArgs static_cast(args.output_spatial_lengths_[1])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1])}; + static_cast(args.conv_filter_strides_[1])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1])}; + static_cast(args.input_left_pads_[1])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1])}; + static_cast(args.input_right_pads_[1])}; k_batch = args.k_batch; @@ -182,17 +182,17 @@ struct GroupedConvBwdWeightKernelArgs CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1]), - static_cast(args.input_spatial_lengths_[2])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1]), - static_cast(args.filter_spatial_lengths_[2])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -201,17 +201,17 @@ struct GroupedConvBwdWeightKernelArgs static_cast(args.output_spatial_lengths_[2])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1]), - static_cast(args.conv_filter_strides_[2])}; + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1]), static_cast(args.conv_filter_dilations_[2])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1]), - static_cast(args.input_left_pads_[2])}; + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1]), - static_cast(args.input_right_pads_[2])}; + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; k_batch = args.k_batch; @@ -254,8 +254,9 @@ struct GroupedConvBwdWeightKernelArgs GemmBatch = args.G_; } - using ABCGridDescs = remove_cvref_t; + using ABCGridDescs = + remove_cvref_t; using AGridDescMK = remove_cvref_t{}])>; using BGridDescNK = remove_cvref_t{}])>; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index f979d96326..8cd1710043 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -37,13 +37,13 @@ struct GroupedConvFwdKernelArgs CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -107,15 +107,15 @@ struct GroupedConvFwdKernelArgs CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -123,13 +123,13 @@ struct GroupedConvFwdKernelArgs static_cast(args.output_spatial_lengths_[1])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1])}; + static_cast(args.conv_filter_strides_[1])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1])}; + static_cast(args.input_left_pads_[1])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1])}; + static_cast(args.input_right_pads_[1])}; k_batch = args.k_batch; @@ -184,17 +184,17 @@ struct GroupedConvFwdKernelArgs CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1]), - static_cast(args.input_spatial_lengths_[2])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1]), - static_cast(args.filter_spatial_lengths_[2])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -203,17 +203,17 @@ struct GroupedConvFwdKernelArgs static_cast(args.output_spatial_lengths_[2])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1]), - static_cast(args.conv_filter_strides_[2])}; + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1]), static_cast(args.conv_filter_dilations_[2])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1]), - static_cast(args.input_left_pads_[2])}; + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1]), - static_cast(args.input_right_pads_[2])}; + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; k_batch = args.k_batch; @@ -259,15 +259,15 @@ struct GroupedConvFwdKernelArgs group_stride_c = args.K_; } - using AGridDescMK = remove_cvref_t())>; - using BGridDescNK = remove_cvref_t())>; - using CGridDescMN = remove_cvref_t())>; + using AGridDescMK = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeADescriptor_M_K())>; + using BGridDescNK = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeBDescriptor_N_K())>; + using CGridDescMN = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeCDescriptor_M_N())>; static constexpr index_t NonSpatialDims = 3; array in_g_n_c_wis_lengths; diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 48aaed3aae..b173ab25a1 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -67,11 +67,11 @@ struct GroupedConvTraits using DsLayout = DsLayout_; using OutLayout = OutLayout_; using GroupedConvImplicitGemmTraits = TileGemmTraits; + true, + true, + ck_tile::tensor_layout::gemm::RowMajor, + ck_tile::tensor_layout::gemm::ColumnMajor, + ck_tile::tensor_layout::gemm::RowMajor>; static constexpr index_t NumDTensor = DsLayout::size(); using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout()); }; diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index c93329bfbe..434be9f84a 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -380,6 +380,6 @@ struct BlockReduce2D // deduction guide template -CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D; +CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D; } // namespace ck_tile diff --git a/include/ck_tile/ref/naive_attention.hpp b/include/ck_tile/ref/naive_attention.hpp index 98ceab6992..172fcee2e3 100644 --- a/include/ck_tile/ref/naive_attention.hpp +++ b/include/ck_tile/ref/naive_attention.hpp @@ -695,18 +695,18 @@ struct naive_attention_fwd_kernel static_cast(variation_), \ static_cast(quant_algo_)>; \ using k_ = naive_attention_fwd_kernel; \ + k_type_, \ + v_type_, \ + o_type_, \ + acc_type_, \ + kvscale_type_, \ + q_layout_, \ + k_layout_, \ + v_layout_, \ + o_layout_, \ + k_scale_layout_, \ + v_scale_layout_, \ + ktraits_>; \ dim3 grids = k_::get_grid_size(a); \ r = ck_tile::launch_kernel(s, \ ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \ diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index 9f2ef3389f..6f5a425207 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -1,14 +1,8 @@ -from datetime import datetime -import pathlib -from pathlib import Path -import subprocess -import os -import copy +from datetime import datetime import pathlib from pathlib import Path import subprocess import os + import copy -NS = 'ck_tile' -OPS = 'ops' -REF = 'ref' -OPS_COMMON = 'common' # common header will be duplicated into ops/* other module + NS = 'ck_tile' OPS = 'ops' REF = 'ref' OPS_COMMON = + 'common' #common header will be duplicated into ops/* other module HEADER_COMMON = f"""// SPDX-License-Identifier: MIT // Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n @@ -82,7 +76,7 @@ submodule = submodule_t() # formatting for x in all_files: subprocess.Popen(f'dos2unix {str(x)}', shell=True) - cmd = f'clang-format-12 -style=file -i {str(x)}' + cmd = f'clang-format-18 -style=file -i {str(x)}' #for xp in x.parents: #print(get_file_base(x)) subprocess.Popen(cmd, shell=True) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index 120bf7484a..59dfd76ede 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -116,7 +116,7 @@ struct ReferenceMoeGemm : public device::BaseOperator #if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); #else - v_a = i4 - 8; + v_a = i4 - 8; #endif } else diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp index eedd687bde..9f04cf3e3d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp @@ -110,7 +110,7 @@ struct ReferenceMoeGemm1BlockScale : public device::BaseOperator #if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); #else - v_a = i4 - 8; + v_a = i4 - 8; #endif } else diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp index 2c2cac77e3..28274a5154 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -25,17 +25,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - naive_gemm_kernel(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - CDataType* __restrict__ p_c_grid, - index_t m, - index_t n, - index_t k, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op) + naive_gemm_kernel(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + index_t m, + index_t n, + index_t k, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) { using RowMajor = ck::tensor_layout::gemm::RowMajor; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp index 681f466677..2f0c6113de 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp @@ -23,8 +23,9 @@ template using S = ck::Sequence; template -using device_column_to_image_bf16_instances = std::tuple< - // clang-format off +using device_column_to_image_bf16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -39,12 +40,13 @@ using device_column_to_image_bf16_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_column_to_image_f16_instances = std::tuple< - // clang-format off +using device_column_to_image_f16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -59,12 +61,13 @@ using device_column_to_image_f16_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_column_to_image_f32_instances = std::tuple< - // clang-format off +using device_column_to_image_f32_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -76,12 +79,13 @@ using device_column_to_image_f32_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4> - // clang-format on - >; + // clang-format on + >; template -using device_column_to_image_i8_instances = std::tuple< - // clang-format off +using device_column_to_image_i8_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -97,8 +101,8 @@ using device_column_to_image_i8_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 8>, DeviceColumnToImageImpl, 16> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp index 74a2155a04..2d2798b667 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp @@ -23,8 +23,9 @@ template using S = ck::Sequence; template -using device_image_to_column_bf16_instances = std::tuple< - // clang-format off +using device_image_to_column_bf16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -38,12 +39,13 @@ using device_image_to_column_bf16_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_image_to_column_f16_instances = std::tuple< - // clang-format off +using device_image_to_column_f16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -58,12 +60,13 @@ using device_image_to_column_f16_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_image_to_column_f32_instances = std::tuple< - // clang-format off +using device_image_to_column_f32_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -75,12 +78,13 @@ using device_image_to_column_f32_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4> - // clang-format on - >; + // clang-format on + >; template -using device_image_to_column_i8_instances = std::tuple< - // clang-format off +using device_image_to_column_i8_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -96,8 +100,8 @@ using device_image_to_column_i8_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 8>, DeviceImageToColumnImpl, 16> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp index 0c44ca6613..1da94059b0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp @@ -38,8 +38,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_km_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_km_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -56,8 +57,8 @@ using device_gemm_xdl_universal_km_kn_mn_comp_instances = std::tuple< DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_dl_f32_instances = std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_dl_f32_instances = + std::tuple< + // clang-format off //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_dl_f16_instances = std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_dl_f16_instances = + std::tuple< + // clang-format off //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_dl_bf16_instances = std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_dl_bf16_instances = + std::tuple< + // clang-format off //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp index 40c4d558b8..47cb9a88a4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp @@ -37,9 +37,8 @@ template -using device_grouped_conv_bwd_weight_wmma_f16_instances = - std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_wmma_f16_instances = std::tuple< + // clang-format off //#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector| //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| @@ -71,17 +70,16 @@ using device_grouped_conv_bwd_weight_wmma_f16_instances = DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_wmma_i8_instances = - std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_wmma_i8_instances = std::tuple< + // clang-format off //#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector| //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| @@ -110,8 +108,8 @@ using device_grouped_conv_bwd_weight_wmma_i8_instances = DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp index 659d6a99a9..34b580cf75 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp @@ -31,9 +31,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -47,8 +46,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -52,8 +51,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -55,8 +54,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -47,8 +46,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -52,8 +51,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -55,8 +54,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( std::vector; // FIXME: retire dedicated 2D version -using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = - std::tuple< - // clang-format off +using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple< + // clang-format off //#####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //#####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //#####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| @@ -122,8 +121,8 @@ using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instan DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( std::vector -using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances = - std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances = std::tuple< + // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| @@ -51,8 +50,8 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances = DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> - // clang-format on - >; + // clang-format on + >; // double rate mfma instances on gfx950 -using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances_2x = - std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances_2x = std::tuple< + // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 256, 64, 64, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp index 5a0c52c2df..ab5f40e81d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp index 59ffb80bd4..6f368a44d3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -20,8 +19,8 @@ using Instances = //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_opt_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp index a64424e8ac..7049732e41 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp index a0dd60c0f5..eef7e728d2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp @@ -9,9 +9,8 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< - // clang-format off +using Instances = std::tuple< + // clang-format off // pipeline v1, 1 wave //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | @@ -25,8 +24,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp index 122fff4960..e966b3ec49 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp index 9f459aabfc..e090b157b3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -20,8 +19,8 @@ using Instances = //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_opt_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp index 3671bea7a3..811358a3d3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp index 98db8bad1c..a9ee03ca49 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp @@ -9,9 +9,8 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< - // clang-format off +using Instances = std::tuple< + // clang-format off // pipeline v1, 1 wave //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | @@ -34,8 +33,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp index 532c348b7e..d4e5ab8014 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -36,8 +35,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp index b931b8fdfd..03fdf13bc4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -20,8 +19,8 @@ using Instances = //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_opt_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp index fa53a3bf0f..c3ab756f3b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -36,8 +35,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp index a590413acc..aa895fc0cd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp @@ -28,9 +28,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = - std::tuple< - // clang-format off +using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| @@ -43,8 +42,8 @@ using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( std::vector -using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< - // clang-format off +using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = + std::tuple< + // clang-format off //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -46,8 +47,8 @@ using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = st DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> - // clang-format on - >; + // clang-format on + >; template using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp index 430daae3ab..06d6780227 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp index 9b876f5430..fd938f502f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 65261235b6..87300fa871 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -54,8 +53,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp index dc770d8d9a..902e349492 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -57,8 +56,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index 266e6b1a5d..a439cf27f5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index 1674b2de6c..55e0362018 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index 758420ca37..e51de0556c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -54,8 +53,8 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp index dad402dff4..722a0bae55 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -57,8 +56,8 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp index ee15dfa94e..d10b9facd5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -50,8 +49,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp index 93039a5008..d9d16ede65 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp index 1dc9678c5b..9277e5e901 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp index e4682c27d3..e97a649c19 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp index 0c601b3823..c8f1b85ddb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp index 8d11b6f9d9..fc0220a502 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp index d389da5ee8..b87cf64b0f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp index 001330eabb..31ad66409e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -52,8 +51,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp index 59154f3439..a6b6465128 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -53,8 +54,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp index b962d75b12..e0bbe7dff0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -55,8 +56,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 9f142ad831..5cb767ab0f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -48,8 +49,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp index 7d141a47e1..ac29d1ba9c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -48,8 +49,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp index 8d109d1346..1a8227279d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp @@ -53,9 +53,8 @@ using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_comp_instances = std::tupl #endif template -using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances = - std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| ACompType| BCompType| APermute| BPermute| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| | | | | //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| | | | | @@ -79,8 +78,8 @@ using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances = DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp index 940da94e70..a160f84175 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -51,8 +52,8 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp index d83014d5e8..2f043cef03 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -63,8 +64,8 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp index ff13de1d6a..0d72da9e6e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -48,8 +49,8 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp index bb10da37f4..c763b5048c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -45,8 +46,8 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp index 680788d668..63300d2c37 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp @@ -53,8 +53,9 @@ using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_comp_instances = std::tuple< #endif template -using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| ACompType| BCompType| APermute| BPermute| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| | | | | //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| | | | | @@ -78,8 +79,8 @@ using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp index 5c525244e1..783606ef9d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -50,8 +51,8 @@ using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< // We prefer following instance, however, existing compiler bug cause it failed to generate sanity code. // DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp index af4008c91d..bece6b4c30 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -47,8 +48,8 @@ using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp index b4554fc6a9..f03dc4fc8e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -54,8 +55,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp index b6a60a1f31..7f1976f220 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -59,8 +60,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp index 5353fe16b5..93ac0d7dcc 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -51,8 +52,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp index 959c1c0992..b2e3252e4d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -61,8 +62,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 16, 1, 16>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp index 282cea7563..a318627bea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -54,8 +55,8 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp index 7335a9851f..92e5c86343 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -62,8 +63,8 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp index d03002af5c..f83b0a47c9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -51,8 +52,8 @@ using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp index 7736f38cb2..2de3ed35b0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -49,8 +50,8 @@ using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp index 57b6ab3ae2..a38eef7294 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -52,8 +53,8 @@ using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_instances = std // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp index 14bd36d29f..d2e15f01da 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -49,8 +50,8 @@ using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp index 839d3559f7..2344108576 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp @@ -80,9 +80,8 @@ template -using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = - std::tuple< - // clang-format off +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| @@ -99,8 +98,8 @@ using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/src/utility/convolution_parameter.cpp b/library/src/utility/convolution_parameter.cpp index a71f8a4fa1..634b7f0890 100644 --- a/library/src/utility/convolution_parameter.cpp +++ b/library/src/utility/convolution_parameter.cpp @@ -215,9 +215,8 @@ ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, ch std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p) { - os << "ConvParam {" - << "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nG: " << p.G_ << "\nN: " << p.N_ - << "\nK: " << p.K_ << "\nC: " << p.C_ + os << "ConvParam {" << "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nG: " << p.G_ + << "\nN: " << p.N_ << "\nK: " << p.K_ << "\nC: " << p.C_ << "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_ << "\ninput_spatial_lengths: " << p.input_spatial_lengths_ << "\nconv_filter_strides: " << p.conv_filter_strides_ diff --git a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp index b70dd9538d..5ea1a78094 100644 --- a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -260,9 +260,9 @@ bool profile_conv_bwd_data_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_conv_fwd_impl.hpp b/profiler/include/profiler/profile_conv_fwd_impl.hpp index 917e4c07fc..37366821c4 100644 --- a/profiler/include/profiler/profile_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_impl.hpp @@ -233,9 +233,9 @@ bool profile_conv_fwd_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp b/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp index fa0a771962..14182bb7b0 100644 --- a/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp +++ b/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp @@ -288,9 +288,8 @@ bool profile_conv_tensor_rearrange_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\nGB/s: " << best_gb_per_sec << std::endl; return is_supporting_instance && pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index 12f6ad606f..0aeefaabfb 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -287,10 +287,9 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK " - << best_split_k << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << ", SplitK " << best_split_k << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index c1bb90dd9c..84acb53425 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -92,12 +92,12 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, if(do_verification) { auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight{}; + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp>{}; auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(input, weight_host_result, @@ -302,10 +302,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK " - << best_split_k << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << ", SplitK " << best_split_k << std::endl; return all_pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index c12fa75e34..d0e1cf2611 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -178,8 +178,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, in_element_op, wei_element_op, out_element_op, - {}, - {}, + {}, + {}, d_tensors); // init host output to zero @@ -312,9 +312,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, run_impl(op_ptr, argument_ptr); } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index a1f9ee1528..2dcee4c1fc 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -250,9 +250,9 @@ bool profile_grouped_conv_fwd_impl(int do_verification, run_impl(op_ptr, argument_ptr); } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index bd756eb825..b553e07735 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -342,9 +342,9 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, run_impl(op_ptr, argument_ptr); } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_softmax_impl.hpp b/profiler/include/profiler/profile_softmax_impl.hpp index daaf565149..83913d8398 100644 --- a/profiler/include/profiler/profile_softmax_impl.hpp +++ b/profiler/include/profiler/profile_softmax_impl.hpp @@ -103,12 +103,12 @@ bool profile_softmax_impl(int do_verification, // add device softmax instances using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceOp = tensor_operation::device::DeviceSoftmax; + AccDataType, + OutDataType, + PassThrough, + PassThrough, + Rank, + NumReduceDim>; // get device op instances const auto instances = tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -141,8 +141,7 @@ bool profile_softmax_impl(int do_verification, { std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; LogRange(std::cout << "input lengths = [", in_length, ", ") - << "], " - << "scaler = [" << alpha << ", " << beta << "]"; + << "], " << "scaler = [" << alpha << ", " << beta << "]"; LogRange(std::cout << ", reduce dims = [", reduce_dims, ", ") << "]." << std::endl; instance_pass.push_back(true); continue; @@ -202,8 +201,7 @@ bool profile_softmax_impl(int do_verification, { std::cout << inst_ptr->GetTypeString() << " failed verification: "; LogRange(std::cout << "input lengths = [", in_length, ", ") - << "], " - << "scaler = [" << alpha << ", " << beta << "]." << std::endl; + << "], " << "scaler = [" << alpha << ", " << beta << "]." << std::endl; } instance_pass.push_back(pass); } @@ -215,9 +213,8 @@ bool profile_softmax_impl(int do_verification, LogRange(std::cout << "length = ", in_tensor_lengths, ",") << ", "; LogRange(std::cout << "stride = ", in_tensor_strides, ",") << ", "; LogRange(std::cout << "reduce dims ", reduce_dims, ",") << ", "; - std::cout << "alpha = " << alpha << ", " - << "beta = " << beta << ", " << best_avg_time << " ms, " << best_gb_per_sec - << " GB/s, " << best_instance_name << std::endl; + std::cout << "alpha = " << alpha << ", " << "beta = " << beta << ", " << best_avg_time + << " ms, " << best_gb_per_sec << " GB/s, " << best_instance_name << std::endl; } return std::all_of( std::begin(instance_pass), std::end(instance_pass), [](bool p) { return p; }); diff --git a/profiler/src/profile_contraction_bilinear.cpp b/profiler/src/profile_contraction_bilinear.cpp index 990e1e1196..a64555fc66 100644 --- a/profiler/src/profile_contraction_bilinear.cpp +++ b/profiler/src/profile_contraction_bilinear.cpp @@ -29,8 +29,7 @@ static void print_helper_msg() << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" << "arg6: verification (0: no; 1: yes)\n" - << "arg7: initialization (0: no init; 1: integer value; 2: decimal " - << "value)\n" + << "arg7: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" << "arg8: print tensor value (0: no; 1: yes)\n" << "arg9: time kernel (0: no, 1: yes)\n" << "arg10: alpha\n" diff --git a/profiler/src/profile_contraction_scale.cpp b/profiler/src/profile_contraction_scale.cpp index 85252eaa37..a168c09bcf 100644 --- a/profiler/src/profile_contraction_scale.cpp +++ b/profiler/src/profile_contraction_scale.cpp @@ -29,8 +29,7 @@ static void print_helper_msg() << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" << "arg6: verification (0: no; 1: yes)\n" - << "arg7: initialization (0: no init; 1: integer value; 2: decimal " - << "value)\n" + << "arg7: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" << "arg8: print tensor value (0: no; 1: yes)\n" << "arg9: time kernel (0: no, 1: yes)\n" << "arg10: alpha\n" diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index 728b8c1092..53de05a7d8 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' -git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' +find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' +git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc index b7cf891862..116d3798b9 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc @@ -110,8 +110,8 @@ bool run(const ck_tile::ArgParser& arg_parser) b_buf.ToDevice(b_host.data()); gamma_buf.ToDevice(gamma_host.data()); - std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" << " m:" << m + << ", n:" << n << ", stride:" << stride << std::flush; add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX}; diff --git a/test/ck_tile/data_type/test_pk_int4.cpp b/test/ck_tile/data_type/test_pk_int4.cpp index 4e9fb20efc..1ccae88112 100644 --- a/test/ck_tile/data_type/test_pk_int4.cpp +++ b/test/ck_tile/data_type/test_pk_int4.cpp @@ -36,8 +36,8 @@ TEST(PackedInt4, ConvertToHalf) const half_t first_input_val = ck_tile::type_convert(7.f); const half_t second_input_val = ck_tile::type_convert(-1.f); #else - const half_t first_input_val = ck_tile::type_convert(-1.f); - const half_t second_input_val = ck_tile::type_convert(7.f); + const half_t first_input_val = ck_tile::type_convert(-1.f); + const half_t second_input_val = ck_tile::type_convert(7.f); #endif uint8_t data = 0b11110111; // {-1, 7} pk_int4_t in = ck_tile::bit_cast(data); @@ -53,8 +53,8 @@ TEST(PackedInt4, ConvertToBHalf) const bf16_t first_input_val = ck_tile::type_convert(7.f); const bf16_t second_input_val = ck_tile::type_convert(-1.f); #else - const bf16_t first_input_val = ck_tile::type_convert(-1.f); - const bf16_t second_input_val = ck_tile::type_convert(7.f); + const bf16_t first_input_val = ck_tile::type_convert(-1.f); + const bf16_t second_input_val = ck_tile::type_convert(7.f); #endif uint8_t data = 0b11110111; // {-1, 7} pk_int4_t in = ck_tile::bit_cast(data); diff --git a/test/ck_tile/elementwise/test_elementwise_1d.cpp b/test/ck_tile/elementwise/test_elementwise_1d.cpp index 5f327c7097..7013792335 100644 --- a/test/ck_tile/elementwise/test_elementwise_1d.cpp +++ b/test/ck_tile/elementwise/test_elementwise_1d.cpp @@ -36,11 +36,9 @@ struct elementwise_op_traits template auto make_uniform_array_with_factory(F&& factory) { - return [&](std::index_sequence) - { + return [&](std::index_sequence) { return std::array, D>{factory(Is)...}; - } - (std::make_index_sequence{}); + }(std::make_index_sequence{}); } template @@ -87,12 +85,10 @@ class TestCkTileElementwise : public ::testing::Test ck_tile::DeviceMem d_y_mem(h_y); d_y_mem.SetZero(); - auto d_x_ptrs_tuple = [&](std::index_sequence) - { + auto d_x_ptrs_tuple = [&](std::index_sequence) { return ck_tile::make_tuple( static_cast(d_xs_mems_owner[Is].GetDeviceBuffer())...); - } - (std::make_index_sequence{}); + }(std::make_index_sequence{}); YDataType* p_y_device = static_cast(d_y_mem.GetDeviceBuffer()); @@ -142,11 +138,9 @@ class TestCkTileElementwise : public ::testing::Test ElementwiseOpType op_host; for(ck_tile::index_t i = 0; i < total_m_elements; ++i) { - auto get_host_op_args = [&](std::index_sequence) - { + auto get_host_op_args = [&](std::index_sequence) { return ck_tile::make_tuple(static_cast(h_xs[Is](i))...); - } - (std::make_index_sequence{}); + }(std::make_index_sequence{}); YDataType temp_y_val; ck_tile::apply( diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 9adf9ec185..70aa161881 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -218,10 +218,9 @@ class TestCkTileGemmPipeline : public ::testing::Test if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " + << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " + << blocks.y << ", " << blocks.z << "}" << std::endl; } ck_tile::launch_kernel( diff --git a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc index 3d2c9a82e0..a63a58b473 100644 --- a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc +++ b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc @@ -90,24 +90,24 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s tail_number_v>; using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - transposed_warp_gemm, - ck_tile::memory_operation_enum::set>>; + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; using Kernel = ck_tile::AQuantGemmKernel; @@ -449,14 +449,18 @@ bool run_gemm_test(int argc, char* argv[]) } else if(data_type == "i4fp8") { - using TypeConfig = decltype( - GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); } else if(data_type == "i4bf8") { - using TypeConfig = decltype( - GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); } else if(data_type == "i4f32fp8") diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 4d6a1b42b1..af229aad29 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -215,10 +215,9 @@ class TestCkTileGemmPipeline : public ::testing::Test if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " + << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " + << blocks.y << ", " << blocks.z << "}" << std::endl; } ck_tile::launch_kernel( diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index 79e29f8b99..cededd38f9 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -82,11 +82,11 @@ class TestCkTileGroupedGemm : public ::testing::Test GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; + GroupedGemKernelParam::kPadN, + GroupedGemKernelParam::kPadK, + ALayout, + BLayout, + CLayout>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits 0) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } ave_time = ck_tile::launch_kernel( @@ -284,10 +284,10 @@ class TestCkTileGroupedGemm : public ::testing::Test if(s.log_level_ > 0) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } ck_tile::launch_kernel(s, @@ -412,8 +412,7 @@ class TestCkTileGroupedGemm : public ::testing::Test c_m_n_tensors.push_back(ck_tile::HostTensor( f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); - std::cout << "gemm[" << i << "]" - << " a_m_k: " << a_m_k_tensors[i].mDesc + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc << " KBatch: " << kbatch << std::endl; diff --git a/test/ck_tile/layernorm2d/layernorm2d_fwd.inc b/test/ck_tile/layernorm2d/layernorm2d_fwd.inc index 8070815b7e..a0295eafeb 100644 --- a/test/ck_tile/layernorm2d/layernorm2d_fwd.inc +++ b/test/ck_tile/layernorm2d/layernorm2d_fwd.inc @@ -194,8 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << std::flush; diff --git a/test/ck_tile/moe_smoothquant/moe_smoothquant.inc b/test/ck_tile/moe_smoothquant/moe_smoothquant.inc index ff23c99e74..9e181a9d8c 100644 --- a/test/ck_tile/moe_smoothquant/moe_smoothquant.inc +++ b/test/ck_tile/moe_smoothquant/moe_smoothquant.inc @@ -128,9 +128,9 @@ bool run(const ck_tile::ArgParser& arg_parser) smscale_buf.ToDevice(smscale_host.data()); topk_ids_buf.ToDevice(topk_ids_host.data()); - std::cout << "[" << prec_i << "-" << prec_o << "]" - << " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride - << ", experts:" << experts << ", topk:" << topk << std::flush; + std::cout << "[" << prec_i << "-" << prec_o << "]" << " tokens:" << tokens + << ", hidden_size:" << hidden_size << ", stride:" << stride << ", experts:" << experts + << ", topk:" << topk << std::flush; moe_smoothquant_traits traits{prec_i, prec_o}; diff --git a/test/ck_tile/moe_sorting/moe_sorting_api.cpp b/test/ck_tile/moe_sorting/moe_sorting_api.cpp index 0e8998e254..0f25e17867 100644 --- a/test/ck_tile/moe_sorting/moe_sorting_api.cpp +++ b/test/ck_tile/moe_sorting/moe_sorting_api.cpp @@ -40,11 +40,11 @@ constexpr bool local_expert_masking = local_expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemEx; \ + ms_weight_type, \ + sub_token_tile, \ + sub_token_onshot, \ + local_expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingKernel; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -200,11 +200,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -218,11 +218,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -236,11 +236,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -254,11 +254,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -273,11 +273,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ diff --git a/test/ck_tile/moe_sorting/moe_sorting_fp32.cpp b/test/ck_tile/moe_sorting/moe_sorting_fp32.cpp index cc511984fe..8a300dd890 100644 --- a/test/ck_tile/moe_sorting/moe_sorting_fp32.cpp +++ b/test/ck_tile/moe_sorting/moe_sorting_fp32.cpp @@ -226,20 +226,26 @@ bool test_moe_sorting(ck_tile::ArgParser args) moe_sorting_trait trait{ index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy}; - moe_sorting_args karg - { - topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(), - local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr, - is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, - sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(), - sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(), - moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, - workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size, - num_experts, topk, + moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), + weights_dev.GetDeviceBuffer(), + local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() + : nullptr, + is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, + sorted_ids_dev.GetDeviceBuffer(), + sorted_weights_dev.GetDeviceBuffer(), + sorted_expert_ids_dev.GetDeviceBuffer(), + sorted_id_cnt_dev.GetDeviceBuffer(), + moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, + workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, + tokens, + unit_size, + num_experts, + topk, #if MOE_SORTING_FMOE_2D_BUF - moe_buf_interm_dim, moe_buf_elem_bytes + moe_buf_interm_dim, + moe_buf_elem_bytes #else - static_cast(moe_buf_size * sizeof(float)) + static_cast(moe_buf_size * sizeof(float)) #endif }; diff --git a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp index 518a9a8889..c94adc24c3 100644 --- a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -333,12 +333,12 @@ struct matrix_core_swizzle_kernel return tmp_1; #else // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, - constexpr index_t kv = Alignment; - constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t waveflatten = kw * nw * kv; - const index_t kr = a_.k / (k1 * k2); - const index_t nr = a_.n / nw; + const index_t kr = a_.k / (k1 * k2); + const index_t nr = a_.n / nw; auto tmp = make_naive_tensor_view_packed( p_dst, make_tuple(nr, kr, waveflatten), @@ -387,8 +387,8 @@ struct matrix_core_swizzle_kernel constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t waveflatten_tile = kw * nw * kv; - constexpr index_t nr_tile = NPerBlock / nw; - constexpr index_t kr_tile = KPerBlock / (kw * kv); + constexpr index_t nr_tile = NPerBlock / nw; + constexpr index_t kr_tile = KPerBlock / (kw * kv); return make_tile_window(dst_view, make_tuple(number{}, number{}, diff --git a/test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc b/test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc index 19abf10f3c..bf8ee8b0cc 100644 --- a/test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc +++ b/test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc @@ -194,8 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << std::flush; diff --git a/test/ck_tile/smoothquant/smoothquant.inc b/test/ck_tile/smoothquant/smoothquant.inc index afda7de4eb..23dba27e88 100644 --- a/test/ck_tile/smoothquant/smoothquant.inc +++ b/test/ck_tile/smoothquant/smoothquant.inc @@ -96,9 +96,8 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); smscale_buf.ToDevice(smscale_host.data()); - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride - << std::flush; + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", y_stride:" << y_stride << std::flush; smoothquant_traits traits{data_type}; diff --git a/test/data_type/test_pk_i4.cpp b/test/data_type/test_pk_i4.cpp index d8d4d0e36d..52273d45de 100644 --- a/test/data_type/test_pk_i4.cpp +++ b/test/data_type/test_pk_i4.cpp @@ -31,8 +31,8 @@ TEST(PackedInt4, ConvertToFloat) constexpr float first_input_val = 7.f; constexpr float second_input_val = -1.f; #else - constexpr float first_input_val = -1.f; - constexpr float second_input_val = 7.f; + constexpr float first_input_val = -1.f; + constexpr float second_input_val = 7.f; #endif uint8_t data = 0b11110111; // {-1, 7} pk_i4_t in = ck::bit_cast(data); @@ -65,8 +65,8 @@ TEST(PackedInt4, ConvertToBHalf) const bhalf_t first_input_val = ck::type_convert(7.f); const bhalf_t second_input_val = ck::type_convert(-1.f); #else - const bhalf_t first_input_val = ck::type_convert(-1.f); - const bhalf_t second_input_val = ck::type_convert(7.f); + const bhalf_t first_input_val = ck::type_convert(-1.f); + const bhalf_t second_input_val = ck::type_convert(7.f); #endif uint8_t data = 0b11110111; // {-1, 7} pk_i4_t in = ck::bit_cast(data); diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index 5e2aedd35e..9decfe14ac 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -67,12 +67,12 @@ TEST(MFMA, FP8MFMA16x16x128) using CLayout = ck::tensor_layout::gemm::ColumnMajor; auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mfma_test(AB_init); + BLayout, + CLayout, + f8_t, + f8_t, + half_t, + ck::MFMA_F8F6F4::F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -83,12 +83,12 @@ TEST(MFMA, BF8MFMA16x16x128) using CLayout = ck::tensor_layout::gemm::ColumnMajor; auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mfma_test(AB_init); + BLayout, + CLayout, + bf8_t, + bf8_t, + half_t, + ck::MFMA_F8F6F4::F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -126,12 +126,12 @@ TEST(MFMA, BF6MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mfma_test(AB_init); + BLayout, + CLayout, + bf6_t, + bf6_t, + float, + ck::MFMA_F8F6F4::F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -156,12 +156,12 @@ TEST(MFMA, BF8MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mfma_test(AB_init); + BLayout, + CLayout, + bf8_t, + bf8_t, + float, + ck::MFMA_F8F6F4::F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -199,12 +199,12 @@ TEST(MFMA, BF6MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mfma_test(AB_init); + BLayout, + CLayout, + bf6_t, + bf6_t, + half_t, + ck::MFMA_F8F6F4::F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -274,12 +274,12 @@ TEST(MXMFMA, MXFP8MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f8_t, + f8_t, + float, + ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -291,12 +291,12 @@ TEST(MXMFMA, MXFP8MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f8_t, + f8_t, + half_t, + ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -308,12 +308,12 @@ TEST(MXMFMA, MXBF8MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + bf8_t, + bf8_t, + float, + ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -325,12 +325,12 @@ TEST(MXMFMA, MXBF8MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + bf8_t, + bf8_t, + half_t, + ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -342,12 +342,12 @@ TEST(MXMFMA, MXFP6MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f6_t, + f6_t, + float, + ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -359,12 +359,12 @@ TEST(MXMFMA, MXFP6MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f6_t, + f6_t, + half_t, + ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -376,12 +376,12 @@ TEST(MXMFMA, MXBF6MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + bf6_t, + bf6_t, + float, + ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -393,12 +393,12 @@ TEST(MXMFMA, MXBF6MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + bf6_t, + bf6_t, + half_t, + ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -410,12 +410,12 @@ TEST(MXMFMA, MXFP4MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f4_t, + f4_t, + float, + ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -427,11 +427,11 @@ TEST(MXMFMA, MXFP4MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f4_t, + f4_t, + half_t, + ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } diff --git a/test/pool/test_max_pool2d_fwd.cpp b/test/pool/test_max_pool2d_fwd.cpp index 2179242754..bb6fc96cb1 100644 --- a/test/pool/test_max_pool2d_fwd.cpp +++ b/test/pool/test_max_pool2d_fwd.cpp @@ -57,9 +57,9 @@ using true_t = std::integral_constant; using false_t = std::integral_constant; using MaxPool2D_F32_Types = ::testing::Types, - std::tuple>; + std::tuple>; using MaxPool2D_F16_Types = ::testing::Types, - std::tuple>; + std::tuple>; using MaxPool2D_BF16_Types = ::testing::Types, std::tuple>; using MaxPool2D_I8_Types = diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp index b3328e4b36..45345cccfa 100644 --- a/test/reference_conv_fwd/reference_conv_fwd.cpp +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -58,12 +58,12 @@ run_reference_convolution_forward(const ck::utils::conv::ConvParam& conv_param, ck::ranges::fill(host_output, 0.f); auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp>(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(input, weights, diff --git a/tile_engine/ops/gemm/benchmark_gemm.hpp b/tile_engine/ops/gemm/benchmark_gemm.hpp index bbb9c1d715..ce8a6e8234 100644 --- a/tile_engine/ops/gemm/benchmark_gemm.hpp +++ b/tile_engine/ops/gemm/benchmark_gemm.hpp @@ -105,10 +105,8 @@ struct KernelInstance friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) { os << "{\n" - << " \"name\": \"" - << "{\n" - << obj.name_ << "\n}" - << "\",\n" + << " \"name\": \"" << "{\n" + << obj.name_ << "\n}" << "\",\n" << " \"problem\": \"" << obj.problem_ << "\",\n" << " \"perf_result\": " << obj.perf_result_ << "\n" << "}"; diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index fdad363f7c..634e19de6e 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -218,10 +218,8 @@ class GemmProfiler { file << "rocm_version,device_name," << "split_k,m,n,k,stride_a,stride_b,stride_c," - << "dtype_a,dtype_b,dtype_acc,dtype_c," - << "layout_a,layout_b,layout_c," - << "structured_sparsity," - << "name," + << "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c," + << "structured_sparsity," << "name," << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n"; } @@ -251,7 +249,7 @@ class GemmProfiler return kernel_instance; } - GemmProfiler(const GemmProfiler&) = delete; + GemmProfiler(const GemmProfiler&) = delete; GemmProfiler& operator=(const GemmProfiler&) = delete; private: